Skip to content

Commit

Permalink
feat: support to customize tokenizer
Browse files Browse the repository at this point in the history
Signed-off-by: BubbleCal <[email protected]>
  • Loading branch information
BubbleCal committed Oct 10, 2024
1 parent fe5fcf3 commit 6ec08a7
Show file tree
Hide file tree
Showing 9 changed files with 268 additions and 39 deletions.
36 changes: 18 additions & 18 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ exclude = ["python"]
resolver = "2"

[workspace.package]
version = "0.18.3"
version = "0.18.4"
edition = "2021"
authors = ["Lance Devs <[email protected]>"]
license = "Apache-2.0"
Expand All @@ -44,21 +44,21 @@ categories = [
rust-version = "1.78"

[workspace.dependencies]
lance = { version = "=0.18.3", path = "./rust/lance" }
lance-arrow = { version = "=0.18.3", path = "./rust/lance-arrow" }
lance-core = { version = "=0.18.3", path = "./rust/lance-core" }
lance-datafusion = { version = "=0.18.3", path = "./rust/lance-datafusion" }
lance-datagen = { version = "=0.18.3", path = "./rust/lance-datagen" }
lance-encoding = { version = "=0.18.3", path = "./rust/lance-encoding" }
lance-encoding-datafusion = { version = "=0.18.3", path = "./rust/lance-encoding-datafusion" }
lance-file = { version = "=0.18.3", path = "./rust/lance-file" }
lance-index = { version = "=0.18.3", path = "./rust/lance-index" }
lance-io = { version = "=0.18.3", path = "./rust/lance-io" }
lance-jni = { version = "=0.18.3", path = "./java/core/lance-jni" }
lance-linalg = { version = "=0.18.3", path = "./rust/lance-linalg" }
lance-table = { version = "=0.18.3", path = "./rust/lance-table" }
lance-test-macros = { version = "=0.18.3", path = "./rust/lance-test-macros" }
lance-testing = { version = "=0.18.3", path = "./rust/lance-testing" }
lance = { version = "=0.18.4", path = "./rust/lance" }
lance-arrow = { version = "=0.18.4", path = "./rust/lance-arrow" }
lance-core = { version = "=0.18.4", path = "./rust/lance-core" }
lance-datafusion = { version = "=0.18.4", path = "./rust/lance-datafusion" }
lance-datagen = { version = "=0.18.4", path = "./rust/lance-datagen" }
lance-encoding = { version = "=0.18.4", path = "./rust/lance-encoding" }
lance-encoding-datafusion = { version = "=0.18.4", path = "./rust/lance-encoding-datafusion" }
lance-file = { version = "=0.18.4", path = "./rust/lance-file" }
lance-index = { version = "=0.18.4", path = "./rust/lance-index" }
lance-io = { version = "=0.18.4", path = "./rust/lance-io" }
lance-jni = { version = "=0.18.4", path = "./java/core/lance-jni" }
lance-linalg = { version = "=0.18.4", path = "./rust/lance-linalg" }
lance-table = { version = "=0.18.4", path = "./rust/lance-table" }
lance-test-macros = { version = "=0.18.4", path = "./rust/lance-test-macros" }
lance-testing = { version = "=0.18.4", path = "./rust/lance-testing" }
approx = "0.5.1"
# Note that this one does not include pyarrow
arrow = { version = "52.2", optional = false, features = ["prettyprint"] }
Expand Down Expand Up @@ -111,7 +111,7 @@ datafusion-physical-expr = { version = "41.0", features = [
] }
deepsize = "0.2.0"
either = "1.0"
fsst = { version = "=0.18.3", path = "./rust/lance-encoding/compression-algo/fsst" }
fsst = { version = "=0.18.4", path = "./rust/lance-encoding/compression-algo/fsst" }
futures = "0.3"
http = "0.2.9"
hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
Expand Down Expand Up @@ -141,7 +141,7 @@ serde = { version = "^1" }
serde_json = { version = "1" }
shellexpand = "3.0"
snafu = "0.7.5"
tantivy = "0.22.0"
tantivy = { version = "0.22.0", features = ["stopwords"] }
tempfile = "3"
test-log = { version = "0.2.15" }
tokio = { version = "1.23", features = [
Expand Down
2 changes: 1 addition & 1 deletion python/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pylance"
version = "0.18.3"
version = "0.18.4"
edition = "2021"
authors = ["Lance Devs <[email protected]>"]
rust-version = "1.65"
Expand Down
21 changes: 21 additions & 0 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,6 +1337,27 @@ def create_scalar_index(
query. This will significantly increase the index size.
It won't impact the performance of non-phrase queries even if it is set to
True.
base_tokenizer: str, default "simple"
This is for the ``INVERTED`` index. The base tokenizer to use. The value
can be:
* "simple": splits tokens on whitespace and punctuation.
* "whitespace": splits tokens on whitespace.
* "raw": no tokenization.
language: str, default "English"
This is for the ``INVERTED`` index. The language for stemming
and stop words. This is only used when `stem` or `remove_stop_words` is true
max_token_length: Optional[int], default 40
This is for the ``INVERTED`` index. The maximum token length.
Any token longer than this will be removed.
lower_case: bool, default True
This is for the ``INVERTED`` index. If True, the index will convert all
text to lowercase.
stem: bool, default False
This is for the ``INVERTED`` index. If True, the index will stem the
tokens.
remove_stop_words: bool, default False
This is for the ``INVERTED`` index. If True, the index will remove
stop words.
Examples
--------
Expand Down
32 changes: 32 additions & 0 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,38 @@ impl Dataset {
if let Some(with_position) = kwargs.get_item("with_position")? {
params.with_position = with_position.extract()?;
}
if let Some(base_tokenizer) = kwargs.get_item("base_tokenizer")? {
params.tokenizer_config = params
.tokenizer_config
.base_tokenizer(base_tokenizer.extract()?);
}
if let Some(language) = kwargs.get_item("language")? {
let language = language.extract()?;
params.tokenizer_config =
params.tokenizer_config.language(language).map_err(|e| {
PyValueError::new_err(format!(
"can't set tokenizer language to {}: {:?}",
language, e
))
})?;
}
if let Some(max_token_length) = kwargs.get_item("max_token_length")? {
params.tokenizer_config = params
.tokenizer_config
.max_token_length(max_token_length.extract()?);
}
if let Some(lower_case) = kwargs.get_item("lower_case")? {
params.tokenizer_config =
params.tokenizer_config.lower_case(lower_case.extract()?);
}
if let Some(stem) = kwargs.get_item("stem")? {
params.tokenizer_config = params.tokenizer_config.stem(stem.extract()?);
}
if let Some(remove_stop_words) = kwargs.get_item("remove_stop_words")? {
params.tokenizer_config = params
.tokenizer_config
.remove_stop_words(remove_stop_words.extract()?);
}
}
Box::new(params)
}
Expand Down
21 changes: 20 additions & 1 deletion rust/lance-index/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
//! Scalar indices for metadata search & filtering

use std::collections::HashMap;
use std::fmt::Debug;
use std::{any::Any, ops::Bound, sync::Arc};

use arrow::buffer::{OffsetBuffer, ScalarBuffer};
Expand All @@ -17,6 +18,7 @@ use datafusion_common::{scalar::ScalarValue, Column};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::Expr;
use deepsize::DeepSizeOf;
use inverted::TokenizerConfig;
use lance_core::utils::mask::RowIdTreeMap;
use lance_core::{Error, Result};
use snafu::{location, Location};
Expand Down Expand Up @@ -91,19 +93,36 @@ impl IndexParams for ScalarIndexParams {
}
}

#[derive(Debug, Clone, DeepSizeOf)]
#[derive(Clone)]
pub struct InvertedIndexParams {
/// If true, store the position of the term in the document
/// This can significantly increase the size of the index
/// If false, only store the frequency of the term in the document
/// Default is true
pub with_position: bool,

pub tokenizer_config: TokenizerConfig,
}

impl Debug for InvertedIndexParams {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InvertedIndexParams")
.field("with_position", &self.with_position)
.finish()
}
}

impl DeepSizeOf for InvertedIndexParams {
fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
0
}
}

