Skip to content

Commit

Permalink
re-enable all benchmarks under linalg
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu committed Oct 1, 2024
1 parent 6e042ac commit a1e8327
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/rust-benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
working-directory: ./rust/lance-linalg
run: |
# TODO: a few benchmarks are failing. Re-enable everything once they are fixed.
cargo bench --features "fp16kernels" --bench l2 --bench cosine --bench dot --bench kmeans --bench norm_l2 -- --output-format bencher | tee -a ../../output.txt
cargo bench --features "fp16kernels" -- --output-format bencher | tee -a ../../output.txt
- name: Run index benchmarks
working-directory: ./rust/lance-index
run: |
Expand Down
4 changes: 2 additions & 2 deletions rust/lance-index/src/vector/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use arrow_schema::DataType;
use deepsize::DeepSizeOf;
use lance_arrow::*;
use lance_core::{Error, Result};
use lance_linalg::distance::{dot_distance_batch, DistanceType, Dot, L2};
use lance_linalg::distance::{dot_distance_batch, Cosine, DistanceType, Dot, L2};
use lance_linalg::kmeans::compute_partition;
use num_traits::Float;
use prost::Message;
Expand Down Expand Up @@ -96,7 +96,7 @@ impl ProductQuantizer {
#[instrument(name = "ProductQuantizer::transform", level = "debug", skip_all)]
fn transform<T: ArrowPrimitiveType>(&self, vectors: &dyn Array) -> Result<ArrayRef>
where
T::Native: Float + L2 + Dot,
T::Native: Float + L2 + Dot + Cosine,
{
let fsl = vectors.as_fixed_size_list_opt().ok_or(Error::Index {
message: format!(
Expand Down
4 changes: 2 additions & 2 deletions rust/lance-index/src/vector/residual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use arrow_array::{
use arrow_schema::DataType;
use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt};
use lance_core::{Error, Result};
use lance_linalg::distance::{DistanceType, Dot, L2};
use lance_linalg::distance::{Cosine, DistanceType, Dot, L2};
use lance_linalg::kmeans::compute_partitions;
use num_traits::Float;
use snafu::{location, Location};
Expand Down Expand Up @@ -60,7 +60,7 @@ fn do_compute_residual<T: ArrowPrimitiveType>(
partitions: Option<&UInt32Array>,
) -> Result<FixedSizeListArray>
where
T::Native: Float + L2 + Dot,
T::Native: Float + L2 + Dot + Cosine,
{
let dimension = centroids.value_length() as usize;
let centroids_slice = centroids.values().as_primitive::<T>().values();
Expand Down
17 changes: 14 additions & 3 deletions rust/lance-linalg/benches/compute_partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::sync::Arc;

use arrow_array::types::Float32Type;
use criterion::{criterion_group, criterion_main, Criterion};
use lance_linalg::{distance::MetricType, kmeans::compute_partitions};
use lance_linalg::{distance::DistanceType, kmeans::compute_partitions};
use lance_testing::datagen::generate_random_array_with_seed;
#[cfg(target_os = "linux")]
use pprof::criterion::{Output, PProfProfiler};
Expand All @@ -28,7 +28,7 @@ fn bench_compute_partitions(c: &mut Criterion) {
centroids.values(),
input.values(),
DIMENSION,
MetricType::L2,
DistanceType::L2,
)
})
});
Expand All @@ -39,7 +39,18 @@ fn bench_compute_partitions(c: &mut Criterion) {
centroids.values(),
input.values(),
DIMENSION,
MetricType::Cosine,
DistanceType::Cosine,
)
})
});

c.bench_function("compute_centroids(Dot)", |b| {
b.iter(|| {
compute_partitions(
centroids.values(),
input.values(),
DIMENSION,
DistanceType::Dot,
)
})
});
Expand Down
9 changes: 6 additions & 3 deletions rust/lance-linalg/src/kmeans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use rand::prelude::*;
use rayon::prelude::*;

use crate::distance::hamming::hamming;
use crate::distance::{dot_distance_batch, DistanceType};
use crate::distance::{cosine_distance_batch, dot_distance_batch, Cosine, DistanceType};
use crate::kernels::{argmax, argmin_value_float};
use crate::{
distance::{
Expand Down Expand Up @@ -676,7 +676,7 @@ pub fn compute_partitions_arrow_array(
/// Compute partition ID of each vector in the KMeans.
///
/// If returns `None`, means the vector is not valid, i.e., all `NaN`.
pub fn compute_partitions<T: Float + L2 + Dot + Sync>(
pub fn compute_partitions<T: Float + L2 + Dot + Cosine + Sync>(
centroids: &[T],
vectors: &[T],
dimension: impl AsPrimitive<usize>,
Expand All @@ -690,7 +690,7 @@ pub fn compute_partitions<T: Float + L2 + Dot + Sync>(
}

#[inline]
pub fn compute_partition<T: Float + L2 + Dot>(
pub fn compute_partition<T: Float + L2 + Dot + Cosine>(
centroids: &[T],
vector: &[T],
distance_type: DistanceType,
Expand All @@ -702,6 +702,9 @@ pub fn compute_partition<T: Float + L2 + Dot>(
DistanceType::Dot => {
argmin_value_float(dot_distance_batch(vector, centroids, vector.len())).map(|c| c.0)
}
DistanceType::Cosine => {
argmin_value_float(cosine_distance_batch(vector, centroids, vector.len())).map(|c| c.0)
}
_ => {
panic!(
"KMeans::compute_partition: distance type {} is not supported",
Expand Down

0 comments on commit a1e8327

Please sign in to comment.