diff --git a/.github/workflows/rust-benchmark.yml b/.github/workflows/rust-benchmark.yml index dfcc14b753..4eb0cbd21d 100644 --- a/.github/workflows/rust-benchmark.yml +++ b/.github/workflows/rust-benchmark.yml @@ -38,7 +38,7 @@ jobs: working-directory: ./rust/lance-linalg run: | # TODO: a few benchmarks are failing. Re-enable everything once they are fixed. - cargo bench --features "fp16kernels" --bench l2 --bench cosine --bench dot --bench kmeans --bench norm_l2 -- --output-format bencher | tee -a ../../output.txt + cargo bench --features "fp16kernels" -- --output-format bencher | tee -a ../../output.txt - name: Run index benchmarks working-directory: ./rust/lance-index run: | diff --git a/rust/lance-index/src/vector/pq.rs b/rust/lance-index/src/vector/pq.rs index a2132caf0e..6ecb678d27 100644 --- a/rust/lance-index/src/vector/pq.rs +++ b/rust/lance-index/src/vector/pq.rs @@ -13,7 +13,7 @@ use arrow_schema::DataType; use deepsize::DeepSizeOf; use lance_arrow::*; use lance_core::{Error, Result}; -use lance_linalg::distance::{dot_distance_batch, DistanceType, Dot, L2}; +use lance_linalg::distance::{dot_distance_batch, Cosine, DistanceType, Dot, L2}; use lance_linalg::kmeans::compute_partition; use num_traits::Float; use prost::Message; @@ -96,7 +96,7 @@ impl ProductQuantizer { #[instrument(name = "ProductQuantizer::transform", level = "debug", skip_all)] fn transform(&self, vectors: &dyn Array) -> Result where - T::Native: Float + L2 + Dot, + T::Native: Float + L2 + Dot + Cosine, { let fsl = vectors.as_fixed_size_list_opt().ok_or(Error::Index { message: format!( diff --git a/rust/lance-index/src/vector/residual.rs b/rust/lance-index/src/vector/residual.rs index 009415e00e..8652addde1 100644 --- a/rust/lance-index/src/vector/residual.rs +++ b/rust/lance-index/src/vector/residual.rs @@ -11,7 +11,7 @@ use arrow_array::{ use arrow_schema::DataType; use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt}; use lance_core::{Error, Result}; -use lance_linalg::distance::{DistanceType, Dot, L2}; +use lance_linalg::distance::{Cosine, DistanceType, Dot, L2}; use lance_linalg::kmeans::compute_partitions; use num_traits::Float; use snafu::{location, Location}; @@ -60,7 +60,7 @@ fn do_compute_residual( partitions: Option<&UInt32Array>, ) -> Result where - T::Native: Float + L2 + Dot, + T::Native: Float + L2 + Dot + Cosine, { let dimension = centroids.value_length() as usize; let centroids_slice = centroids.values().as_primitive::().values(); diff --git a/rust/lance-linalg/benches/compute_partition.rs b/rust/lance-linalg/benches/compute_partition.rs index 7b155a9aa5..e3943348c4 100644 --- a/rust/lance-linalg/benches/compute_partition.rs +++ b/rust/lance-linalg/benches/compute_partition.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use arrow_array::types::Float32Type; use criterion::{criterion_group, criterion_main, Criterion}; -use lance_linalg::{distance::MetricType, kmeans::compute_partitions}; +use lance_linalg::{distance::DistanceType, kmeans::compute_partitions}; use lance_testing::datagen::generate_random_array_with_seed; #[cfg(target_os = "linux")] use pprof::criterion::{Output, PProfProfiler}; @@ -28,7 +28,7 @@ fn bench_compute_partitions(c: &mut Criterion) { centroids.values(), input.values(), DIMENSION, - MetricType::L2, + DistanceType::L2, ) }) }); @@ -39,7 +39,18 @@ fn bench_compute_partitions(c: &mut Criterion) { centroids.values(), input.values(), DIMENSION, - MetricType::Cosine, + DistanceType::Cosine, + ) + }) + }); + + c.bench_function("compute_centroids(Dot)", |b| { + b.iter(|| { + compute_partitions( + centroids.values(), + input.values(), + DIMENSION, + DistanceType::Dot, ) }) }); diff --git a/rust/lance-linalg/src/kmeans.rs b/rust/lance-linalg/src/kmeans.rs index 57c8f16839..17eb0161f5 100644 --- a/rust/lance-linalg/src/kmeans.rs +++ b/rust/lance-linalg/src/kmeans.rs @@ -29,7 +29,7 @@ use rand::prelude::*; use rayon::prelude::*; use crate::distance::hamming::hamming; -use crate::distance::{dot_distance_batch, DistanceType}; +use crate::distance::{cosine_distance_batch, dot_distance_batch, Cosine, DistanceType}; use crate::kernels::{argmax, argmin_value_float}; use crate::{ distance::{ @@ -676,7 +676,7 @@ pub fn compute_partitions_arrow_array( /// Compute partition ID of each vector in the KMeans. /// /// If returns `None`, means the vector is not valid, i.e., all `NaN`. -pub fn compute_partitions( +pub fn compute_partitions( centroids: &[T], vectors: &[T], dimension: impl AsPrimitive, @@ -690,7 +690,7 @@ pub fn compute_partitions( } #[inline] -pub fn compute_partition( +pub fn compute_partition( centroids: &[T], vector: &[T], distance_type: DistanceType, @@ -702,6 +702,9 @@ pub fn compute_partition( DistanceType::Dot => { argmin_value_float(dot_distance_batch(vector, centroids, vector.len())).map(|c| c.0) } + DistanceType::Cosine => { + argmin_value_float(cosine_distance_batch(vector, centroids, vector.len())).map(|c| c.0) + } _ => { panic!( "KMeans::compute_partition: distance type {} is not supported",