From 874bf9cf3f408ca38e63da70d79ac353639fa90f Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Mon, 1 Apr 2024 02:08:48 +0800 Subject: [PATCH] feat: add scalar quantizer (#2134) plan to support scalar quantization it's been without any optimization yet --------- Signed-off-by: BubbleCal --- rust/lance-arrow/src/floats.rs | 10 ++ rust/lance-index/src/vector.rs | 1 + rust/lance-index/src/vector/sq.rs | 253 ++++++++++++++++++++++++++++++ 3 files changed, 264 insertions(+) create mode 100644 rust/lance-index/src/vector/sq.rs diff --git a/rust/lance-arrow/src/floats.rs b/rust/lance-arrow/src/floats.rs index a9d9348fd9..3546f565f9 100644 --- a/rust/lance-arrow/src/floats.rs +++ b/rust/lance-arrow/src/floats.rs @@ -95,6 +95,8 @@ pub trait ArrowFloatType: Debug { + Display; const FLOAT_TYPE: FloatType; + const MIN: Self::Native; + const MAX: Self::Native; /// Arrow Float Array Type. type ArrayType: FloatArray; @@ -140,6 +142,8 @@ impl ArrowFloatType for BFloat16Type { type Native = bf16; const FLOAT_TYPE: FloatType = FloatType::BFloat16; + const MIN: Self::Native = bf16::MIN; + const MAX: Self::Native = bf16::MAX; type ArrayType = BFloat16Array; } @@ -148,6 +152,8 @@ impl ArrowFloatType for Float16Type { type Native = f16; const FLOAT_TYPE: FloatType = FloatType::Float16; + const MIN: Self::Native = f16::MIN; + const MAX: Self::Native = f16::MAX; type ArrayType = Float16Array; } @@ -156,6 +162,8 @@ impl ArrowFloatType for Float32Type { type Native = f32; const FLOAT_TYPE: FloatType = FloatType::Float32; + const MIN: Self::Native = f32::MIN; + const MAX: Self::Native = f32::MAX; type ArrayType = Float32Array; } @@ -164,6 +172,8 @@ impl ArrowFloatType for Float64Type { type Native = f64; const FLOAT_TYPE: FloatType = FloatType::Float64; + const MIN: Self::Native = f64::MIN; + const MAX: Self::Native = f64::MAX; type ArrayType = Float64Array; } diff --git a/rust/lance-index/src/vector.rs b/rust/lance-index/src/vector.rs index e05f96c238..24dec1a0d4 100644 --- a/rust/lance-index/src/vector.rs +++ b/rust/lance-index/src/vector.rs @@ -26,6 +26,7 @@ pub mod ivf; pub mod kmeans; pub mod pq; pub mod residual; +pub mod sq; pub mod transform; pub mod utils; diff --git a/rust/lance-index/src/vector/sq.rs b/rust/lance-index/src/vector/sq.rs new file mode 100644 index 0000000000..13195846ef --- /dev/null +++ b/rust/lance-index/src/vector/sq.rs @@ -0,0 +1,253 @@ +// Copyright 2024 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{any::Any, ops::Range, sync::Arc}; + +use arrow::array::AsArray; +use arrow_array::{Array, ArrayRef, FixedSizeListArray, UInt8Array}; + +use lance_arrow::*; +use lance_core::{Error, Result}; +use lance_linalg::distance::{Dot, MetricType, L2}; +use num_traits::*; +use snafu::{location, Location}; + +#[async_trait::async_trait] +pub trait ScalarQuantizer: Send + Sync + std::fmt::Debug { + fn as_any(&self) -> &dyn Any; + + /// Transform a vector column to SQ code column. + /// + /// Parameters + /// ---------- + /// *data*: vector array, must be a `FixedSizeListArray` + /// + /// Returns + /// ------- + /// SQ code column + async fn transform(&self, data: &dyn Array) -> Result; + + /// Get the centroids for each dimension. + fn num_bits(&self) -> u16; + + /// Whether to use residual as input or not. + fn use_residual(&self) -> bool; +} + +/// Scalar Quantization, optimized for [Apache Arrow] buffer memory layout. +/// +// +// TODO: move this to be pub(crate) once we have a better way to test it. +#[derive(Debug)] +pub struct ScalarQuantizerImpl { + /// Number of bits for the centroids. + /// + /// Only support 8, as one of `u8` byte now. + pub num_bits: u16, + + /// Distance type. + pub metric_type: MetricType, + + pub bounds: Range, +} + +impl ScalarQuantizerImpl { + pub fn new(num_bits: u16, metric_type: MetricType) -> Self { + Self { + num_bits, + metric_type, + bounds: Range:: { + start: T::MAX, + end: T::MIN, + }, + } + } + + pub fn with_bounds(num_bits: u16, metric_type: MetricType, bounds: Range) -> Self { + let mut sq = Self::new(num_bits, metric_type); + sq.bounds = bounds; + sq + } + + pub fn update_bounds(&mut self, vectors: &FixedSizeListArray) -> Result> { + let data = vectors + .values() + .as_any() + .downcast_ref::() + .ok_or(Error::Index { + message: format!( + "Expect to be a float vector array, got: {:?}", + vectors.value_type() + ), + location: location!(), + })? + .as_slice(); + + self.bounds = data + .iter() + .fold(self.bounds.clone(), |f, v| f.start.min(*v)..f.end.max(*v)); + + Ok(self.bounds.clone()) + } +} + +#[async_trait::async_trait] +impl ScalarQuantizer for ScalarQuantizerImpl { + fn as_any(&self) -> &dyn Any { + self + } + + async fn transform(&self, data: &dyn Array) -> Result { + let fsl = data + .as_fixed_size_list_opt() + .ok_or(Error::Index { + message: format!( + "Expect to be a FixedSizeList vector array, got: {:?} array", + data.data_type() + ), + location: location!(), + })? + .clone(); + let data = fsl + .values() + .as_any() + .downcast_ref::() + .ok_or(Error::Index { + message: format!( + "Expect to be a float vector array, got: {:?}", + fsl.value_type() + ), + location: location!(), + })? + .as_slice(); + + // TODO: support SQ4 + let builder: Vec = data + .iter() + .map(|v| { + let range = self.bounds.end - self.bounds.start; + match *v { + v if v < self.bounds.start => 0, + v if v > self.bounds.end => 255, + _ => ((*v - self.bounds.start) * T::Native::from_u32(255).unwrap() / range) + .round() + .to_u8() + .unwrap(), + } + }) + .collect(); + + Ok(Arc::new(FixedSizeListArray::try_new_from_values( + UInt8Array::from(builder), + fsl.value_length(), + )?)) + } + + /// Get the centroids for each dimension. + fn num_bits(&self) -> u16 { + self.num_bits + } + + /// Whether to use residual as input or not. + fn use_residual(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::{Float16Type, Float32Type, Float64Type}; + use arrow_array::{Float16Array, Float32Array, Float64Array}; + use half::f16; + + use super::*; + + #[tokio::test] + async fn test_f16_sq8() { + let mut sq: ScalarQuantizerImpl = ScalarQuantizerImpl::new(8, MetricType::L2); + let float_values = Vec::from_iter((0..16).map(|v| f16::from_usize(v).unwrap())); + let float_array = Float16Array::from_iter_values(float_values.clone()); + let vectors = + FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32) + .unwrap(); + + sq.update_bounds(&vectors).unwrap(); + assert_eq!(sq.bounds.start, float_values[0]); + assert_eq!(sq.bounds.end, float_values.last().cloned().unwrap()); + + let sq_code = sq.transform(&vectors).await.unwrap(); + let sq_values = sq_code + .as_fixed_size_list() + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + sq_values.values().iter().enumerate().for_each(|(i, v)| { + assert_eq!(*v, (i * 17) as u8); + }); + } + + #[tokio::test] + async fn test_f32_sq8() { + let mut sq: ScalarQuantizerImpl = ScalarQuantizerImpl::new(8, MetricType::L2); + let float_values = Vec::from_iter((0..16).map(|v| v as f32)); + let float_array = Float32Array::from_iter_values(float_values.clone()); + let vectors = + FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32) + .unwrap(); + + sq.update_bounds(&vectors).unwrap(); + assert_eq!(sq.bounds.start, float_values[0]); + assert_eq!(sq.bounds.end, float_values.last().cloned().unwrap()); + + let sq_code = sq.transform(&vectors).await.unwrap(); + let sq_values = sq_code + .as_fixed_size_list() + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + sq_values.values().iter().enumerate().for_each(|(i, v)| { + assert_eq!(*v, (i * 17) as u8,); + }); + } + + #[tokio::test] + async fn test_f64_sq8() { + let mut sq: ScalarQuantizerImpl = ScalarQuantizerImpl::new(8, MetricType::L2); + let float_values = Vec::from_iter((0..16).map(|v| v as f64)); + let float_array = Float64Array::from_iter_values(float_values.clone()); + let vectors = + FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32) + .unwrap(); + + sq.update_bounds(&vectors).unwrap(); + assert_eq!(sq.bounds.start, float_values[0]); + assert_eq!(sq.bounds.end, float_values.last().cloned().unwrap()); + + let sq_code = sq.transform(&vectors).await.unwrap(); + let sq_values = sq_code + .as_fixed_size_list() + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + sq_values.values().iter().enumerate().for_each(|(i, v)| { + assert_eq!(*v, (i * 17) as u8,); + }); + } +}