Skip to content

Commit

Permalink
chore: don't require torch (#3007)
Browse files Browse the repository at this point in the history
  • Loading branch information
westonpace authored Oct 16, 2024
1 parent 19d947e commit b6e42f7
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
16 changes: 13 additions & 3 deletions python/python/lance/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,13 @@
from tqdm.auto import tqdm

from . import write_dataset
from .cuvs.kmeans import KMeans as KMeansCuVS
from .dependencies import (
_CAGRA_AVAILABLE,
_RAFT_COMMON_AVAILABLE,
_check_for_numpy,
torch,
)
from .dependencies import numpy as np
from .torch.data import LanceDataset as TorchDataset
from .torch.kmeans import KMeans

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -144,6 +141,9 @@ def train_pq_codebook_on_accelerator(
) -> (np.ndarray, List[Any]):
"""Use accelerator (GPU or MPS) to train pq codebook."""

from .torch.data import LanceDataset as TorchDataset
from .torch.kmeans import KMeans

# cuvs not particularly useful for only 256 centroids without more work
if accelerator == "cuvs":
accelerator = "cuda"
Expand Down Expand Up @@ -212,6 +212,11 @@ def train_ivf_centroids_on_accelerator(
filter_nan: bool = True,
) -> (np.ndarray, Any):
"""Use accelerator (GPU or MPS) to train kmeans."""

from .cuvs.kmeans import KMeans as KMeansCuVS
from .torch.data import LanceDataset as TorchDataset
from .torch.kmeans import KMeans

if isinstance(accelerator, str) and (
not (
CUDA_REGEX.match(accelerator)
Expand Down Expand Up @@ -321,6 +326,7 @@ def compute_pq_codes(
str
The absolute path of the pq codes dataset.
"""
from .torch.data import LanceDataset as TorchDataset

torch.backends.cuda.matmul.allow_tf32 = allow_cuda_tf32

Expand Down Expand Up @@ -451,6 +457,8 @@ def compute_partitions(
str
The absolute path of the partition dataset.
"""
from .torch.data import LanceDataset as TorchDataset

torch.backends.cuda.matmul.allow_tf32 = allow_cuda_tf32

num_rows = dataset.count_rows()
Expand Down Expand Up @@ -640,6 +648,8 @@ def one_pass_assign_ivf_pq_on_accelerator(
str
The absolute path of the ivfpq codes dataset, as precomputed partition buffers.
"""
from .torch.data import LanceDataset as TorchDataset

torch.backends.cuda.matmul.allow_tf32 = allow_cuda_tf32

num_rows = dataset.count_rows()
Expand Down
4 changes: 4 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2182,6 +2182,7 @@ def test_scan_with_row_ids(tmp_path: Path):
assert tbl2["a"] == tbl["a"]


@pytest.mark.cuda
def test_random_dataset_recall_accelerated(tmp_path: Path):
dims = 32
schema = pa.schema([pa.field("a", pa.list_(pa.float32(), dims), False)])
Expand All @@ -2207,6 +2208,7 @@ def test_random_dataset_recall_accelerated(tmp_path: Path):
validate_vector_index(dataset, "a", pass_threshold=0.5)


@pytest.mark.cuda
def test_random_dataset_recall_accelerated_one_pass(tmp_path: Path):
dims = 32
schema = pa.schema([pa.field("a", pa.list_(pa.float32(), dims), False)])
Expand All @@ -2233,6 +2235,7 @@ def test_random_dataset_recall_accelerated_one_pass(tmp_path: Path):
validate_vector_index(dataset, "a", pass_threshold=0.5)


@pytest.mark.cuda
def test_count_index_rows_accelerated(tmp_path: Path):
dims = 32
schema = pa.schema([pa.field("a", pa.list_(pa.float32(), dims), False)])
Expand Down Expand Up @@ -2277,6 +2280,7 @@ def test_count_index_rows_accelerated(tmp_path: Path):
assert dataset.stats.index_stats(index_name)["num_indexed_rows"] == 512


@pytest.mark.cuda
def test_count_index_rows_accelerated_one_pass(tmp_path: Path):
dims = 32
schema = pa.schema([pa.field("a", pa.list_(pa.float32(), dims), False)])
Expand Down
4 changes: 3 additions & 1 deletion python/python/tests/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pathlib import Path

import lance
import lance.torch.data
import numpy as np
import pytest

Expand All @@ -26,7 +25,10 @@ def test_write_hf_dataset(tmp_path: Path):
assert ds.schema == hf_ds.features.arrow_schema


@pytest.mark.cuda
def test_image_hf_dataset(tmp_path: Path):
import lance.torch.data

ds = datasets.Dataset.from_dict(
{"i": [np.zeros(shape=(16, 16, 3), dtype=np.uint8)]},
features=datasets.Features({"i": datasets.Image()}),
Expand Down

0 comments on commit b6e42f7

Please sign in to comment.