impl Default for InvertedIndexParams {
fn default() -> Self {
Self {
with_position: true,
tokenizer_config: TokenizerConfig::default(),
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions rust/lance-index/src/scalar/inverted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

mod builder;
mod index;
mod tokenizer;
mod wand;

pub use builder::InvertedIndexBuilder;
pub use index::*;
use lance_core::Result;
pub use tokenizer::*;

use super::btree::TrainingSource;
use super::{IndexStore, InvertedIndexParams};
Expand Down
8 changes: 6 additions & 2 deletions rust/lance-index/src/scalar/inverted/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,9 @@ impl InvertedIndexBuilder {
let senders = Arc::new(senders);
let tokenizer_pool = Arc::new(ArrayQueue::new(num_shards));
let token_buffers_pool = Arc::new(ArrayQueue::new(num_shards));
let tokenizer = self.params.tokenizer_config.build()?;
for _ in 0..num_shards {
let _ = tokenizer_pool.push(TOKENIZER.clone());
let _ = tokenizer_pool.push(tokenizer.clone());
token_buffers_pool
.push(vec![Vec::new(); num_shards])
.unwrap();
Expand Down Expand Up @@ -355,7 +356,10 @@ impl InvertedIndexBuilder {
let batch = tokens.to_batch()?;
let mut writer = store.new_index_file(TOKENS_FILE, batch.schema()).await?;
writer.write_record_batch(batch).await?;
writer.finish().await?;

let tokenizer = serde_json::to_string(&self.params.tokenizer_config)?;
let metadata = HashMap::from_iter(vec![("tokenizer".to_owned(), tokenizer)]);
writer.finish_with_metadata(metadata).await?;

log::info!("finished writing tokens");
Ok(())
Expand Down
50 changes: 33 additions & 17 deletions rust/lance-index/src/scalar/inverted/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-FileCopyrightText: Copyright The Lance Authors

use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
use std::sync::Arc;

use arrow::array::{
Expand All @@ -27,11 +28,10 @@ use lazy_static::lazy_static;
use moka::future::Cache;
use roaring::RoaringBitmap;
use snafu::{location, Location};
use tantivy::tokenizer::Language;
use tracing::instrument;

use super::builder::inverted_list_schema;
use super::{wand::*, InvertedIndexBuilder};
use super::{wand::*, InvertedIndexBuilder, TokenizerConfig};
use crate::prefilter::{NoFilter, PreFilter};
use crate::scalar::{
AnyQuery, FullTextSearchQuery, IndexReader, IndexStore, SargableQuery, ScalarIndex,
Expand All @@ -57,26 +57,30 @@ pub const K1: f32 = 1.2;
pub const B: f32 = 0.75;

lazy_static! {
pub static ref TOKENIZER: tantivy::tokenizer::TextAnalyzer = {
tantivy::tokenizer::TextAnalyzer::builder(tantivy::tokenizer::SimpleTokenizer::default())
.filter(tantivy::tokenizer::RemoveLongFilter::limit(40))
.filter(tantivy::tokenizer::LowerCaser)
.filter(tantivy::tokenizer::Stemmer::new(Language::English))
.build()
};
static ref CACHE_SIZE: usize = std::env::var("LANCE_INVERTED_CACHE_SIZE")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(512 * 1024 * 1024);
}

#[derive(Debug, Clone)]
#[derive(Clone)]
pub struct InvertedIndex {
tokenizer: tantivy::tokenizer::TextAnalyzer,
tokens: TokenSet,
inverted_list: Arc<InvertedListReader>,
docs: DocSet,
}

impl Debug for InvertedIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InvertedIndex")
.field("tokens", &self.tokens)
.field("inverted_list", &self.inverted_list)
.field("docs", &self.docs)
.finish()
}
}

impl DeepSizeOf for InvertedIndex {
fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
self.tokens.deep_size_of_children(context)
Expand All @@ -102,7 +106,8 @@ impl InvertedIndex {
query: &FullTextSearchQuery,
prefilter: Arc<dyn PreFilter>,
) -> Result<Vec<(u64, f32)>> {
let tokens = collect_tokens(&query.query);
let mut tokenizer = self.tokenizer.clone();
let tokens = collect_tokens(&query.query, &mut tokenizer);
let token_ids = self.map(&tokens).into_iter();
let token_ids = if !is_phrase_query(&query.query) {
token_ids.sorted_unstable().dedup().collect()
Expand Down Expand Up @@ -239,8 +244,16 @@ impl ScalarIndex for InvertedIndex {
let store = store.clone();
async move {
let token_reader = store.open_index_file(TOKENS_FILE).await?;
let tokenizer = token_reader
.schema()
.metadata
.get("tokenizer")
.map(|s| serde_json::from_str::<TokenizerConfig>(s))
.transpose()?
.unwrap_or_default()
.build()?;
let tokens = TokenSet::load(token_reader).await?;
Result::Ok(tokens)
Result::Ok((tokenizer, tokens))
}
});
let invert_list_fut = tokio::spawn({
Expand All @@ -260,11 +273,12 @@ impl ScalarIndex for InvertedIndex {
}
});

let tokens = tokens_fut.await??;
let (tokenizer, tokens) = tokens_fut.await??;
let inverted_list = invert_list_fut.await??;
let docs = docs_fut.await??;

Ok(Arc::new(Self {
tokenizer,
tokens,
inverted_list,
docs,
Expand Down Expand Up @@ -959,13 +973,16 @@ fn do_flat_full_text_search<Offset: OffsetSizeTrait>(
query: &str,
) -> Result<Vec<u64>> {
let mut results = Vec::new();
let query_tokens = collect_tokens(query).into_iter().collect::<HashSet<_>>();
let mut tokenizer = TokenizerConfig::default().build()?;
let query_tokens = collect_tokens(query, &mut tokenizer)
.into_iter()
.collect::<HashSet<_>>();
for batch in batches {
let row_id_array = batch[ROW_ID].as_primitive::<UInt64Type>();
let doc_array = batch[doc_col].as_string::<Offset>();
for i in 0..row_id_array.len() {
let doc = doc_array.value(i);
let doc_tokens = collect_tokens(doc);
let doc_tokens = collect_tokens(doc, &mut tokenizer);
if doc_tokens.iter().any(|token| query_tokens.contains(token)) {
results.push(row_id_array.value(i));
assert!(doc.contains(query));
Expand All @@ -976,8 +993,7 @@ fn do_flat_full_text_search<Offset: OffsetSizeTrait>(
Ok(results)
}

pub fn collect_tokens(text: &str) -> Vec<String> {
let mut tokenizer = TOKENIZER.clone();
pub fn collect_tokens(text: &str, tokenizer: &mut tantivy::tokenizer::TextAnalyzer) -> Vec<String> {
let mut stream = tokenizer.token_stream(text);
let mut tokens = Vec::new();
while let Some(token) = stream.next() {
Expand Down
Loading

0 comments on commit 6ec08a7

Please sign in to comment.