diff --git a/.github/workflows/pr-title.yml b/.github/workflows/pr-title.yml index 80d75b0c30..dc404a1441 100644 --- a/.github/workflows/pr-title.yml +++ b/.github/workflows/pr-title.yml @@ -36,7 +36,7 @@ jobs: with: # pull_request_target checks out the base branch by default, not # the PR branch. - ref: "${{ github.event.pull_request.merge_commit_sha }}" + ref: "${{ github.event.pull_request.head.sha }}" path: pr - uses: actions/setup-python@v5 with: diff --git a/Cargo.toml b/Cargo.toml index bbf2ee5892..a9f7e86cc4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ exclude = ["python"] resolver = "2" [workspace.package] -version = "0.18.3" +version = "0.19.2" edition = "2021" authors = ["Lance Devs "] license = "Apache-2.0" @@ -44,21 +44,21 @@ categories = [ rust-version = "1.78" [workspace.dependencies] -lance = { version = "=0.18.3", path = "./rust/lance" } -lance-arrow = { version = "=0.18.3", path = "./rust/lance-arrow" } -lance-core = { version = "=0.18.3", path = "./rust/lance-core" } -lance-datafusion = { version = "=0.18.3", path = "./rust/lance-datafusion" } -lance-datagen = { version = "=0.18.3", path = "./rust/lance-datagen" } -lance-encoding = { version = "=0.18.3", path = "./rust/lance-encoding" } -lance-encoding-datafusion = { version = "=0.18.3", path = "./rust/lance-encoding-datafusion" } -lance-file = { version = "=0.18.3", path = "./rust/lance-file" } -lance-index = { version = "=0.18.3", path = "./rust/lance-index" } -lance-io = { version = "=0.18.3", path = "./rust/lance-io" } -lance-jni = { version = "=0.18.3", path = "./java/core/lance-jni" } -lance-linalg = { version = "=0.18.3", path = "./rust/lance-linalg" } -lance-table = { version = "=0.18.3", path = "./rust/lance-table" } -lance-test-macros = { version = "=0.18.3", path = "./rust/lance-test-macros" } -lance-testing = { version = "=0.18.3", path = "./rust/lance-testing" } +lance = { version = "=0.19.2", path = "./rust/lance" } +lance-arrow = { version = "=0.19.2", path = "./rust/lance-arrow" } +lance-core = { version = "=0.19.2", path = "./rust/lance-core" } +lance-datafusion = { version = "=0.19.2", path = "./rust/lance-datafusion" } +lance-datagen = { version = "=0.19.2", path = "./rust/lance-datagen" } +lance-encoding = { version = "=0.19.2", path = "./rust/lance-encoding" } +lance-encoding-datafusion = { version = "=0.19.2", path = "./rust/lance-encoding-datafusion" } +lance-file = { version = "=0.19.2", path = "./rust/lance-file" } +lance-index = { version = "=0.19.2", path = "./rust/lance-index" } +lance-io = { version = "=0.19.2", path = "./rust/lance-io" } +lance-jni = { version = "=0.19.2", path = "./java/core/lance-jni" } +lance-linalg = { version = "=0.19.2", path = "./rust/lance-linalg" } +lance-table = { version = "=0.19.2", path = "./rust/lance-table" } +lance-test-macros = { version = "=0.19.2", path = "./rust/lance-test-macros" } +lance-testing = { version = "=0.19.2", path = "./rust/lance-testing" } approx = "0.5.1" # Note that this one does not include pyarrow arrow = { version = "52.2", optional = false, features = ["prettyprint"] } @@ -111,7 +111,7 @@ datafusion-physical-expr = { version = "41.0", features = [ ] } deepsize = "0.2.0" either = "1.0" -fsst = { version = "=0.18.3", path = "./rust/lance-encoding/src/compression_algo/fsst" } +fsst = { version = "=0.19.2", path = "./rust/lance-encoding/src/compression_algo/fsst" } futures = "0.3" http = "0.2.9" hyperloglogplus = { version = "0.4.1", features = ["const-loop"] } @@ -141,7 +141,7 @@ serde = { version = "^1" } serde_json = { version = "1" } shellexpand = "3.0" snafu = "0.7.5" -tantivy = "0.22.0" +tantivy = { version = "0.22.0", features = ["stopwords"] } tempfile = "3" test-log = { version = "0.2.15" } tokio = { version = "1.23", features = [ diff --git a/java/core/pom.xml b/java/core/pom.xml index c2d82fd5bb..64b70812bf 100644 --- a/java/core/pom.xml +++ b/java/core/pom.xml @@ -8,7 +8,7 @@ com.lancedb lance-parent - 0.18.3 + 0.19.2 ../pom.xml diff --git a/java/pom.xml b/java/pom.xml index 462c48b4e7..536e8521e9 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -6,7 +6,7 @@ com.lancedb lance-parent - 0.18.3 + 0.19.2 pom Lance Parent diff --git a/java/spark/pom.xml b/java/spark/pom.xml index cc3d390feb..76bd97c937 100644 --- a/java/spark/pom.xml +++ b/java/spark/pom.xml @@ -8,7 +8,7 @@ com.lancedb lance-parent - 0.18.3 + 0.19.2 ../pom.xml @@ -40,7 +40,7 @@ com.lancedb lance-core - 0.18.3 + 0.19.2 org.apache.spark diff --git a/python/Cargo.toml b/python/Cargo.toml index d451581c20..2ceb46da69 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pylance" -version = "0.18.3" +version = "0.19.2" edition = "2021" authors = ["Lance Devs "] rust-version = "1.65" diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 2909f23e7f..7a3e91a711 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -7,9 +7,7 @@ import json import logging import os -import pickle import random -import sqlite3 import time import warnings from abc import ABC, abstractmethod @@ -24,7 +22,6 @@ Iterator, List, Literal, - NamedTuple, Optional, Set, TypedDict, @@ -39,7 +36,6 @@ from .dependencies import ( _check_for_hugging_face, _check_for_numpy, - _check_for_pandas, torch, ) from .dependencies import numpy as np @@ -60,6 +56,10 @@ from .lance import _Session as Session from .optimize import Compaction from .schema import LanceSchema +from .types import _coerce_reader +from .udf import BatchUDF, normalize_transform +from .udf import BatchUDFCheckpoint as BatchUDFCheckpoint +from .udf import batch_udf as batch_udf from .util import td_to_micros if TYPE_CHECKING: @@ -67,15 +67,7 @@ from .commit import CommitLock from .progress import FragmentWriteProgress - - ReaderLike = Union[ - pd.Timestamp, - pa.Table, - pa.dataset.Dataset, - pa.dataset.Scanner, - Iterable[RecordBatch], - pa.RecordBatchReader, - ] + from .types import ReaderLike QueryVectorLike = Union[ pd.Series, @@ -917,6 +909,7 @@ def add_columns( transforms: Dict[str, str] | BatchUDF | ReaderLike, read_columns: List[str] | None = None, reader_schema: Optional[pa.Schema] = None, + batch_size: Optional[int] = None, ): """ Add new columns with defined values. @@ -948,6 +941,9 @@ def add_columns( reader_schema: pa.Schema, optional Only valid if transforms is a `ReaderLike` object. This will be used to determine the schema of the reader. + batch_size: int, optional + The number of rows to read at a time from the source dataset when applying + the transform. This is ignored if the dataset is a v1 dataset. Examples -------- @@ -977,42 +973,16 @@ def add_columns( LanceDataset.merge : Merge a pre-computed set of columns into the dataset. """ - if isinstance(transforms, BatchUDF): - if transforms.output_schema is None: - # Infer the schema based on the first batch - sample_batch = transforms( - next(iter(self.to_batches(limit=1, columns=read_columns))) - ) - if isinstance(sample_batch, pd.DataFrame): - sample_batch = pa.RecordBatch.from_pandas(sample_batch) - transforms.output_schema = sample_batch.schema - del sample_batch - elif isinstance(transforms, dict): - for k, v in transforms.items(): - if not isinstance(k, str): - raise TypeError(f"Column names must be a string. Got {type(k)}") - if not isinstance(v, str): - raise TypeError( - f"Column expressions must be a string. Got {type(k)}" - ) + transforms = normalize_transform(transforms, self, read_columns, reader_schema) + if isinstance(transforms, pa.RecordBatchReader): + self._ds.add_columns_from_reader(transforms, batch_size) + return else: - try: - reader = _coerce_reader(transforms, reader_schema) - self._ds.add_columns_from_reader(reader) - return + self._ds.add_columns(transforms, read_columns, batch_size) - except TypeError as inner_err: - raise TypeError( - "transforms must be a dict, AddColumnsUDF, or a ReaderLike value. " - f"Received {type(transforms)}. Could not coerce to a " - f"reader: {inner_err}" - ) - - self._ds.add_columns(transforms, read_columns) - - if isinstance(transforms, BatchUDF): - if transforms.cache is not None: - transforms.cache.cleanup() + if isinstance(transforms, BatchUDF): + if transforms.cache is not None: + transforms.cache.cleanup() def drop_columns(self, columns: List[str]): """Drop one or more columns from the dataset @@ -1379,6 +1349,31 @@ def create_scalar_index( query. This will significantly increase the index size. It won't impact the performance of non-phrase queries even if it is set to True. + base_tokenizer: str, default "simple" + This is for the ``INVERTED`` index. The base tokenizer to use. The value + can be: + * "simple": splits tokens on whitespace and punctuation. + * "whitespace": splits tokens on whitespace. + * "raw": no tokenization. + language: str, default "English" + This is for the ``INVERTED`` index. The language for stemming + and stop words. This is only used when `stem` or `remove_stop_words` is true + max_token_length: Optional[int], default 40 + This is for the ``INVERTED`` index. The maximum token length. + Any token longer than this will be removed. + lower_case: bool, default True + This is for the ``INVERTED`` index. If True, the index will convert all + text to lowercase. + stem: bool, default False + This is for the ``INVERTED`` index. If True, the index will stem the + tokens. + remove_stop_words: bool, default False + This is for the ``INVERTED`` index. If True, the index will remove + stop words. + ascii_folding: bool, default False + This is for the ``INVERTED`` index. If True, the index will convert + non-ascii characters to ascii characters if possible. + This would remove accents like "é" -> "e". Examples -------- @@ -3240,46 +3235,6 @@ def write_dataset( return ds -def _coerce_reader( - data_obj: ReaderLike, schema: Optional[pa.Schema] = None -) -> pa.RecordBatchReader: - if _check_for_pandas(data_obj) and isinstance(data_obj, pd.DataFrame): - return pa.Table.from_pandas(data_obj, schema=schema).to_reader() - elif isinstance(data_obj, pa.Table): - return data_obj.to_reader() - elif isinstance(data_obj, pa.RecordBatch): - return pa.Table.from_batches([data_obj]).to_reader() - elif isinstance(data_obj, LanceDataset): - return data_obj.scanner().to_reader() - elif isinstance(data_obj, pa.dataset.Dataset): - return pa.dataset.Scanner.from_dataset(data_obj).to_reader() - elif isinstance(data_obj, pa.dataset.Scanner): - return data_obj.to_reader() - elif isinstance(data_obj, pa.RecordBatchReader): - return data_obj - elif ( - type(data_obj).__module__.startswith("polars") - and data_obj.__class__.__name__ == "DataFrame" - ): - return data_obj.to_arrow().to_reader() - # for other iterables, assume they are of type Iterable[RecordBatch] - elif isinstance(data_obj, Iterable): - if schema is not None: - data = _casting_recordbatch_iter(data_obj, schema) - return pa.RecordBatchReader.from_batches(schema, data) - else: - raise ValueError( - "Must provide schema to write dataset from RecordBatch iterable" - ) - else: - raise TypeError( - f"Unknown data type {type(data_obj)}. " - "Please check " - "https://lancedb.github.io/lance/read_and_write.html " - "to see supported types." - ) - - def _coerce_query_vector(query: QueryVectorLike): if isinstance(query, pa.Scalar): if isinstance(query, pa.ExtensionScalar): @@ -3341,175 +3296,3 @@ def _validate_metadata(metadata: dict): ) elif isinstance(v, dict): _validate_metadata(v) - - -def _casting_recordbatch_iter( - input_iter: Iterable[pa.RecordBatch], schema: pa.Schema -) -> Iterable[pa.RecordBatch]: - """ - Wrapper around an iterator of record batches. If the batches don't match the - schema, try to cast them to the schema. If that fails, raise an error. - - This is helpful for users who might have written the iterator with default - data types in PyArrow, but specified more specific types in the schema. For - example, PyArrow defaults to float64 for floating point types, but Lance - uses float32 for vectors. - """ - for batch in input_iter: - if not isinstance(batch, pa.RecordBatch): - raise TypeError(f"Expected RecordBatch, got {type(batch)}") - if batch.schema != schema: - try: - # RecordBatch doesn't have a cast method, but table does. - batch = pa.Table.from_batches([batch]).cast(schema).to_batches()[0] - except pa.lib.ArrowInvalid: - raise ValueError( - f"Input RecordBatch iterator yielded a batch with schema that " - f"does not match the expected schema.\nExpected:\n{schema}\n" - f"Got:\n{batch.schema}" - ) - yield batch - - -class BatchUDF: - """A user-defined function that can be passed to :meth:`LanceDataset.add_columns`. - - Use :func:`lance.add_columns_udf` decorator to wrap a function with this class. - """ - - def __init__(self, func, output_schema=None, checkpoint_file=None): - self.func = func - self.output_schema = output_schema - if checkpoint_file is not None: - self.cache = BatchUDFCheckpoint(checkpoint_file) - else: - self.cache = None - - def __call__(self, batch: pa.RecordBatch): - # Directly call inner function. This is to allow the user to test the - # function and have it behave exactly as it was written. - return self.func(batch) - - def _call(self, batch: pa.RecordBatch): - if self.output_schema is None: - raise ValueError( - "output_schema must be provided when using a function that " - "returns a RecordBatch" - ) - result = self.func(batch) - - if _check_for_pandas(result): - if isinstance(result, pd.DataFrame): - result = pa.RecordBatch.from_pandas(result) - assert result.schema == self.output_schema, ( - f"Output schema of function does not match the expected schema. " - f"Expected:\n{self.output_schema}\nGot:\n{result.schema}" - ) - return result - - -def batch_udf(output_schema=None, checkpoint_file=None): - """ - Create a user defined function (UDF) that adds columns to a dataset. - - This function is used to add columns to a dataset. It takes a function that - takes a single argument, a RecordBatch, and returns a RecordBatch. The - function is called once for each batch in the dataset. The function should - not modify the input batch, but instead create a new batch with the new - columns added. - - Parameters - ---------- - output_schema : Schema, optional - The schema of the output RecordBatch. This is used to validate the - output of the function. If not provided, the schema of the first output - RecordBatch will be used. - checkpoint_file : str or Path, optional - If specified, this file will be used as a cache for unsaved results of - this UDF. If the process fails, and you call add_columns again with this - same file, it will resume from the last saved state. This is useful for - long running processes that may fail and need to be resumed. This file - may get very large. It will hold up to an entire data files' worth of - results on disk, which can be multiple gigabytes of data. - - Returns - ------- - AddColumnsUDF - """ - - def inner(func): - return BatchUDF(func, output_schema, checkpoint_file) - - return inner - - -class BatchUDFCheckpoint: - """A cache for BatchUDF results to avoid recomputation. - - This is backed by a SQLite database. - """ - - class BatchInfo(NamedTuple): - fragment_id: int - batch_index: int - - def __init__(self, path): - self.path = path - # We don't re-use the connection because it's not thread safe - conn = sqlite3.connect(path) - # One table to store the results for each batch. - conn.execute( - """ - CREATE TABLE IF NOT EXISTS batches - (fragment_id INT, batch_index INT, result BLOB) - """ - ) - # One table to store fully written (but not committed) fragments. - conn.execute( - "CREATE TABLE IF NOT EXISTS fragments (fragment_id INT, data BLOB)" - ) - conn.commit() - - def cleanup(self): - os.remove(self.path) - - def get_batch(self, info: BatchInfo) -> Optional[pa.RecordBatch]: - conn = sqlite3.connect(self.path) - cursor = conn.execute( - "SELECT result FROM batches WHERE fragment_id = ? AND batch_index = ?", - (info.fragment_id, info.batch_index), - ) - row = cursor.fetchone() - if row is not None: - return pickle.loads(row[0]) - return None - - def insert_batch(self, info: BatchInfo, batch: pa.RecordBatch): - conn = sqlite3.connect(self.path) - conn.execute( - "INSERT INTO batches (fragment_id, batch_index, result) VALUES (?, ?, ?)", - (info.fragment_id, info.batch_index, pickle.dumps(batch)), - ) - conn.commit() - - def get_fragment(self, fragment_id: int) -> Optional[str]: - """Retrieves a fragment as a JSON string.""" - conn = sqlite3.connect(self.path) - cursor = conn.execute( - "SELECT data FROM fragments WHERE fragment_id = ?", (fragment_id,) - ) - row = cursor.fetchone() - if row is not None: - return row[0] - return None - - def insert_fragment(self, fragment_id: int, fragment: str): - """Save a JSON string of a fragment to the cache.""" - # Clear all batches for the fragment - conn = sqlite3.connect(self.path) - conn.execute( - "INSERT INTO fragments (fragment_id, data) VALUES (?, ?)", - (fragment_id, fragment), - ) - conn.execute("DELETE FROM batches WHERE fragment_id = ?", (fragment_id,)) - conn.commit() diff --git a/python/python/lance/fragment.py b/python/python/lance/fragment.py index 808d25c822..25e4f271ec 100644 --- a/python/python/lance/fragment.py +++ b/python/python/lance/fragment.py @@ -27,6 +27,7 @@ from .lance import _Fragment, _write_fragments from .lance import _FragmentMetadata as _FragmentMetadata from .progress import FragmentWriteProgress, NoopFragmentWriteProgress +from .udf import BatchUDF, normalize_transform if TYPE_CHECKING: from .dataset import LanceDataset, LanceScanner, ReaderLike @@ -361,9 +362,13 @@ def to_table( def merge_columns( self, - value_func: Callable[[pa.RecordBatch], pa.RecordBatch], + value_func: Dict[str, str] + | BatchUDF + | ReaderLike + | Callable[[pa.RecordBatch], pa.RecordBatch], columns: Optional[list[str]] = None, batch_size: Optional[int] = None, + reader_schema: Optional[pa.Schema] = None, ) -> Tuple[FragmentMetadata, LanceSchema]: """Add columns to this Fragment. @@ -371,13 +376,12 @@ def merge_columns( Internal API. This method is not intended to be used by end users. - Parameters - ---------- - value_func: Callable. - A function that takes a RecordBatch as input and returns a RecordBatch. - columns: Optional[list[str]]. - If specified, only the columns in this list will be passed to the - value_func. Otherwise, all columns will be passed to the value_func. + The parameters and their interpretation are the same as in the + :meth:`lance.dataset.LanceDataset.add_columns` operation. + + The only difference is that, instead of modifying the dataset, a new + fragment is created. The new schema of the fragment is returned as well. + These can be used in a later operation to commit the changes to the dataset. See Also -------- @@ -390,63 +394,26 @@ def merge_columns( Tuple[FragmentMetadata, LanceSchema] A new fragment with the added column(s) and the final schema. """ - updater = self._fragment.updater(columns, batch_size) - - while True: - batch = updater.next() - if batch is None: - break - new_value = value_func(batch) - if not isinstance(new_value, pa.RecordBatch): + transforms = normalize_transform(value_func, self, columns, reader_schema) + + if isinstance(transforms, BatchUDF): + if transforms.cache is not None: raise ValueError( - f"value_func must return a Pyarrow RecordBatch, " - f"got {type(new_value)}" + "A checkpoint file cannot be used when applying a UDF with " + "LanceFragment.merge_columns. You must apply your own " + "checkpointing for fragment-level operations." ) - updater.update(new_value) - metadata = updater.finish() - schema = updater.schema() - return FragmentMetadata.from_metadata(metadata), schema - - def add_columns( - self, - value_func: Callable[[pa.RecordBatch], pa.RecordBatch], - columns: Optional[list[str]] = None, - ) -> FragmentMetadata: - """Add columns to this Fragment. - - .. deprecated:: 0.10.14 - Use :meth:`merge_columns` instead. - - .. warning:: - - Internal API. This method is not intended to be used by end users. - - Parameters - ---------- - value_func: Callable. - A function that takes a RecordBatch as input and returns a RecordBatch. - columns: Optional[list[str]]. - If specified, only the columns in this list will be passed to the - value_func. Otherwise, all columns will be passed to the value_func. - - See Also - -------- - lance.dataset.LanceOperation.Merge : - The operation used to commit these changes to the dataset. See the - doc page for an example of using this API. + if isinstance(transforms, pa.RecordBatchReader): + metadata, schema = self._fragment.add_columns_from_reader( + transforms, batch_size + ) + else: + metadata, schema = self._fragment.add_columns( + transforms, columns, batch_size + ) - Returns - ------- - FragmentMetadata - A new fragment with the added column(s). - """ - warnings.warn( - "LanceFragment.add_columns is deprecated, use LanceFragment.merge_columns " - "instead", - DeprecationWarning, - ) - return self.merge_columns(value_func, columns)[0] + return FragmentMetadata.from_metadata(metadata), schema def delete(self, predicate: str) -> FragmentMetadata | None: """Delete rows from this Fragment. diff --git a/python/python/lance/torch/distance.py b/python/python/lance/torch/distance.py index 6d1390270a..c31d637ed0 100644 --- a/python/python/lance/torch/distance.py +++ b/python/python/lance/torch/distance.py @@ -220,7 +220,7 @@ def l2_distance( A tuple of Tensors, for centroids id, and distance to the centroids. """ split = _suggest_batch_size(centroids) - while split >= 256: + while split >= 128: try: return _l2_distance(vectors, centroids, split_size=split, y2=y2) except RuntimeError as e: # noqa: PERF203 diff --git a/python/python/lance/types.py b/python/python/lance/types.py new file mode 100644 index 0000000000..b0559c5ff1 --- /dev/null +++ b/python/python/lance/types.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterable, Optional, Union + +import pyarrow as pa +from pyarrow import RecordBatch + +from . import dataset +from .dependencies import _check_for_pandas +from .dependencies import pandas as pd + +if TYPE_CHECKING: + ReaderLike = Union[ + pd.Timestamp, + pa.Table, + pa.dataset.Dataset, + pa.dataset.Scanner, + Iterable[RecordBatch], + pa.RecordBatchReader, + ] + + +def _casting_recordbatch_iter( + input_iter: Iterable[pa.RecordBatch], schema: pa.Schema +) -> Iterable[pa.RecordBatch]: + """ + Wrapper around an iterator of record batches. If the batches don't match the + schema, try to cast them to the schema. If that fails, raise an error. + + This is helpful for users who might have written the iterator with default + data types in PyArrow, but specified more specific types in the schema. For + example, PyArrow defaults to float64 for floating point types, but Lance + uses float32 for vectors. + """ + for batch in input_iter: + if not isinstance(batch, pa.RecordBatch): + raise TypeError(f"Expected RecordBatch, got {type(batch)}") + if batch.schema != schema: + try: + # RecordBatch doesn't have a cast method, but table does. + batch = pa.Table.from_batches([batch]).cast(schema).to_batches()[0] + except pa.lib.ArrowInvalid: + raise ValueError( + f"Input RecordBatch iterator yielded a batch with schema that " + f"does not match the expected schema.\nExpected:\n{schema}\n" + f"Got:\n{batch.schema}" + ) + yield batch + + +def _coerce_reader( + data_obj: ReaderLike, schema: Optional[pa.Schema] = None +) -> pa.RecordBatchReader: + if _check_for_pandas(data_obj) and isinstance(data_obj, pd.DataFrame): + return pa.Table.from_pandas(data_obj, schema=schema).to_reader() + elif isinstance(data_obj, pa.Table): + return data_obj.to_reader() + elif isinstance(data_obj, pa.RecordBatch): + return pa.Table.from_batches([data_obj]).to_reader() + elif isinstance(data_obj, dataset.LanceDataset): + return data_obj.scanner().to_reader() + elif isinstance(data_obj, pa.dataset.Dataset): + return pa.dataset.Scanner.from_dataset(data_obj).to_reader() + elif isinstance(data_obj, pa.dataset.Scanner): + return data_obj.to_reader() + elif isinstance(data_obj, pa.RecordBatchReader): + return data_obj + elif ( + type(data_obj).__module__.startswith("polars") + and data_obj.__class__.__name__ == "DataFrame" + ): + return data_obj.to_arrow().to_reader() + # for other iterables, assume they are of type Iterable[RecordBatch] + elif isinstance(data_obj, Iterable): + if schema is not None: + data = _casting_recordbatch_iter(data_obj, schema) + return pa.RecordBatchReader.from_batches(schema, data) + else: + raise ValueError( + "Must provide schema to write dataset from RecordBatch iterable" + ) + else: + raise TypeError( + f"Unknown data type {type(data_obj)}. " + "Please check " + "https://lancedb.github.io/lance/read_and_write.html " + "to see supported types." + ) diff --git a/python/python/lance/udf.py b/python/python/lance/udf.py new file mode 100644 index 0000000000..57475195c5 --- /dev/null +++ b/python/python/lance/udf.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +from __future__ import annotations + +import os +import pickle +import sqlite3 +from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional + +import pyarrow as pa + +from .dependencies import ( + _check_for_pandas, +) +from .dependencies import pandas as pd +from .types import _coerce_reader + +if TYPE_CHECKING: + from .dataset import LanceDataset, LanceFragment + from .types import ReaderLike + + +class BatchUDF: + """A user-defined function that can be passed to :meth:`LanceDataset.add_columns`. + + Use :func:`lance.add_columns_udf` decorator to wrap a function with this class. + """ + + def __init__(self, func, output_schema=None, checkpoint_file=None): + self.func = func + self.output_schema = output_schema + if checkpoint_file is not None: + self.cache = BatchUDFCheckpoint(checkpoint_file) + else: + self.cache = None + + def __call__(self, batch: pa.RecordBatch): + # Directly call inner function. This is to allow the user to test the + # function and have it behave exactly as it was written. + return self.func(batch) + + def _call(self, batch: pa.RecordBatch): + if self.output_schema is None: + raise ValueError( + "output_schema must be provided when using a function that " + "returns a RecordBatch" + ) + result = self.func(batch) + + if _check_for_pandas(result): + if isinstance(result, pd.DataFrame): + result = pa.RecordBatch.from_pandas(result) + assert result.schema == self.output_schema, ( + f"Output schema of function does not match the expected schema. " + f"Expected:\n{self.output_schema}\nGot:\n{result.schema}" + ) + return result + + +def batch_udf(output_schema=None, checkpoint_file=None): + """ + Create a user defined function (UDF) that adds columns to a dataset. + + This function is used to add columns to a dataset. It takes a function that + takes a single argument, a RecordBatch, and returns a RecordBatch. The + function is called once for each batch in the dataset. The function should + not modify the input batch, but instead create a new batch with the new + columns added. + + Parameters + ---------- + output_schema : Schema, optional + The schema of the output RecordBatch. This is used to validate the + output of the function. If not provided, the schema of the first output + RecordBatch will be used. + checkpoint_file : str or Path, optional + If specified, this file will be used as a cache for unsaved results of + this UDF. If the process fails, and you call add_columns again with this + same file, it will resume from the last saved state. This is useful for + long running processes that may fail and need to be resumed. This file + may get very large. It will hold up to an entire data files' worth of + results on disk, which can be multiple gigabytes of data. + + Returns + ------- + AddColumnsUDF + """ + + def inner(func): + return BatchUDF(func, output_schema, checkpoint_file) + + return inner + + +class BatchUDFCheckpoint: + """A cache for BatchUDF results to avoid recomputation. + + This is backed by a SQLite database. + """ + + class BatchInfo(NamedTuple): + fragment_id: int + batch_index: int + + def __init__(self, path): + self.path = path + # We don't re-use the connection because it's not thread safe + conn = sqlite3.connect(path) + # One table to store the results for each batch. + conn.execute( + """ + CREATE TABLE IF NOT EXISTS batches + (fragment_id INT, batch_index INT, result BLOB) + """ + ) + # One table to store fully written (but not committed) fragments. + conn.execute( + "CREATE TABLE IF NOT EXISTS fragments (fragment_id INT, data BLOB)" + ) + conn.commit() + + def cleanup(self): + os.remove(self.path) + + def get_batch(self, info: BatchInfo) -> Optional[pa.RecordBatch]: + conn = sqlite3.connect(self.path) + cursor = conn.execute( + "SELECT result FROM batches WHERE fragment_id = ? AND batch_index = ?", + (info.fragment_id, info.batch_index), + ) + row = cursor.fetchone() + if row is not None: + return pickle.loads(row[0]) + return None + + def insert_batch(self, info: BatchInfo, batch: pa.RecordBatch): + conn = sqlite3.connect(self.path) + conn.execute( + "INSERT INTO batches (fragment_id, batch_index, result) VALUES (?, ?, ?)", + (info.fragment_id, info.batch_index, pickle.dumps(batch)), + ) + conn.commit() + + def get_fragment(self, fragment_id: int) -> Optional[str]: + """Retrieves a fragment as a JSON string.""" + conn = sqlite3.connect(self.path) + cursor = conn.execute( + "SELECT data FROM fragments WHERE fragment_id = ?", (fragment_id,) + ) + row = cursor.fetchone() + if row is not None: + return row[0] + return None + + def insert_fragment(self, fragment_id: int, fragment: str): + """Save a JSON string of a fragment to the cache.""" + # Clear all batches for the fragment + conn = sqlite3.connect(self.path) + conn.execute( + "INSERT INTO fragments (fragment_id, data) VALUES (?, ?)", + (fragment_id, fragment), + ) + conn.execute("DELETE FROM batches WHERE fragment_id = ?", (fragment_id,)) + conn.commit() + + +def normalize_transform( + udf_like: Dict[str, str] | BatchUDF | ReaderLike, + data_source: LanceDataset | LanceFragment, + read_columns: Optional[List[str]] = None, + reader_schema: Optional[pa.Schema] = None, +): + if isinstance(udf_like, BatchUDF): + if udf_like.output_schema is None: + # Infer the schema based on the first batch + sample_batch = udf_like( + next(iter(data_source.to_batches(limit=1, columns=read_columns))) + ) + if isinstance(sample_batch, pd.DataFrame): + sample_batch = pa.RecordBatch.from_pandas(sample_batch) + udf_like.output_schema = sample_batch.schema + + return udf_like + elif isinstance(udf_like, dict): + for k, v in udf_like.items(): + if not isinstance(k, str): + raise TypeError(f"Column names must be a string. Got {type(k)}") + if not isinstance(v, str): + raise TypeError(f"Column expressions must be a string. Got {type(k)}") + + return udf_like + # Is this a callable/function that is not a BatchUDF? If so, wrap in a BatchUDF + elif callable(udf_like): + try: + sample_batch = udf_like( + next(iter(data_source.to_batches(limit=1, columns=read_columns))) + ) + if isinstance(sample_batch, pd.DataFrame): + sample_batch = pa.RecordBatch.from_pandas(sample_batch) + udf_like = BatchUDF(udf_like, output_schema=sample_batch.schema) + + return udf_like + except Exception as inner_err: + raise TypeError( + "transforms must be a BatchUDF, dict, map function, or ReaderLike " + f"value. Received {type(udf_like)}, which is callable, but gave " + f"an error when called with a batch of data: {inner_err}" + ) + # Last thing we check is to see if we can coerce into a RecordBatchReader + else: + try: + reader = _coerce_reader(udf_like, reader_schema) + return reader + + except TypeError as inner_err: + raise TypeError( + "transforms must be a BatchUDF, dict, map function, or ReaderLike " + f"value. Received {type(udf_like)}. Could not coerce to a " + f"reader: {inner_err}" + ) diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index e5e77ba314..072c8e201f 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -706,36 +706,6 @@ def test_pickle_fragment(tmp_path: Path): assert fragment.to_table() == unpickled.to_table() -def test_add_columns(tmp_path: Path): - table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) - base_dir = tmp_path / "test" - lance.write_dataset(table, base_dir) - - dataset = lance.dataset(base_dir) - fragments = dataset.get_fragments() - - fragment = fragments[0] - - def adder(batch: pa.RecordBatch) -> pa.RecordBatch: - c_array = pa.compute.multiply(batch.column(0), 2) - return pa.RecordBatch.from_arrays([c_array], names=["c"]) - - fragment_metadata, schema = fragment.merge_columns(adder, columns=["a"]) - - operation = lance.LanceOperation.Overwrite(schema.to_pyarrow(), [fragment_metadata]) - dataset = lance.LanceDataset.commit(base_dir, operation) - assert dataset.schema == schema.to_pyarrow() - - tbl = dataset.to_table() - assert tbl == pa.Table.from_pydict( - { - "a": range(100), - "b": range(100), - "c": pa.array(range(0, 200, 2), pa.int64()), - } - ) - - def test_cleanup_old_versions(tmp_path): table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) base_dir = tmp_path / "test" @@ -921,12 +891,9 @@ def test_merge_with_commit(tmp_path: Path): lance.write_dataset(table, base_dir) fragment = lance.dataset(base_dir).get_fragments()[0] - # add_columns is deprecated, but we can make sure it still works - # for now. - with pytest.deprecated_call(): - merged = fragment.add_columns( - lambda _: pa.RecordBatch.from_pydict({"c": range(100)}) - ) + merged = fragment.merge_columns( + lambda _: pa.RecordBatch.from_pydict({"c": range(100)}) + )[0] expected = pa.Table.from_pydict({"a": range(100), "b": range(100), "c": range(100)}) @@ -940,33 +907,6 @@ def test_merge_with_commit(tmp_path: Path): assert tbl == expected -def test_merge_batch_size(tmp_path: Path): - # Create dataset with 10 fragments with 100 rows each - table = pa.table({"a": range(1000)}) - for batch_size in [1, 10, 100, 1000]: - ds_path = str(tmp_path / str(batch_size)) - dataset = lance.write_dataset(table, ds_path, max_rows_per_file=100) - fragments = [] - - def mutate(batch): - assert batch.num_rows <= batch_size - return pa.RecordBatch.from_pydict({"b": batch.column("a")}) - - for frag in dataset.get_fragments(): - merged, schema = frag.merge_columns(mutate, batch_size=batch_size) - fragments.append(merged) - - merge = lance.LanceOperation.Merge(fragments, schema) - dataset = lance.LanceDataset.commit( - ds_path, merge, read_version=dataset.version - ) - - dataset.validate() - tbl = dataset.to_table() - expected = pa.table({"a": range(1000), "b": range(1000)}) - assert tbl == expected - - def test_merge_with_schema_holes(tmp_path: Path): # Create table with 3 cols table = pa.table({"a": range(10)}) diff --git a/python/python/tests/test_schema_evolution.py b/python/python/tests/test_schema_evolution.py index 7a3b4eb488..b2b6acbcfe 100644 --- a/python/python/tests/test_schema_evolution.py +++ b/python/python/tests/test_schema_evolution.py @@ -11,6 +11,7 @@ import pyarrow as pa import pyarrow.compute as pc import pytest +from lance import LanceDataset from lance.file import LanceFileReader, LanceFileWriter @@ -48,7 +49,48 @@ def test_drop_columns(tmp_path: Path): dataset.drop_columns(["c"]) -def test_add_columns_udf(tmp_path): +# The LanceDataset.add_columns and LanceFragment.merge_columns should be mostly the +# same. Tests that test these methods can use this fixture to test both methods. +def check_add_columns( + dataset: LanceDataset, expected: pa.Table, use_fragments: bool, *args, **kwargs +): + if use_fragments: + # Ensure we are working with latest dataset version + dataset = lance.dataset(dataset.uri) + new_frags = [] + for fragment in dataset.get_fragments(): + # the parameter name is different in `merge_columns` (backwards compat.) + if "read_columns" in kwargs: + kwargs["columns"] = kwargs.pop("read_columns") + new_frag, schema = fragment.merge_columns(*args, **kwargs) + new_frags.append(new_frag) + op = lance.LanceOperation.Merge(new_frags, schema) + dataset = LanceDataset.commit(dataset.uri, op, read_version=dataset.version) + assert dataset.to_table() == expected + else: + dataset.add_columns(*args, **kwargs) + assert dataset.to_table() == expected + + +def check_add_columns_fails( + dataset: LanceDataset, + use_fragments: bool, + expected_exception: any, + match: str, + *args, + **kwargs, +): + if use_fragments: + with pytest.raises(expected_exception, match=match): + frag = dataset.get_fragments()[0] + frag.merge_columns(*args, **kwargs) + else: + with pytest.raises(expected_exception, match=match): + dataset.add_columns(*args, **kwargs) + + +@pytest.mark.parametrize("use_fragments", [False, True]) +def test_add_columns_udf(tmp_path, use_fragments): tab = pa.table({"a": range(100), "b": range(100)}) dataset = lance.write_dataset(tab, tmp_path, max_rows_per_file=25) @@ -61,10 +103,8 @@ def double_a(batch): [pa.array([2 * x.as_py() for x in batch["a"]])], ["double_a"] ) - dataset.add_columns(double_a, read_columns=["a"]) - expected = tab.append_column("double_a", pa.array([2 * x for x in range(100)])) - assert expected == dataset.to_table() + check_add_columns(dataset, expected, use_fragments, double_a, read_columns=["a"]) # Check: errors if produces inconsistent schema @lance.batch_udf() @@ -72,20 +112,21 @@ def make_new_col(batch): col_name = str(uuid.uuid4()) return pa.record_batch([batch["a"]], [col_name]) - with pytest.raises( - Exception, match="Output schema of function does not match the expected schema" - ): - dataset.add_columns(make_new_col) + check_add_columns_fails( + dataset, + use_fragments, + Exception, + "Output schema of function does not match the expected schema", + make_new_col, + ) # Schema inference and Pandas conversion @lance.batch_udf() def triple_a(batch): return pd.DataFrame({"triple_a": [3 * x.as_py() for x in batch["a"]]}) - dataset.add_columns(triple_a, read_columns=["a"]) - expected = expected.append_column("triple_a", pa.array([3 * x for x in range(100)])) - assert expected == dataset.to_table() + check_add_columns(dataset, expected, use_fragments, triple_a, read_columns=["a"]) def test_add_columns_from_rbr(tmp_path): @@ -217,11 +258,12 @@ def double_a(batch): assert "cache.sqlite" not in os.listdir(tmp_path) -def test_add_columns_exprs(tmp_path): +@pytest.mark.parametrize("use_fragments", [False, True]) +def test_add_columns_exprs(tmp_path, use_fragments): tab = pa.table({"a": range(100)}) dataset = lance.write_dataset(tab, tmp_path) - dataset.add_columns({"b": "a + 1"}) - assert dataset.to_table() == pa.table({"a": range(100), "b": range(1, 101)}) + expected = pa.table({"a": range(100), "b": range(1, 101)}) + check_add_columns(dataset, expected, use_fragments, {"b": "a + 1"}) def test_add_many_columns(tmp_path: Path): @@ -232,6 +274,19 @@ def test_add_many_columns(tmp_path: Path): assert dataset.to_table().num_rows == 3 +@pytest.mark.parametrize("use_fragments", [False, True]) +def test_add_columns_callable(tmp_path: Path, use_fragments): + table = pa.table({"a": range(100)}) + dataset = lance.write_dataset(table, tmp_path) + + def mapper(batch: pa.RecordBatch): + plus_one = pc.add(batch["a"], 1) + return pa.record_batch([plus_one], names=["b"]) + + expected = pa.table({"a": range(100), "b": range(1, 101)}) + check_add_columns(dataset, expected, use_fragments, mapper) + + def test_query_after_merge(tmp_path): # https://github.com/lancedb/lance/issues/1905 tab = pa.table( @@ -323,3 +378,137 @@ def test_alter_columns(tmp_path: Path): match="At least one of name, nullable, or data_type must be specified", ): dataset.alter_columns({"path": "x"}) + + +def test_merge_columns(tmp_path: Path): + table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) + base_dir = tmp_path / "test" + lance.write_dataset(table, base_dir) + + dataset = lance.dataset(base_dir) + fragments = dataset.get_fragments() + + fragment = fragments[0] + + def adder(batch: pa.RecordBatch) -> pa.RecordBatch: + c_array = pa.compute.multiply(batch.column(0), 2) + return pa.RecordBatch.from_arrays([c_array], names=["c"]) + + fragment_metadata, schema = fragment.merge_columns(adder, columns=["a"]) + + operation = lance.LanceOperation.Overwrite(schema.to_pyarrow(), [fragment_metadata]) + dataset = lance.LanceDataset.commit(base_dir, operation) + assert dataset.schema == schema.to_pyarrow() + + tbl = dataset.to_table() + assert tbl == pa.Table.from_pydict( + { + "a": range(100), + "b": range(100), + "c": pa.array(range(0, 200, 2), pa.int64()), + } + ) + + +def test_merge_columns_from_reader(tmp_path: Path): + table = pa.Table.from_pydict({"a": range(100), "b": range(100)}) + base_dir = tmp_path / "test" + lance.write_dataset(table, base_dir) + + dataset = lance.dataset(base_dir) + fragments = dataset.get_fragments() + + fragment = fragments[0] + + with LanceFileWriter(tmp_path / "some_file") as writer: + writer.write_batch(pa.table({"c": range(100), "d": range(100)})) + + def datareader(): + reader = LanceFileReader(str(tmp_path / "some_file")) + for batch in reader.read_all(batch_size=10).to_batches(): + yield batch + + fragment_metadata, schema = fragment.merge_columns( + datareader(), + reader_schema=pa.schema([pa.field("c", pa.int64()), pa.field("d", pa.int64())]), + batch_size=15, + ) + + operation = lance.LanceOperation.Overwrite(schema.to_pyarrow(), [fragment_metadata]) + dataset = lance.LanceDataset.commit(base_dir, operation) + assert dataset.schema == schema.to_pyarrow() + + tbl = dataset.to_table() + assert tbl == pa.Table.from_pydict( + { + "a": range(100), + "b": range(100), + "c": range(100), + "d": range(100), + } + ) + + +def test_merge_batch_size(tmp_path: Path): + # Create dataset with 10 fragments with 100 rows each + table = pa.table({"a": range(1000)}) + for batch_size in [1, 10, 100, 1000]: + ds_path = str(tmp_path / str(batch_size)) + dataset = lance.write_dataset(table, ds_path, max_rows_per_file=100) + fragments = [] + + def mutate(batch): + assert batch.num_rows <= batch_size + return pa.RecordBatch.from_pydict({"b": batch.column("a")}) + + for frag in dataset.get_fragments(): + merged, schema = frag.merge_columns(mutate, batch_size=batch_size) + fragments.append(merged) + + merge = lance.LanceOperation.Merge(fragments, schema) + dataset = lance.LanceDataset.commit( + ds_path, merge, read_version=dataset.version + ) + + dataset.validate() + tbl = dataset.to_table() + expected = pa.table({"a": range(1000), "b": range(1000)}) + assert tbl == expected + + +def test_add_cols_batch_size(tmp_path: Path): + # Same test as `test_merge_batch_size` but using LanceDataset.add_columns instead + table = pa.table({"a": range(1000)}) + for batch_size in [1, 10, 100, 1000]: + ds_path = str(tmp_path / str(batch_size)) + dataset = lance.write_dataset(table, ds_path, max_rows_per_file=100) + + def mutate(batch): + assert batch.num_rows <= batch_size + return pa.RecordBatch.from_pydict({"b": batch.column("a")}) + + dataset.add_columns(mutate, batch_size=batch_size) + + dataset.validate() + tbl = dataset.to_table() + expected = pa.table({"a": range(1000), "b": range(1000)}) + assert tbl == expected + + +def test_no_checkpoint_merge_columns(tmp_path: Path): + tab = pa.table( + { + "a": range(100), + "b": range(100), + } + ) + dataset = lance.write_dataset(tab, tmp_path, max_rows_per_file=20) + + @lance.batch_udf(checkpoint_file=tmp_path / "cache.sqlite") + def some_udf(batch): + return batch + + frag = dataset.get_fragments()[0] + + with pytest.raises(ValueError, match="A checkpoint file cannot be used"): + frag.merge_columns(some_udf, columns=["a"]) diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 6655747cfc..92d0b0efb1 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -362,6 +362,50 @@ impl Operation { } } +pub fn transforms_from_python(transforms: &PyAny) -> PyResult { + if let Ok(transforms) = transforms.extract::<&PyDict>() { + let expressions = transforms + .iter() + .map(|(k, v)| { + let col = k.extract::()?; + let expr = v.extract::()?; + Ok((col, expr)) + }) + .collect::>>()?; + Ok(NewColumnTransform::SqlExpressions(expressions)) + } else { + let append_schema: PyArrowType = + transforms.getattr("output_schema")?.extract()?; + let output_schema = Arc::new(append_schema.0); + + let result_checkpoint: Option = transforms.getattr("cache")?.extract()?; + let result_checkpoint = result_checkpoint.map(|c| PyBatchUDFCheckpointWrapper { inner: c }); + + let udf_obj = transforms.to_object(transforms.py()); + let mapper = move |batch: &RecordBatch| -> lance::Result { + Python::with_gil(|py| { + let py_batch: PyArrowType = PyArrowType(batch.clone()); + let result = udf_obj + .call_method1(py, "_call", (py_batch,)) + .map_err(|err| { + lance::Error::io(format_python_error(err, py).unwrap(), location!()) + })?; + let result_batch: PyArrowType = result + .extract(py) + .map_err(|err| lance::Error::io(err.to_string(), location!()))?; + Ok(result_batch.0) + }) + }; + + Ok(NewColumnTransform::BatchUDF(BatchUDF { + mapper: Box::new(mapper), + output_schema, + result_checkpoint: result_checkpoint + .map(|c| Arc::new(c) as Arc), + })) + } +} + /// Lance Dataset that will be wrapped by another class in Python #[pyclass(name = "_Dataset", module = "_lib")] #[derive(Clone)] @@ -1231,6 +1275,43 @@ impl Dataset { if let Some(with_position) = kwargs.get_item("with_position")? { params.with_position = with_position.extract()?; } + if let Some(base_tokenizer) = kwargs.get_item("base_tokenizer")? { + params.tokenizer_config = params + .tokenizer_config + .base_tokenizer(base_tokenizer.extract()?); + } + if let Some(language) = kwargs.get_item("language")? { + let language = language.extract()?; + params.tokenizer_config = + params.tokenizer_config.language(language).map_err(|e| { + PyValueError::new_err(format!( + "can't set tokenizer language to {}: {:?}", + language, e + )) + })?; + } + if let Some(max_token_length) = kwargs.get_item("max_token_length")? { + params.tokenizer_config = params + .tokenizer_config + .max_token_length(max_token_length.extract()?); + } + if let Some(lower_case) = kwargs.get_item("lower_case")? { + params.tokenizer_config = + params.tokenizer_config.lower_case(lower_case.extract()?); + } + if let Some(stem) = kwargs.get_item("stem")? { + params.tokenizer_config = params.tokenizer_config.stem(stem.extract()?); + } + if let Some(remove_stop_words) = kwargs.get_item("remove_stop_words")? { + params.tokenizer_config = params + .tokenizer_config + .remove_stop_words(remove_stop_words.extract()?); + } + if let Some(ascii_folding) = kwargs.get_item("ascii_folding")? { + params.tokenizer_config = params + .tokenizer_config + .ascii_folding(ascii_folding.extract()?); + } } Box::new(params) } @@ -1381,7 +1462,11 @@ impl Dataset { Ok(()) } - fn add_columns_from_reader(&mut self, reader: &Bound) -> PyResult<()> { + fn add_columns_from_reader( + &mut self, + reader: &Bound, + batch_size: Option, + ) -> PyResult<()> { let batches = ArrowArrayStreamReader::from_pyarrow_bound(reader)?; let transforms = NewColumnTransform::Reader(Box::new(batches)); @@ -1389,7 +1474,7 @@ impl Dataset { let mut new_self = self.ds.as_ref().clone(); let new_self = RT .spawn(None, async move { - new_self.add_columns(transforms, None).await?; + new_self.add_columns(transforms, None, batch_size).await?; Ok(new_self) })? .map_err(|err: lance::Error| PyIOError::new_err(err.to_string()))?; @@ -1402,55 +1487,16 @@ impl Dataset { &mut self, transforms: &PyAny, read_columns: Option>, + batch_size: Option, ) -> PyResult<()> { - println!("add_columns"); - let transforms = if let Ok(transforms) = transforms.extract::<&PyDict>() { - let expressions = transforms - .iter() - .map(|(k, v)| { - let col = k.extract::()?; - let expr = v.extract::()?; - Ok((col, expr)) - }) - .collect::>>()?; - NewColumnTransform::SqlExpressions(expressions) - } else { - let append_schema: PyArrowType = - transforms.getattr("output_schema")?.extract()?; - let output_schema = Arc::new(append_schema.0); - - let result_checkpoint: Option = transforms.getattr("cache")?.extract()?; - let result_checkpoint = - result_checkpoint.map(|c| PyBatchUDFCheckpointWrapper { inner: c }); - - let udf_obj = transforms.to_object(transforms.py()); - let mapper = move |batch: &RecordBatch| -> lance::Result { - Python::with_gil(|py| { - let py_batch: PyArrowType = PyArrowType(batch.clone()); - let result = udf_obj - .call_method1(py, "_call", (py_batch,)) - .map_err(|err| { - lance::Error::io(format_python_error(err, py).unwrap(), location!()) - })?; - let result_batch: PyArrowType = result - .extract(py) - .map_err(|err| lance::Error::io(err.to_string(), location!()))?; - Ok(result_batch.0) - }) - }; - - NewColumnTransform::BatchUDF(BatchUDF { - mapper: Box::new(mapper), - output_schema, - result_checkpoint: result_checkpoint - .map(|c| Arc::new(c) as Arc), - }) - }; + let transforms = transforms_from_python(transforms)?; let mut new_self = self.ds.as_ref().clone(); let new_self = RT .spawn(None, async move { - new_self.add_columns(transforms, read_columns).await?; + new_self + .add_columns(transforms, read_columns, batch_size) + .await?; Ok(new_self) })? .map_err(|err: lance::Error| PyIOError::new_err(err.to_string()))?; diff --git a/python/src/fragment.rs b/python/src/fragment.rs index 1e69d3b354..e5d2a39199 100644 --- a/python/src/fragment.rs +++ b/python/src/fragment.rs @@ -21,13 +21,15 @@ use arrow_array::RecordBatchReader; use arrow_schema::Schema as ArrowSchema; use futures::TryFutureExt; use lance::dataset::fragment::FileFragment as LanceFragment; +use lance::dataset::NewColumnTransform; use lance_table::format::{DataFile as LanceDataFile, Fragment as LanceFragmentMetadata}; use lance_table::io::deletion::deletion_file_path; use pyo3::prelude::*; use pyo3::{exceptions::*, pyclass::CompareOp, types::PyDict}; -use crate::dataset::get_write_params; -use crate::updater::Updater; +use crate::dataset::{get_write_params, transforms_from_python}; +use crate::error::PythonErrorExt; +use crate::schema::LanceSchema; use crate::{Dataset, Scanner, RT}; #[pyclass(name = "_Fragment", module = "_lib")] @@ -210,14 +212,43 @@ impl FileFragment { Ok(Scanner::new(scn)) } - fn updater(&self, columns: Option>, batch_size: Option) -> PyResult { - let cols = columns.as_deref(); - let inner = RT - .block_on(None, async { - self.fragment.updater(cols, None, batch_size).await + fn add_columns_from_reader( + &mut self, + reader: &Bound, + batch_size: Option, + ) -> PyResult<(FragmentMetadata, LanceSchema)> { + let batches = ArrowArrayStreamReader::from_pyarrow_bound(reader)?; + + let transforms = NewColumnTransform::Reader(Box::new(batches)); + + let fragment = self.fragment.clone(); + let (fragment, schema) = RT + .spawn(None, async move { + fragment.add_columns(transforms, None, batch_size).await })? - .map_err(|err| PyIOError::new_err(err.to_string()))?; - Ok(Updater::new(inner)) + .infer_error()?; + + Ok((FragmentMetadata::new(fragment), LanceSchema(schema))) + } + + fn add_columns( + &mut self, + transforms: &PyAny, + read_columns: Option>, + batch_size: Option, + ) -> PyResult<(FragmentMetadata, LanceSchema)> { + let transforms = transforms_from_python(transforms)?; + + let fragment = self.fragment.clone(); + let (fragment, schema) = RT + .spawn(None, async move { + fragment + .add_columns(transforms, read_columns, batch_size) + .await + })? + .infer_error()?; + + Ok((FragmentMetadata::new(fragment), LanceSchema(schema))) } fn delete(&self, predicate: &str) -> PyResult> { diff --git a/python/src/lib.rs b/python/src/lib.rs index 39746fd07e..67c00b7692 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -64,7 +64,6 @@ pub(crate) mod scanner; pub(crate) mod schema; pub(crate) mod session; pub(crate) mod tracing; -pub(crate) mod updater; pub(crate) mod utils; pub use crate::arrow::{bfloat16_array, BFloat16}; diff --git a/python/src/updater.rs b/python/src/updater.rs deleted file mode 100644 index c5e47016d2..0000000000 --- a/python/src/updater.rs +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2023 Lance Developers. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use arrow::pyarrow::PyArrowType; -use arrow_array::RecordBatch; -use pyo3::{exceptions::*, prelude::*}; - -use lance::dataset::updater::Updater as LanceUpdater; - -use crate::{fragment::FragmentMetadata, schema::LanceSchema, RT}; - -#[pyclass(name = "_Updater", module = "_lib")] -pub struct Updater { - inner: LanceUpdater, -} - -impl Updater { - pub(super) fn new(updater: LanceUpdater) -> Self { - Self { inner: updater } - } -} - -#[pymethods] -impl Updater { - /// Return the next batch as input data. - #[pyo3(signature=())] - fn next(&mut self, py: Python<'_>) -> PyResult>> { - let batch = { - RT.block_on(Some(py), async { self.inner.next().await })? - .map_err(|err| PyIOError::new_err(err.to_string()))? - }; - Ok(batch.map(|b| PyArrowType(b.clone()))) - } - - /// Update one batch - fn update(&mut self, batch: PyArrowType) -> PyResult<()> { - let batch = batch.0; - RT.block_on(None, async { - self.inner - .update(batch) - .await - .map_err(|e| PyIOError::new_err(e.to_string())) - })? - } - - fn finish(&mut self) -> PyResult { - let fragment = RT.block_on(None, async { - self.inner - .finish() - .await - .map_err(|e| PyIOError::new_err(e.to_string())) - })??; - - Ok(FragmentMetadata::new(fragment)) - } - - fn schema(&self) -> Option { - self.inner.schema().map(|s| LanceSchema(s.clone())) - } -} diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index 1d38fa6929..27a86ea9f4 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -4,6 +4,7 @@ //! Scalar indices for metadata search & filtering use std::collections::HashMap; +use std::fmt::Debug; use std::{any::Any, ops::Bound, sync::Arc}; use arrow::buffer::{OffsetBuffer, ScalarBuffer}; @@ -17,6 +18,7 @@ use datafusion_common::{scalar::ScalarValue, Column}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::Expr; use deepsize::DeepSizeOf; +use inverted::TokenizerConfig; use lance_core::utils::mask::RowIdTreeMap; use lance_core::{Error, Result}; use snafu::{location, Location}; @@ -91,19 +93,36 @@ impl IndexParams for ScalarIndexParams { } } -#[derive(Debug, Clone, DeepSizeOf)] +#[derive(Clone)] pub struct InvertedIndexParams { /// If true, store the position of the term in the document /// This can significantly increase the size of the index /// If false, only store the frequency of the term in the document /// Default is true pub with_position: bool, + + pub tokenizer_config: TokenizerConfig, +} + +impl Debug for InvertedIndexParams { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("InvertedIndexParams") + .field("with_position", &self.with_position) + .finish() + } +} + +impl DeepSizeOf for InvertedIndexParams { + fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize { + 0 + } } impl Default for InvertedIndexParams { fn default() -> Self { Self { with_position: true, + tokenizer_config: TokenizerConfig::default(), } } } diff --git a/rust/lance-index/src/scalar/inverted.rs b/rust/lance-index/src/scalar/inverted.rs index 974905a047..32371773a3 100644 --- a/rust/lance-index/src/scalar/inverted.rs +++ b/rust/lance-index/src/scalar/inverted.rs @@ -3,11 +3,13 @@ mod builder; mod index; +mod tokenizer; mod wand; pub use builder::InvertedIndexBuilder; pub use index::*; use lance_core::Result; +pub use tokenizer::*; use super::btree::TrainingSource; use super::{IndexStore, InvertedIndexParams}; diff --git a/rust/lance-index/src/scalar/inverted/builder.rs b/rust/lance-index/src/scalar/inverted/builder.rs index a934750e70..4a73cc4dda 100644 --- a/rust/lance-index/src/scalar/inverted/builder.rs +++ b/rust/lance-index/src/scalar/inverted/builder.rs @@ -10,9 +10,9 @@ use std::sync::Arc; use crate::scalar::lance_format::LanceIndexStore; use crate::scalar::{IndexReader, IndexStore, IndexWriter, InvertedIndexParams}; use crate::vector::graph::OrderedFloat; -use arrow::array::AsArray; +use arrow::array::{ArrayBuilder, AsArray, Int32Builder, StringBuilder}; use arrow::datatypes; -use arrow_array::RecordBatch; +use arrow_array::{Int32Array, RecordBatch, StringArray}; use arrow_schema::SchemaRef; use crossbeam_queue::ArrayQueue; use datafusion::execution::SendableRecordBatchStream; @@ -139,8 +139,8 @@ impl InvertedIndexBuilder { senders.push(sender); result_futs.push(tokio::spawn({ async move { - while let Some((row_id, tokens)) = receiver.recv().await { - worker.add(row_id, tokens).await?; + while let Some((row_id, tokens, positions)) = receiver.recv().await { + worker.add(row_id, tokens, positions).await?; } let reader = worker.into_reader(inverted_list).await?; Result::Ok(reader) @@ -151,18 +151,14 @@ impl InvertedIndexBuilder { let start = std::time::Instant::now(); let senders = Arc::new(senders); let tokenizer_pool = Arc::new(ArrayQueue::new(num_shards)); - let token_buffers_pool = Arc::new(ArrayQueue::new(num_shards)); + let tokenizer = self.params.tokenizer_config.build()?; for _ in 0..num_shards { - let _ = tokenizer_pool.push(TOKENIZER.clone()); - token_buffers_pool - .push(vec![Vec::new(); num_shards]) - .unwrap(); + let _ = tokenizer_pool.push(tokenizer.clone()); } let mut stream = stream .map(move |batch| { let senders = senders.clone(); let tokenizer_pool = tokenizer_pool.clone(); - let token_buffers_pool = token_buffers_pool.clone(); CPU_RUNTIME.spawn_blocking(move || { let batch = batch?; let doc_iter = iter_str_array(batch.column(0)); @@ -172,11 +168,22 @@ impl InvertedIndexBuilder { .filter_map(|(doc, row_id)| doc.map(|doc| (doc, *row_id))); let mut tokenizer = tokenizer_pool.pop().unwrap(); - let mut token_buffers = token_buffers_pool.pop().unwrap(); let num_tokens = docs .map(|(doc, row_id)| { // tokenize the document + let predicted_num_tokens = doc.len() / 5 / num_shards; + let mut token_buffers = std::iter::repeat_with(|| { + ( + StringBuilder::with_capacity( + predicted_num_tokens, + doc.len() / num_shards, + ), + Int32Builder::with_capacity(predicted_num_tokens), + ) + }) + .take(num_shards) + .collect_vec(); let mut num_tokens = 0; let mut token_stream = tokenizer.token_stream(doc); while token_stream.advance() { @@ -184,17 +191,25 @@ impl InvertedIndexBuilder { let mut hasher = DefaultHasher::new(); hasher.write(token.text.as_bytes()); let shard = hasher.finish() as usize % num_shards; - token_buffers[shard] - .push((std::mem::take(&mut token.text), token.position as i32)); + let (ref mut token_builder, ref mut position_builder) = + &mut token_buffers[shard]; + token_builder.append_value(&token.text); + position_builder.append_value(token.position as i32); num_tokens += 1; } - for (shard, buffer) in token_buffers.iter_mut().enumerate() { - if buffer.is_empty() { + for (shard, (token_builder, position_builder)) in + token_buffers.iter_mut().enumerate() + { + if token_builder.is_empty() { continue; } - let buffer = std::mem::take(buffer); - senders[shard].blocking_send((row_id, buffer)).unwrap(); + + let tokens = token_builder.finish(); + let positions = position_builder.finish(); + senders[shard] + .blocking_send((row_id, tokens, positions)) + .unwrap(); } (row_id, num_tokens) @@ -202,7 +217,6 @@ impl InvertedIndexBuilder { .collect_vec(); let _ = tokenizer_pool.push(tokenizer); - token_buffers_pool.push(token_buffers).unwrap(); Result::Ok(num_tokens) }) }) @@ -350,7 +364,10 @@ impl InvertedIndexBuilder { ("max_scores".to_owned(), serde_json::to_string(&max_scores)?), ]); writer.finish_with_metadata(metadata).await?; - log::info!("finished writing posting lists"); + log::info!( + "finished writing posting lists, elapsed: {:?}", + start.elapsed() + ); Ok(()) } @@ -363,7 +380,10 @@ impl InvertedIndexBuilder { let batch = tokens.to_batch()?; let mut writer = store.new_index_file(TOKENS_FILE, batch.schema()).await?; writer.write_record_batch(batch).await?; - writer.finish().await?; + + let tokenizer = serde_json::to_string(&self.params.tokenizer_config)?; + let metadata = HashMap::from_iter(vec![("tokenizer".to_owned(), tokenizer)]); + writer.finish_with_metadata(metadata).await?; log::info!("finished writing tokens"); Ok(()) @@ -429,13 +449,18 @@ impl IndexWorker { self.schema.column_with_name(POSITION_COL).is_some() } - async fn add(&mut self, row_id: u64, tokens: Vec<(String, i32)>) -> Result<()> { + async fn add(&mut self, row_id: u64, tokens: StringArray, positions: Int32Array) -> Result<()> { let mut token_occurrences = HashMap::new(); - for (token, position) in tokens { + for (token, position) in tokens.iter().zip(positions.values().into_iter()) { + let token = if let Some(token) = token { + token + } else { + continue; + }; token_occurrences .entry(token) .or_insert_with(Vec::new) - .push(position); + .push(*position); } let with_position = self.has_position(); token_occurrences @@ -443,7 +468,7 @@ impl IndexWorker { .for_each(|(token, term_positions)| { let posting_list = self .posting_lists - .entry(token.clone()) + .entry(token.to_owned()) .or_insert_with(|| PostingListBuilder::empty(with_position)); let old_size = if posting_list.is_empty() { @@ -499,6 +524,7 @@ impl IndexWorker { Ok(()) } + #[instrument(level = "debug", skip_all)] async fn flush_posting_list(&mut self, token: String) -> Result { if let Some(posting_list) = self.posting_lists.remove(&token) { let size = posting_list.size(); @@ -710,6 +736,7 @@ mod tests { use lance_io::object_store::ObjectStore; use object_store::path::Path; + use crate::scalar::inverted::TokenizerConfig; use crate::scalar::lance_format::LanceIndexStore; use crate::scalar::{FullTextSearchQuery, SargableQuery, ScalarIndex}; @@ -717,13 +744,15 @@ mod tests { async fn create_index( with_position: bool, + tokenizer: TokenizerConfig, ) -> Arc { let tempdir = tempfile::tempdir().unwrap(); let index_dir = Path::from_filesystem_path(tempdir.path()).unwrap(); let cache = FileMetadataCache::with_capacity(128 * 1024 * 1024, CapacityMode::Bytes); let store = LanceIndexStore::new(ObjectStore::local(), index_dir, cache); - let params = super::InvertedIndexParams::default().with_position(with_position); + let mut params = super::InvertedIndexParams::default().with_position(with_position); + params.tokenizer_config = tokenizer; let mut invert_index = super::InvertedIndexBuilder::new(params); let doc_col = GenericStringArray::::from(vec![ "lance database the search", @@ -732,6 +761,7 @@ mod tests { "database search", "unrelated doc", "unrelated", + "mots accentués", ]); let row_id_col = UInt64Array::from(Vec::from_iter(0..doc_col.len() as u64)); let batch = RecordBatch::try_new( @@ -758,7 +788,7 @@ mod tests { } async fn test_inverted_index() { - let invert_index = create_index::(false).await; + let invert_index = create_index::(false, TokenizerConfig::default()).await; let row_ids = invert_index .search(&SargableQuery::FullTextSearch( FullTextSearchQuery::new("lance".to_owned()).limit(Some(3)), @@ -808,7 +838,7 @@ mod tests { assert!(results.unwrap_err().to_string().contains("position is not found but required for phrase queries, try recreating the index with position")); // recreate the index with position - let invert_index = create_index::(true).await; + let invert_index = create_index::(true, TokenizerConfig::default()).await; let row_ids = invert_index .search(&SargableQuery::FullTextSearch( FullTextSearchQuery::new("lance database".to_owned()).limit(Some(10)), @@ -865,4 +895,43 @@ mod tests { async fn test_inverted_index_with_large_string() { test_inverted_index::().await; } + + #[tokio::test] + async fn test_accented_chars() { + let invert_index = create_index::(false, TokenizerConfig::default()).await; + let row_ids = invert_index + .search(&SargableQuery::FullTextSearch( + FullTextSearchQuery::new("accentués".to_owned()).limit(Some(3)), + )) + .await + .unwrap(); + assert_eq!(row_ids.len(), Some(1)); + + let row_ids = invert_index + .search(&SargableQuery::FullTextSearch( + FullTextSearchQuery::new("accentues".to_owned()).limit(Some(3)), + )) + .await + .unwrap(); + assert_eq!(row_ids.len(), Some(0)); + + // with ascii folding enabled, the search should be accent-insensitive + let invert_index = + create_index::(true, TokenizerConfig::default().ascii_folding(true)).await; + let row_ids = invert_index + .search(&SargableQuery::FullTextSearch( + FullTextSearchQuery::new("accentués".to_owned()).limit(Some(3)), + )) + .await + .unwrap(); + assert_eq!(row_ids.len(), Some(1)); + + let row_ids = invert_index + .search(&SargableQuery::FullTextSearch( + FullTextSearchQuery::new("accentues".to_owned()).limit(Some(3)), + )) + .await + .unwrap(); + assert_eq!(row_ids.len(), Some(1)); + } } diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index 1b4a2a82af..e0b62db96c 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use std::collections::{HashMap, HashSet}; +use std::fmt::Debug; use std::sync::Arc; use arrow::array::{ @@ -27,11 +28,10 @@ use lazy_static::lazy_static; use moka::future::Cache; use roaring::RoaringBitmap; use snafu::{location, Location}; -use tantivy::tokenizer::Language; use tracing::instrument; use super::builder::inverted_list_schema; -use super::{wand::*, InvertedIndexBuilder}; +use super::{wand::*, InvertedIndexBuilder, TokenizerConfig}; use crate::prefilter::{NoFilter, PreFilter}; use crate::scalar::{ AnyQuery, FullTextSearchQuery, IndexReader, IndexStore, SargableQuery, ScalarIndex, @@ -57,26 +57,30 @@ pub const K1: f32 = 1.2; pub const B: f32 = 0.75; lazy_static! { - pub static ref TOKENIZER: tantivy::tokenizer::TextAnalyzer = { - tantivy::tokenizer::TextAnalyzer::builder(tantivy::tokenizer::SimpleTokenizer::default()) - .filter(tantivy::tokenizer::RemoveLongFilter::limit(40)) - .filter(tantivy::tokenizer::LowerCaser) - .filter(tantivy::tokenizer::Stemmer::new(Language::English)) - .build() - }; static ref CACHE_SIZE: usize = std::env::var("LANCE_INVERTED_CACHE_SIZE") .ok() .and_then(|s| s.parse().ok()) .unwrap_or(512 * 1024 * 1024); } -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct InvertedIndex { + tokenizer: tantivy::tokenizer::TextAnalyzer, tokens: TokenSet, inverted_list: Arc, docs: DocSet, } +impl Debug for InvertedIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("InvertedIndex") + .field("tokens", &self.tokens) + .field("inverted_list", &self.inverted_list) + .field("docs", &self.docs) + .finish() + } +} + impl DeepSizeOf for InvertedIndex { fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { self.tokens.deep_size_of_children(context) @@ -102,7 +106,8 @@ impl InvertedIndex { query: &FullTextSearchQuery, prefilter: Arc, ) -> Result> { - let tokens = collect_tokens(&query.query); + let mut tokenizer = self.tokenizer.clone(); + let tokens = collect_tokens(&query.query, &mut tokenizer); let token_ids = self.map(&tokens).into_iter(); let token_ids = if !is_phrase_query(&query.query) { token_ids.sorted_unstable().dedup().collect() @@ -239,8 +244,16 @@ impl ScalarIndex for InvertedIndex { let store = store.clone(); async move { let token_reader = store.open_index_file(TOKENS_FILE).await?; + let tokenizer = token_reader + .schema() + .metadata + .get("tokenizer") + .map(|s| serde_json::from_str::(s)) + .transpose()? + .unwrap_or_default() + .build()?; let tokens = TokenSet::load(token_reader).await?; - Result::Ok(tokens) + Result::Ok((tokenizer, tokens)) } }); let invert_list_fut = tokio::spawn({ @@ -260,11 +273,12 @@ impl ScalarIndex for InvertedIndex { } }); - let tokens = tokens_fut.await??; + let (tokenizer, tokens) = tokens_fut.await??; let inverted_list = invert_list_fut.await??; let docs = docs_fut.await??; Ok(Arc::new(Self { + tokenizer, tokens, inverted_list, docs, @@ -959,13 +973,16 @@ fn do_flat_full_text_search( query: &str, ) -> Result> { let mut results = Vec::new(); - let query_tokens = collect_tokens(query).into_iter().collect::>(); + let mut tokenizer = TokenizerConfig::default().build()?; + let query_tokens = collect_tokens(query, &mut tokenizer) + .into_iter() + .collect::>(); for batch in batches { let row_id_array = batch[ROW_ID].as_primitive::(); let doc_array = batch[doc_col].as_string::(); for i in 0..row_id_array.len() { let doc = doc_array.value(i); - let doc_tokens = collect_tokens(doc); + let doc_tokens = collect_tokens(doc, &mut tokenizer); if doc_tokens.iter().any(|token| query_tokens.contains(token)) { results.push(row_id_array.value(i)); assert!(doc.contains(query)); @@ -976,8 +993,7 @@ fn do_flat_full_text_search( Ok(results) } -pub fn collect_tokens(text: &str) -> Vec { - let mut tokenizer = TOKENIZER.clone(); +pub fn collect_tokens(text: &str, tokenizer: &mut tantivy::tokenizer::TextAnalyzer) -> Vec { let mut stream = tokenizer.token_stream(text); let mut tokens = Vec::new(); while let Some(token) = stream.next() { diff --git a/rust/lance-index/src/scalar/inverted/tokenizer.rs b/rust/lance-index/src/scalar/inverted/tokenizer.rs new file mode 100644 index 0000000000..1796091f12 --- /dev/null +++ b/rust/lance-index/src/scalar/inverted/tokenizer.rs @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use lance_core::{Error, Result}; +use serde::{Deserialize, Serialize}; +use snafu::{location, Location}; + +/// Tokenizer configs +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenizerConfig { + /// base tokenizer: + /// - `simple`: splits tokens on whitespace and punctuation + /// - `whitespace`: splits tokens on whitespace + /// - `raw`: no tokenization + /// + /// `simple` is recommended for most cases and the default value + base_tokenizer: String, + + /// language for stemming and stop words + /// this is only used when `stem` or `remove_stop_words` is true + language: tantivy::tokenizer::Language, + + /// maximum token length + /// - `None`: no limit + /// - `Some(n)`: remove tokens longer than `n` + max_token_length: Option, + + /// whether lower case tokens + lower_case: bool, + + /// whether apply stemming + stem: bool, + + /// whether remove stop words + remove_stop_words: bool, + + /// ascii folding + ascii_folding: bool, +} + +impl Default for TokenizerConfig { + fn default() -> Self { + Self::new("simple".to_owned(), tantivy::tokenizer::Language::English) + } +} + +impl TokenizerConfig { + pub fn new(base_tokenizer: String, language: tantivy::tokenizer::Language) -> Self { + TokenizerConfig { + base_tokenizer, + language, + max_token_length: Some(40), + lower_case: true, + stem: false, + remove_stop_words: false, + ascii_folding: false, + } + } + + pub fn base_tokenizer(mut self, base_tokenizer: String) -> Self { + self.base_tokenizer = base_tokenizer; + self + } + + pub fn language(mut self, language: &str) -> Result { + // need to convert to valid JSON string + let language = serde_json::from_str(format!("\"{}\"", language).as_str())?; + self.language = language; + Ok(self) + } + + pub fn max_token_length(mut self, max_token_length: Option) -> Self { + self.max_token_length = max_token_length; + self + } + + pub fn lower_case(mut self, lower_case: bool) -> Self { + self.lower_case = lower_case; + self + } + + pub fn stem(mut self, stem: bool) -> Self { + self.stem = stem; + self + } + + pub fn remove_stop_words(mut self, remove_stop_words: bool) -> Self { + self.remove_stop_words = remove_stop_words; + self + } + + pub fn ascii_folding(mut self, ascii_folding: bool) -> Self { + self.ascii_folding = ascii_folding; + self + } + + pub fn build(&self) -> Result { + let mut builder = build_base_tokenizer_builder(&self.base_tokenizer)?; + if let Some(max_token_length) = self.max_token_length { + builder = builder.filter_dynamic(tantivy::tokenizer::RemoveLongFilter::limit( + max_token_length, + )); + } + if self.lower_case { + builder = builder.filter_dynamic(tantivy::tokenizer::LowerCaser); + } + if self.stem { + builder = builder.filter_dynamic(tantivy::tokenizer::Stemmer::new(self.language)); + } + if self.remove_stop_words { + let stop_word_filter = tantivy::tokenizer::StopWordFilter::new(self.language) + .ok_or_else(|| { + Error::invalid_input( + format!( + "removing stop words for language {:?} is not supported yet", + self.language + ), + location!(), + ) + })?; + builder = builder.filter_dynamic(stop_word_filter); + } + if self.ascii_folding { + builder = builder.filter_dynamic(tantivy::tokenizer::AsciiFoldingFilter); + } + Ok(builder.build()) + } +} + +fn build_base_tokenizer_builder(name: &str) -> Result { + match name { + "simple" => Ok(tantivy::tokenizer::TextAnalyzer::builder( + tantivy::tokenizer::SimpleTokenizer::default(), + ) + .dynamic()), + "whitespace" => Ok(tantivy::tokenizer::TextAnalyzer::builder( + tantivy::tokenizer::WhitespaceTokenizer::default(), + ) + .dynamic()), + "raw" => Ok(tantivy::tokenizer::TextAnalyzer::builder( + tantivy::tokenizer::RawTokenizer::default(), + ) + .dynamic()), + _ => Err(Error::invalid_input( + format!("unknown base tokenizer {}", name), + location!(), + )), + } +} diff --git a/rust/lance/examples/full_text_search.rs b/rust/lance/examples/full_text_search.rs index fcbfcd61bb..00a65f5faf 100644 --- a/rust/lance/examples/full_text_search.rs +++ b/rust/lance/examples/full_text_search.rs @@ -58,9 +58,7 @@ async fn main() { let mut dataset = Dataset::write(batches, dataset_dir.as_ref(), None) .await .unwrap(); - let params = InvertedIndexParams { - with_position: true, - }; + let params = InvertedIndexParams::default(); let start = std::time::Instant::now(); dataset .create_index( diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 15eba77318..13f1c141bb 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -1393,8 +1393,9 @@ impl Dataset { &mut self, transforms: NewColumnTransform, read_columns: Option>, + batch_size: Option, ) -> Result<()> { - schema_evolution::add_columns(self, transforms, read_columns).await + schema_evolution::add_columns(self, transforms, read_columns, batch_size).await } /// Modify columns in the dataset, changing their name, type, or nullability. diff --git a/rust/lance/src/dataset/fragment.rs b/rust/lance/src/dataset/fragment.rs index 6d15bbbffb..5bb5fe48d7 100644 --- a/rust/lance/src/dataset/fragment.rs +++ b/rust/lance/src/dataset/fragment.rs @@ -46,7 +46,7 @@ use super::hash_joiner::HashJoiner; use super::rowids::load_row_id_sequence; use super::scanner::Scanner; use super::updater::Updater; -use super::WriteParams; +use super::{schema_evolution, NewColumnTransform, WriteParams}; use crate::arrow::*; use crate::dataset::Dataset; @@ -1154,7 +1154,12 @@ impl FileFragment { /// and the full schema (the target schema after the update). If the write /// schema is None, it is inferred from the first batch of results. The full /// schema is inferred by appending the write schema to the existing schema. - pub async fn updater>( + /// + /// The `batch_size` parameter can be used to influence how much data is processed + /// at a time. This can be useful to control memory usage when processing very large + /// fields. The batch_size will only be used if the dataset is a v2 dataset. It will + /// be ignored for v1 datasets. + pub(crate) async fn updater>( &self, columns: Option<&[T]>, schemas: Option<(Schema, Schema)>, @@ -1207,6 +1212,27 @@ impl FileFragment { Ok(self) } + /// Append new columns to the fragment + /// + /// This is the fragment-level version of [`Dataset::add_columns`]. + pub async fn add_columns( + &self, + transforms: NewColumnTransform, + read_columns: Option>, + batch_size: Option, + ) -> Result<(Fragment, Schema)> { + let (fragments, schema) = schema_evolution::add_columns_to_fragments( + self.dataset.as_ref(), + transforms, + read_columns, + &[self.clone()], + batch_size, + ) + .await?; + assert_eq!(fragments.len(), 1); + Ok((fragments.into_iter().next().unwrap(), schema)) + } + /// Delete rows from the fragment. /// /// If all rows are deleted, returns `Ok(None)`. Otherwise, returns a new diff --git a/rust/lance/src/dataset/schema_evolution.rs b/rust/lance/src/dataset/schema_evolution.rs index b9c7ff5279..696bdf2f77 100644 --- a/rust/lance/src/dataset/schema_evolution.rs +++ b/rust/lance/src/dataset/schema_evolution.rs @@ -16,6 +16,7 @@ use lance_datafusion::utils::reader_to_stream; use lance_table::format::Fragment; use snafu::{location, Location}; +use super::fragment::FileFragment; use super::{ transaction::{Operation, Transaction}, Dataset, @@ -123,11 +124,13 @@ fn is_upcast_downcast(from_type: &DataType, to_type: &DataType) -> bool { } } -pub(super) async fn add_columns( - dataset: &mut Dataset, +pub(super) async fn add_columns_to_fragments( + dataset: &Dataset, transforms: NewColumnTransform, read_columns: Option>, -) -> Result<()> { + fragments: &[FileFragment], + batch_size: Option, +) -> Result<(Vec, Schema)> { // Check names early (before calling add_columns_impl) to avoid extra work if // the names are wrong. let check_names = |output_schema: &ArrowSchema| { @@ -147,9 +150,10 @@ pub(super) async fn add_columns( NewColumnTransform::BatchUDF(udf) => { check_names(udf.output_schema.as_ref())?; let fragments = add_columns_impl( - dataset, + fragments, read_columns, udf.mapper, + batch_size, udf.result_checkpoint, None, ) @@ -217,20 +221,21 @@ pub(super) async fn add_columns( let mapper = Box::new(mapper); let read_columns = Some(read_schema.field_names().into_iter().cloned().collect()); - let fragments = add_columns_impl(dataset, read_columns, mapper, None, None).await?; + let fragments = + add_columns_impl(fragments, read_columns, mapper, batch_size, None, None).await?; Ok((output_schema, fragments)) } NewColumnTransform::Stream(stream) => { let output_schema = stream.schema(); check_names(output_schema.as_ref())?; - let fragments = add_columns_from_stream(dataset, stream, None, None).await?; + let fragments = add_columns_from_stream(fragments, stream, None, batch_size).await?; Ok((output_schema, fragments)) } NewColumnTransform::Reader(reader) => { let output_schema = reader.schema(); check_names(output_schema.as_ref())?; let stream = reader_to_stream(reader); - let fragments = add_columns_from_stream(dataset, stream, None, None).await?; + let fragments = add_columns_from_stream(fragments, stream, None, batch_size).await?; Ok((output_schema, fragments)) } }?; @@ -238,6 +243,24 @@ pub(super) async fn add_columns( let mut schema = dataset.schema().merge(output_schema.as_ref())?; schema.set_field_id(Some(dataset.manifest.max_field_id())); + Ok((fragments, schema)) +} + +pub(super) async fn add_columns( + dataset: &mut Dataset, + transforms: NewColumnTransform, + read_columns: Option>, + batch_size: Option, +) -> Result<()> { + let (fragments, schema) = add_columns_to_fragments( + dataset, + transforms, + read_columns, + &dataset.get_fragments(), + batch_size, + ) + .await?; + let operation = Operation::Merge { fragments, schema }; let transaction = Transaction::new(dataset.manifest.version, operation, None); let new_manifest = commit_transaction( @@ -258,15 +281,16 @@ pub(super) async fn add_columns( #[allow(clippy::type_complexity)] async fn add_columns_impl( - dataset: &Dataset, + fragments: &[FileFragment], read_columns: Option>, mapper: Box Result + Send + Sync>, + batch_size: Option, result_cache: Option>, schemas: Option<(Schema, Schema)>, ) -> Result> { let read_columns_ref = read_columns.as_deref(); let mapper_ref = mapper.as_ref(); - let fragments = futures::stream::iter(dataset.get_fragments()) + let fragments = futures::stream::iter(fragments) .then(|fragment| { let cache_ref = result_cache.clone(); let schemas_ref = &schemas; @@ -280,7 +304,7 @@ async fn add_columns_impl( } let mut updater = fragment - .updater(read_columns_ref, schemas_ref.clone(), None) + .updater(read_columns_ref, schemas_ref.clone(), batch_size) .await?; let mut batch_index = 0; @@ -323,12 +347,11 @@ async fn add_columns_impl( } async fn add_columns_from_stream( - dataset: &Dataset, + fragments: &[FileFragment], mut stream: SendableRecordBatchStream, schemas: Option<(Schema, Schema)>, batch_size: Option, ) -> Result> { - let fragments = dataset.get_fragments(); let mut new_fragments = Vec::with_capacity(fragments.len()); let mut last_seen_batch: Option = None; for fragment in fragments { @@ -510,10 +533,11 @@ pub(super) async fn alter_columns( let mapper = Box::new(mapper); let fragments = add_columns_impl( - dataset, + &dataset.get_fragments(), Some(read_columns), mapper, None, + None, Some((new_col_schema, new_schema.clone())), ) .await?; @@ -658,6 +682,7 @@ mod test { let fut = dataset.add_columns( NewColumnTransform::SqlExpressions(vec![("id".into(), "id + 1".into())]), None, + None, ); // (Quick validation that the future is Send) let res = require_send(fut).await; @@ -668,6 +693,7 @@ mod test { .add_columns( NewColumnTransform::SqlExpressions(vec![("value".into(), "2 * random()".into())]), None, + None, ) .await?; @@ -676,6 +702,7 @@ mod test { .add_columns( NewColumnTransform::SqlExpressions(vec![("double_id".into(), "2 * id".into())]), None, + None, ) .await?; @@ -687,6 +714,7 @@ mod test { "id + double_id".into(), )]), None, + None, ) .await?; @@ -750,7 +778,7 @@ mod test { )])), result_checkpoint: None, }); - let res = dataset.add_columns(transforms, None).await; + let res = dataset.add_columns(transforms, None, None).await; assert!(matches!(res, Err(Error::InvalidInput { .. }))); // Can add a column that independent (empty read_schema) @@ -773,7 +801,7 @@ mod test { output_schema, result_checkpoint: None, }); - dataset.add_columns(transforms, None).await?; + dataset.add_columns(transforms, None, None).await?; // Can add a column that depends on another column (double id) let output_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( @@ -800,7 +828,7 @@ mod test { output_schema, result_checkpoint: None, }); - dataset.add_columns(transforms, None).await?; + dataset.add_columns(transforms, None, None).await?; // These can be read back, the dataset is valid dataset.validate().await?; @@ -927,7 +955,7 @@ mod test { output_schema, result_checkpoint: Some(request_counter.clone()), }); - dataset.add_columns(transforms, None).await?; + dataset.add_columns(transforms, None, None).await?; // Should have requested both fragments assert_eq!( @@ -1437,6 +1465,7 @@ mod test { .add_columns( NewColumnTransform::SqlExpressions(vec![("x".into(), "i + 1".into())]), Some(vec!["i".into()]), + None, ) .await?; assert_eq!(dataset.manifest.max_field_id(), 1); @@ -1448,6 +1477,7 @@ mod test { .add_columns( NewColumnTransform::SqlExpressions(vec![("y".into(), "2 * i".into())]), Some(vec!["i".into()]), + None, ) .await?; assert_eq!(dataset.manifest.max_field_id(), 1); @@ -1473,6 +1503,7 @@ mod test { ("b".into(), "i + 7".into()), ]), Some(vec!["i".into()]), + None, ) .await?; assert_eq!(dataset.manifest.max_field_id(), 2); @@ -1486,6 +1517,7 @@ mod test { .add_columns( NewColumnTransform::SqlExpressions(vec![("c".into(), "i + 11".into())]), Some(vec!["i".into()]), + None, ) .await?; assert_eq!(dataset.manifest.max_field_id(), 3); diff --git a/rust/lance/src/dataset/updater.rs b/rust/lance/src/dataset/updater.rs index d8a6ae96bb..f12b201de8 100644 --- a/rust/lance/src/dataset/updater.rs +++ b/rust/lance/src/dataset/updater.rs @@ -72,9 +72,9 @@ impl Updater { let batch_size = match (&legacy_batch_size, batch_size) { // If this is a v1 dataset we must use the row group size of the file - (Some(num_rows), _) => *num_rows, + (Some(legacy_batch_size), _) => *legacy_batch_size, // If this is a v2 dataset, let the user pick the batch size - (None, Some(legacy_batch_size)) => legacy_batch_size, + (None, Some(user_specified_batch_size)) => user_specified_batch_size, // Otherwise, default to 1024 if the user didn't specify anything (None, None) => get_default_batch_size().unwrap_or(1024) as u32, };