Skip to content

Commit

Permalink
feat: support create IVF_HNSW_PQ index in Python (#2127)
Browse files Browse the repository at this point in the history
Signed-off-by: BubbleCal <[email protected]>
  • Loading branch information
BubbleCal authored Mar 30, 2024
1 parent cf9c5c5 commit da1d236
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 110 deletions.
27 changes: 17 additions & 10 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,11 +1282,11 @@ def create_index(
- **max_opq_iterations**: the maximum number of iterations for training OPQ.
- **ivf_centroids**: K-mean centroids for IVF clustering.
If ``index_type`` is "DISKANN", then the following parameters are optional:
- **r**: out-degree bound
- **l**: number of levels in the graph.
- **alpha**: distance threshold for the graph.
Optional parameters for "IVF_HNSW_PQ":
- **max_level**: the maximum number of levels in the graph.
- **m**: the number of edges per node in the graph.
- **m_max**: the maximum number of edges per node in the graph.
- **ef_construction**: the number of nodes to examine during the construction.
Examples
--------
Expand Down Expand Up @@ -1369,12 +1369,13 @@ def create_index(
kwargs["metric_type"] = metric

index_type = index_type.upper()
if index_type not in ["IVF_PQ", "DISKANN"]:
if index_type not in ["IVF_PQ", "IVF_HNSW_PQ"]:
raise NotImplementedError(
f"Only IVF_PQ or DiskANN index_types supported. Got {index_type}"
f"Only [IVF_PQ, IVF_HNSW_PQ] index types supported. "
f"Got {index_type}"
)
if index_type == "IVF_PQ":
if num_partitions is None or num_sub_vectors is None:
if index_type.startswith("IVF"):
if num_partitions is None:
raise ValueError(
"num_partitions and num_sub_vectors are required for IVF_PQ"
)
Expand All @@ -1386,7 +1387,6 @@ def create_index(
f"num_partitions must be int, got {type(num_partitions)}"
)
kwargs["num_partitions"] = num_partitions
kwargs["num_sub_vectors"] = num_sub_vectors

if accelerator is not None and ivf_centroids is None:
# Use accelerator to train ivf centroids
Expand Down Expand Up @@ -1433,6 +1433,13 @@ def create_index(
)
kwargs["ivf_centroids"] = ivf_centroids_batch

if "PQ" in index_type:
if num_sub_vectors is None:
raise ValueError(
"num_partitions and num_sub_vectors are required for IVF_PQ"
)
kwargs["num_sub_vectors"] = num_sub_vectors

if pq_codebook is not None:
# User provided IVF centroids
if _check_for_numpy(pq_codebook) and isinstance(
Expand Down
12 changes: 12 additions & 0 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,18 @@ def test_create_dot_index(dataset, tmp_path):
assert ann_ds.has_index


def test_create_ivf_hnsw_pq_index(dataset, tmp_path):
assert not dataset.has_index
ann_ds = lance.write_dataset(dataset.to_table(), tmp_path / "indexed.lance")
ann_ds = ann_ds.create_index(
"vector",
index_type="IVF_HNSW_PQ",
num_partitions=4,
num_sub_vectors=16,
)
assert ann_ds.list_indices()[0]["fields"] == ["vector"]


def test_pre_populated_ivf_centroids(dataset, tmp_path: Path):
centroids = np.random.randn(5, 128).astype(np.float32) # IVF5
dataset_with_index = dataset.create_index(
Expand Down
239 changes: 139 additions & 100 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use lance::index::{scalar::ScalarIndexParams, vector::VectorIndexParams};
use lance_arrow::as_fixed_size_list_array;
use lance_core::datatypes::Schema;
use lance_index::optimize::OptimizeOptions;
use lance_index::vector::hnsw::builder::HnswBuildParams;
use lance_index::{
vector::{ivf::IvfBuildParams, pq::PQBuildParams},
DatasetIndexExt, IndexParams, IndexType,
Expand Down Expand Up @@ -873,9 +874,10 @@ impl Dataset {
replace: Option<bool>,
kwargs: Option<&PyDict>,
) -> PyResult<()> {
let idx_type = match index_type.to_uppercase().as_str() {
let index_type = index_type.to_uppercase();
let idx_type = match index_type.as_str() {
"BTREE" => IndexType::Scalar,
"IVF_PQ" => IndexType::Vector,
"IVF_PQ" | "IVF_HNSW_PQ" => IndexType::Vector,
_ => {
return Err(PyValueError::new_err(format!(
"Index type '{index_type}' is not supported."
Expand All @@ -884,104 +886,10 @@ impl Dataset {
};

// Only VectorParams are supported.
let params: Box<dyn IndexParams> = match index_type.to_uppercase().as_str() {
"BTREE" => Box::<ScalarIndexParams>::default(),
"IVF_PQ" => {
let mut ivf_params = IvfBuildParams::default();
let mut pq_params = PQBuildParams::default();
let mut m_type = MetricType::L2;
if let Some(kwargs) = kwargs {
if let Some(mt) = kwargs.get_item("metric_type")? {
m_type = MetricType::try_from(mt.to_string().to_lowercase().as_str())
.map_err(|err| PyValueError::new_err(err.to_string()))?;
}

if let Some(n) = kwargs.get_item("num_partitions")? {
ivf_params.num_partitions = PyAny::downcast::<PyInt>(n)?.extract()?
};

if let Some(n) = kwargs.get_item("num_bits")? {
pq_params.num_bits = PyAny::downcast::<PyInt>(n)?.extract()?
};

if let Some(n) = kwargs.get_item("num_sub_vectors")? {
pq_params.num_sub_vectors = PyAny::downcast::<PyInt>(n)?.extract()?
};

if let Some(o) = kwargs.get_item("use_opq")? {
#[cfg(not(feature = "opq"))]
if PyAny::downcast::<PyBool>(o)?.extract()? {
return Err(PyValueError::new_err(
"Feature 'opq' is not installed.".to_string(),
));
}
pq_params.use_opq = PyAny::downcast::<PyBool>(o)?.extract()?
};

if let Some(c) = kwargs.get_item("pq_codebook")? {
let batch = RecordBatch::from_pyarrow(c)?;
if "_pq_codebook" != batch.schema().field(0).name() {
return Err(PyValueError::new_err(
"Expected '_pq_codebook' as the first column name.",
));
}
let codebook = as_fixed_size_list_array(batch.column(0));
pq_params.codebook = Some(codebook.values().clone())
};

if let Some(o) = kwargs.get_item("max_opq_iterations")? {
pq_params.max_opq_iters = PyAny::downcast::<PyInt>(o)?.extract()?
};

if let Some(c) = kwargs.get_item("ivf_centroids")? {
let batch = RecordBatch::from_pyarrow(c)?;
if "_ivf_centroids" != batch.schema().field(0).name() {
return Err(PyValueError::new_err(
"Expected '_ivf_centroids' as the first column name.",
));
}
let centroids = as_fixed_size_list_array(batch.column(0));
ivf_params.centroids = Some(Arc::new(centroids.clone()))
};

if let Some(f) = kwargs.get_item("precomputed_partitions_file")? {
ivf_params.precomputed_partitons_file = Some(f.to_string());
};

match (
kwargs.get_item("precomputed_shuffle_buffers")?,
kwargs.get_item("precomputed_shuffle_buffers_path")?
) {
(Some(l), Some(p)) => {
let path = Path::parse(p.to_string()).map_err(|e| {
PyValueError::new_err(format!(
"Failed to parse precomputed_shuffle_buffers_path: {}",
e
))
})?;
let list = PyAny::downcast::<PyList>(l)?
.iter()
.map(|f| f.to_string())
.collect();
ivf_params.precomputed_shuffle_buffers = Some((path, list));
},
(None, None) => {},
_ => {
return Err(PyValueError::new_err(
"precomputed_shuffle_buffers and precomputed_shuffle_buffers_path must be specified together."
))
}
}
}
Box::new(VectorIndexParams::with_ivf_pq_params(
m_type, ivf_params, pq_params,
))
}
_ => {
return Err(PyValueError::new_err(format!(
"Index type '{index_type}' is not supported."
)))
}
let params: Box<dyn IndexParams> = if index_type == "BTREE" {
Box::<ScalarIndexParams>::default()
} else {
prepare_vector_index_params(&index_type, kwargs)?
};

let replace = replace.unwrap_or(true);
Expand Down Expand Up @@ -1217,6 +1125,137 @@ pub fn get_write_params(options: &PyDict) -> PyResult<Option<WriteParams>> {
Ok(params)
}

fn prepare_vector_index_params(
index_type: &str,
kwargs: Option<&PyDict>,
) -> PyResult<Box<dyn IndexParams>> {
let mut m_type = MetricType::L2;
let mut ivf_params = IvfBuildParams::default();
let mut pq_params = PQBuildParams::default();
let mut hnsw_params = HnswBuildParams::default();

if let Some(kwargs) = kwargs {
// Parse metric type
if let Some(mt) = kwargs.get_item("metric_type")? {
m_type = MetricType::try_from(mt.to_string().to_lowercase().as_str())
.map_err(|err| PyValueError::new_err(err.to_string()))?;
}

// Parse IVF params
if let Some(n) = kwargs.get_item("num_partitions")? {
ivf_params.num_partitions = PyAny::downcast::<PyInt>(n)?.extract()?
};

if let Some(c) = kwargs.get_item("ivf_centroids")? {
let batch = RecordBatch::from_pyarrow(c)?;
if "_ivf_centroids" != batch.schema().field(0).name() {
return Err(PyValueError::new_err(
"Expected '_ivf_centroids' as the first column name.",
));
}
let centroids = as_fixed_size_list_array(batch.column(0));
ivf_params.centroids = Some(Arc::new(centroids.clone()))
};

if let Some(f) = kwargs.get_item("precomputed_partitions_file")? {
ivf_params.precomputed_partitons_file = Some(f.to_string());
};

match (
kwargs.get_item("precomputed_shuffle_buffers")?,
kwargs.get_item("precomputed_shuffle_buffers_path")?
) {
(Some(l), Some(p)) => {
let path = Path::parse(p.to_string()).map_err(|e| {
PyValueError::new_err(format!(
"Failed to parse precomputed_shuffle_buffers_path: {}",
e
))
})?;
let list = PyAny::downcast::<PyList>(l)?
.iter()
.map(|f| f.to_string())
.collect();
ivf_params.precomputed_shuffle_buffers = Some((path, list));
},
(None, None) => {},
_ => {
return Err(PyValueError::new_err(
"precomputed_shuffle_buffers and precomputed_shuffle_buffers_path must be specified together."
))
}
}

// Parse PQ params
if let Some(n) = kwargs.get_item("num_bits")? {
pq_params.num_bits = PyAny::downcast::<PyInt>(n)?.extract()?
};

if let Some(n) = kwargs.get_item("num_sub_vectors")? {
pq_params.num_sub_vectors = PyAny::downcast::<PyInt>(n)?.extract()?
};

if let Some(o) = kwargs.get_item("use_opq")? {
#[cfg(not(feature = "opq"))]
if PyAny::downcast::<PyBool>(o)?.extract()? {
return Err(PyValueError::new_err(
"Feature 'opq' is not installed.".to_string(),
));
}
pq_params.use_opq = PyAny::downcast::<PyBool>(o)?.extract()?
};

if let Some(c) = kwargs.get_item("pq_codebook")? {
let batch = RecordBatch::from_pyarrow(c)?;
if "_pq_codebook" != batch.schema().field(0).name() {
return Err(PyValueError::new_err(
"Expected '_pq_codebook' as the first column name.",
));
}
let codebook = as_fixed_size_list_array(batch.column(0));
pq_params.codebook = Some(codebook.values().clone())
};

if let Some(o) = kwargs.get_item("max_opq_iterations")? {
pq_params.max_opq_iters = PyAny::downcast::<PyInt>(o)?.extract()?
};

// Parse HNSW params
if let Some(max_level) = kwargs.get_item("max_level")? {
hnsw_params.max_level = PyAny::downcast::<PyInt>(max_level)?.extract()?;
}

if let Some(m) = kwargs.get_item("m")? {
hnsw_params.m = PyAny::downcast::<PyInt>(m)?.extract()?;
}

if let Some(m_max) = kwargs.get_item("m_max")? {
hnsw_params.m_max = PyAny::downcast::<PyInt>(m_max)?.extract()?;
}

if let Some(ef_c) = kwargs.get_item("ef_construction")? {
hnsw_params.ef_construction = PyAny::downcast::<PyInt>(ef_c)?.extract()?;
}
}

match index_type {
"IVF_PQ" => Ok(Box::new(VectorIndexParams::with_ivf_pq_params(
m_type, ivf_params, pq_params,
))),

"IVF_HNSW_PQ" => Ok(Box::new(VectorIndexParams::with_ivf_hnsw_pq_params(
m_type,
ivf_params,
hnsw_params,
pq_params,
))),

_ => Err(PyValueError::new_err(format!(
"Index type '{index_type}' is not supported."
))),
}
}

#[pyclass(name = "_FragmentWriteProgress", module = "_lib")]
#[derive(Debug)]
pub struct PyWriteProgress {
Expand Down

0 comments on commit da1d236

Please sign in to comment.