Skip to content

Commit

Permalink
feat: one-pass IVF_PQ accelerated builds (#3001)
Browse files Browse the repository at this point in the history
This feature improves disk IO dependence, but it is quite limited. This
only works if the index type is IVF_PQ, and it will not work efficiently
for local PQ in the future (unless we store _all_ the PQ models in
VRAM).
Importantly, this allows us to bypass local temp storage for storing
residuals. However, this still stores PQ codes locally temporarily due
to how we've implemented accelerator support, but these are much smaller
(exact ratio depends on params).

I tested on my local machine, which is sufficiently fast that the
accelerated builds are mostly IO limited (but IO is also fast). I used
wikipedia-40M

New feature disabled:

![results_static_20241011_224535_plot_dataset_wikipedia-few-queries_k_10](https://github.com/user-attachments/assets/9a9285e1-1814-4215-a4c9-2a3f3a16c874)
ivf training time: 52s
ivf transform time: 89s
pq training time: 18s
pq assignment time: 143s
create_index rust time: 8.9s

New feature enabled:

![results_static_20241011_203303_plot_dataset_wikipedia-few-queries_k_10](https://github.com/user-attachments/assets/9d94f50b-e3b6-42f8-8357-3cb477e6279b)
combined training time: 63.7s (not actually sure why this is faster, but
it's not the big part anyway)
combined transform time: 158.8s
create_index rust time: 8.6s

Improvement should be more noticeable for bigger datasets, as usual.
  • Loading branch information
jacketsj authored Oct 14, 2024
1 parent cdac5de commit d207aa8
Show file tree
Hide file tree
Showing 4 changed files with 455 additions and 16 deletions.
52 changes: 52 additions & 0 deletions python/python/benchmarks/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,41 @@ def test_create_ivf_pq(test_dataset, benchmark):
)


@pytest.mark.benchmark(group="create_index")
def test_create_ivf_pq_torch_cpu(test_dataset, benchmark):
from lance.dependencies import torch

benchmark(
test_dataset.create_index,
column="vector",
index_type="IVF_PQ",
metric_type="L2",
num_partitions=8,
num_sub_vectors=2,
num_bits=8,
replace=True,
accelerator=torch.device("cpu"),
)


@pytest.mark.benchmark(group="create_index")
def test_create_ivf_pq_torch_cpu_one_pass(test_dataset, benchmark):
from lance.dependencies import torch

benchmark(
test_dataset.create_index,
column="vector",
index_type="IVF_PQ",
metric_type="L2",
num_partitions=8,
num_sub_vectors=2,
num_bits=8,
replace=True,
accelerator=torch.device("cpu"),
one_pass_ivfpq=True,
)


@pytest.mark.benchmark(group="create_index")
@pytest.mark.cuda
def test_create_ivf_pq_cuda(test_dataset, benchmark):
Expand All @@ -70,6 +105,23 @@ def test_create_ivf_pq_cuda(test_dataset, benchmark):
)


@pytest.mark.benchmark(group="create_index")
@pytest.mark.cuda
def test_create_ivf_pq_cuda_one_pass(test_dataset, benchmark):
benchmark(
test_dataset.create_index,
column="vector",
index_type="IVF_PQ",
metric_type="L2",
num_partitions=8,
num_sub_vectors=2,
num_bits=8,
accelerator="cuda",
replace=True,
one_pass_ivfpq=True,
)


@pytest.mark.benchmark(group="optimize_index")
@pytest.mark.parametrize("num_partitions", [256, 512])
@pytest.mark.parametrize("num_small_indexes", [5])
Expand Down
71 changes: 63 additions & 8 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,7 @@ def create_index(
precomputed_partition_dataset: Optional[str] = None,
storage_options: Optional[Dict[str, str]] = None,
filter_nan: bool = True,
one_pass_ivfpq: bool = False,
**kwargs,
) -> LanceDataset:
"""Create index on column.
Expand Down Expand Up @@ -1508,6 +1509,8 @@ def create_index(
Defaults to True. False is UNSAFE, and will cause a crash if any null/nan
values are present (and otherwise will not). Disables the null filter used
for nullable columns. Obtains a small speed boost.
one_pass_ivfpq: bool
Defaults to False. If enabled, index type must be "IVF_PQ". Reduces disk IO.
kwargs :
Parameters passed to the index building process.
Expand Down Expand Up @@ -1631,6 +1634,58 @@ def create_index(
raise NotImplementedError(
f"Only {valid_index_types} index types supported. " f"Got {index_type}"
)
if index_type != "IVF_PQ" and one_pass_ivfpq:
raise ValueError(
f'one_pass_ivfpq requires index_type="IVF_PQ", got {index_type}'
)

# Handle timing for various parts of accelerated builds
timers = {}
if one_pass_ivfpq and accelerator is not None:
from .vector import (
one_pass_assign_ivf_pq_on_accelerator,
one_pass_train_ivf_pq_on_accelerator,
)

logging.info("Doing one-pass ivfpq accelerated computations")

timers["ivf+pq_train:start"] = time.time()
ivf_centroids, ivf_kmeans, pq_codebook, pq_kmeans_list = (
one_pass_train_ivf_pq_on_accelerator(
self,
column[0],
num_partitions,
metric,
accelerator,
num_sub_vectors=num_sub_vectors,
batch_size=20480,
filter_nan=filter_nan,
)
)
timers["ivf+pq_train:end"] = time.time()
ivfpq_train_time = timers["ivf+pq_train:end"] - timers["ivf+pq_train:start"]
logging.info("ivf+pq training time: %ss", ivfpq_train_time)
timers["ivf+pq_assign:start"] = time.time()
shuffle_output_dir, shuffle_buffers = one_pass_assign_ivf_pq_on_accelerator(
self,
column[0],
metric,
accelerator,
ivf_kmeans,
pq_kmeans_list,
batch_size=20480,
filter_nan=filter_nan,
)
timers["ivf+pq_assign:end"] = time.time()
ivfpq_assign_time = (
timers["ivf+pq_assign:end"] - timers["ivf+pq_assign:start"]
)
logging.info("ivf+pq transform time: %ss", ivfpq_assign_time)

kwargs["precomputed_shuffle_buffers"] = shuffle_buffers
kwargs["precomputed_shuffle_buffers_path"] = os.path.join(
shuffle_output_dir, "data"
)
if index_type.startswith("IVF"):
if (ivf_centroids is not None) and (ivf_centroids_file is not None):
raise ValueError(
Expand Down Expand Up @@ -1659,9 +1714,6 @@ def create_index(
)
kwargs["num_partitions"] = num_partitions

# Handle timing for various parts of accelerated builds
timers = {}

if (precomputed_partition_dataset is not None) and (ivf_centroids is None):
raise ValueError(
"ivf_centroids must be provided when"
Expand Down Expand Up @@ -1692,7 +1744,7 @@ def create_index(
)
kwargs["precomputed_partitions_file"] = precomputed_partition_dataset

if accelerator is not None and ivf_centroids is None:
if accelerator is not None and ivf_centroids is None and not one_pass_ivfpq:
logging.info("Computing new precomputed partition dataset")
# Use accelerator to train ivf centroids
from .vector import (
Expand Down Expand Up @@ -1773,6 +1825,7 @@ def create_index(
pq_codebook is None
and accelerator is not None
and "precomputed_partitions_file" in kwargs
and not one_pass_ivfpq
):
logging.info("Computing new precomputed shuffle buffers for PQ.")
partitions_file = kwargs["precomputed_partitions_file"]
Expand Down Expand Up @@ -1852,13 +1905,15 @@ def create_index(
if shuffle_partition_concurrency is not None:
kwargs["shuffle_partition_concurrency"] = shuffle_partition_concurrency

times = []
times.append(time.time())
timers["final_create_index:start"] = time.time()
self._ds.create_index(
column, index_type, name, replace, storage_options, kwargs
)
times.append(time.time())
logging.info("Final create_index time: %ss", times[1] - times[0])
timers["final_create_index:end"] = time.time()
final_create_index_time = (
timers["final_create_index:end"] - timers["final_create_index:start"]
)
logging.info("Final create_index rust time: %ss", final_create_index_time)
# Save disk space
if "precomputed_shuffle_buffers_path" in kwargs.keys() and os.path.exists(
kwargs["precomputed_shuffle_buffers_path"]
Expand Down
Loading

0 comments on commit d207aa8

Please sign in to comment.