diff --git a/python/python/lance/vector.py b/python/python/lance/vector.py index 7e3e38151b..51b665499b 100644 --- a/python/python/lance/vector.py +++ b/python/python/lance/vector.py @@ -14,7 +14,6 @@ from tqdm.auto import tqdm from . import write_dataset -from .cuvs.kmeans import KMeans as KMeansCuVS from .dependencies import ( _CAGRA_AVAILABLE, _RAFT_COMMON_AVAILABLE, @@ -22,8 +21,6 @@ torch, ) from .dependencies import numpy as np -from .torch.data import LanceDataset as TorchDataset -from .torch.kmeans import KMeans if TYPE_CHECKING: from pathlib import Path @@ -144,6 +141,9 @@ def train_pq_codebook_on_accelerator( ) -> (np.ndarray, List[Any]): """Use accelerator (GPU or MPS) to train pq codebook.""" + from .torch.data import LanceDataset as TorchDataset + from .torch.kmeans import KMeans + # cuvs not particularly useful for only 256 centroids without more work if accelerator == "cuvs": accelerator = "cuda" @@ -212,6 +212,11 @@ def train_ivf_centroids_on_accelerator( filter_nan: bool = True, ) -> (np.ndarray, Any): """Use accelerator (GPU or MPS) to train kmeans.""" + + from .cuvs.kmeans import KMeans as KMeansCuVS + from .torch.data import LanceDataset as TorchDataset + from .torch.kmeans import KMeans + if isinstance(accelerator, str) and ( not ( CUDA_REGEX.match(accelerator) @@ -321,6 +326,7 @@ def compute_pq_codes( str The absolute path of the pq codes dataset. """ + from .torch.data import LanceDataset as TorchDataset torch.backends.cuda.matmul.allow_tf32 = allow_cuda_tf32 @@ -451,6 +457,8 @@ def compute_partitions( str The absolute path of the partition dataset. """ + from .torch.data import LanceDataset as TorchDataset + torch.backends.cuda.matmul.allow_tf32 = allow_cuda_tf32 num_rows = dataset.count_rows() @@ -640,6 +648,8 @@ def one_pass_assign_ivf_pq_on_accelerator( str The absolute path of the ivfpq codes dataset, as precomputed partition buffers. """ + from .torch.data import LanceDataset as TorchDataset + torch.backends.cuda.matmul.allow_tf32 = allow_cuda_tf32 num_rows = dataset.count_rows() diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index ac49b6f90f..e5e77ba314 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -2182,6 +2182,7 @@ def test_scan_with_row_ids(tmp_path: Path): assert tbl2["a"] == tbl["a"] +@pytest.mark.cuda def test_random_dataset_recall_accelerated(tmp_path: Path): dims = 32 schema = pa.schema([pa.field("a", pa.list_(pa.float32(), dims), False)]) @@ -2207,6 +2208,7 @@ def test_random_dataset_recall_accelerated(tmp_path: Path): validate_vector_index(dataset, "a", pass_threshold=0.5) +@pytest.mark.cuda def test_random_dataset_recall_accelerated_one_pass(tmp_path: Path): dims = 32 schema = pa.schema([pa.field("a", pa.list_(pa.float32(), dims), False)]) @@ -2233,6 +2235,7 @@ def test_random_dataset_recall_accelerated_one_pass(tmp_path: Path): validate_vector_index(dataset, "a", pass_threshold=0.5) +@pytest.mark.cuda def test_count_index_rows_accelerated(tmp_path: Path): dims = 32 schema = pa.schema([pa.field("a", pa.list_(pa.float32(), dims), False)]) @@ -2277,6 +2280,7 @@ def test_count_index_rows_accelerated(tmp_path: Path): assert dataset.stats.index_stats(index_name)["num_indexed_rows"] == 512 +@pytest.mark.cuda def test_count_index_rows_accelerated_one_pass(tmp_path: Path): dims = 32 schema = pa.schema([pa.field("a", pa.list_(pa.float32(), dims), False)]) diff --git a/python/python/tests/test_huggingface.py b/python/python/tests/test_huggingface.py index 6aa39aa21d..0bfeb2daae 100644 --- a/python/python/tests/test_huggingface.py +++ b/python/python/tests/test_huggingface.py @@ -4,7 +4,6 @@ from pathlib import Path import lance -import lance.torch.data import numpy as np import pytest @@ -26,7 +25,10 @@ def test_write_hf_dataset(tmp_path: Path): assert ds.schema == hf_ds.features.arrow_schema +@pytest.mark.cuda def test_image_hf_dataset(tmp_path: Path): + import lance.torch.data + ds = datasets.Dataset.from_dict( {"i": [np.zeros(shape=(16, 16, 3), dtype=np.uint8)]}, features=datasets.Features({"i": datasets.Image()}),