-
Notifications
You must be signed in to change notification settings - Fork 215
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
plan to support scalar quantization it's been without any optimization yet --------- Signed-off-by: BubbleCal <[email protected]>
- Loading branch information
Showing
3 changed files
with
264 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
// Copyright 2024 Lance Developers. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
use std::{any::Any, ops::Range, sync::Arc}; | ||
|
||
use arrow::array::AsArray; | ||
use arrow_array::{Array, ArrayRef, FixedSizeListArray, UInt8Array}; | ||
|
||
use lance_arrow::*; | ||
use lance_core::{Error, Result}; | ||
use lance_linalg::distance::{Dot, MetricType, L2}; | ||
use num_traits::*; | ||
use snafu::{location, Location}; | ||
|
||
#[async_trait::async_trait] | ||
pub trait ScalarQuantizer: Send + Sync + std::fmt::Debug { | ||
fn as_any(&self) -> &dyn Any; | ||
|
||
/// Transform a vector column to SQ code column. | ||
/// | ||
/// Parameters | ||
/// ---------- | ||
/// *data*: vector array, must be a `FixedSizeListArray` | ||
/// | ||
/// Returns | ||
/// ------- | ||
/// SQ code column | ||
async fn transform(&self, data: &dyn Array) -> Result<ArrayRef>; | ||
|
||
/// Get the centroids for each dimension. | ||
fn num_bits(&self) -> u16; | ||
|
||
/// Whether to use residual as input or not. | ||
fn use_residual(&self) -> bool; | ||
} | ||
|
||
/// Scalar Quantization, optimized for [Apache Arrow] buffer memory layout. | ||
/// | ||
// | ||
// TODO: move this to be pub(crate) once we have a better way to test it. | ||
#[derive(Debug)] | ||
pub struct ScalarQuantizerImpl<T: ArrowFloatType + Dot + L2> { | ||
/// Number of bits for the centroids. | ||
/// | ||
/// Only support 8, as one of `u8` byte now. | ||
pub num_bits: u16, | ||
|
||
/// Distance type. | ||
pub metric_type: MetricType, | ||
|
||
pub bounds: Range<T::Native>, | ||
} | ||
|
||
impl<T: ArrowFloatType + Dot + L2> ScalarQuantizerImpl<T> { | ||
pub fn new(num_bits: u16, metric_type: MetricType) -> Self { | ||
Self { | ||
num_bits, | ||
metric_type, | ||
bounds: Range::<T::Native> { | ||
start: T::MAX, | ||
end: T::MIN, | ||
}, | ||
} | ||
} | ||
|
||
pub fn with_bounds(num_bits: u16, metric_type: MetricType, bounds: Range<T::Native>) -> Self { | ||
let mut sq = Self::new(num_bits, metric_type); | ||
sq.bounds = bounds; | ||
sq | ||
} | ||
|
||
pub fn update_bounds(&mut self, vectors: &FixedSizeListArray) -> Result<Range<T::Native>> { | ||
let data = vectors | ||
.values() | ||
.as_any() | ||
.downcast_ref::<T::ArrayType>() | ||
.ok_or(Error::Index { | ||
message: format!( | ||
"Expect to be a float vector array, got: {:?}", | ||
vectors.value_type() | ||
), | ||
location: location!(), | ||
})? | ||
.as_slice(); | ||
|
||
self.bounds = data | ||
.iter() | ||
.fold(self.bounds.clone(), |f, v| f.start.min(*v)..f.end.max(*v)); | ||
|
||
Ok(self.bounds.clone()) | ||
} | ||
} | ||
|
||
#[async_trait::async_trait] | ||
impl<T: ArrowFloatType + Dot + L2 + 'static> ScalarQuantizer for ScalarQuantizerImpl<T> { | ||
fn as_any(&self) -> &dyn Any { | ||
self | ||
} | ||
|
||
async fn transform(&self, data: &dyn Array) -> Result<ArrayRef> { | ||
let fsl = data | ||
.as_fixed_size_list_opt() | ||
.ok_or(Error::Index { | ||
message: format!( | ||
"Expect to be a FixedSizeList<float> vector array, got: {:?} array", | ||
data.data_type() | ||
), | ||
location: location!(), | ||
})? | ||
.clone(); | ||
let data = fsl | ||
.values() | ||
.as_any() | ||
.downcast_ref::<T::ArrayType>() | ||
.ok_or(Error::Index { | ||
message: format!( | ||
"Expect to be a float vector array, got: {:?}", | ||
fsl.value_type() | ||
), | ||
location: location!(), | ||
})? | ||
.as_slice(); | ||
|
||
// TODO: support SQ4 | ||
let builder: Vec<u8> = data | ||
.iter() | ||
.map(|v| { | ||
let range = self.bounds.end - self.bounds.start; | ||
match *v { | ||
v if v < self.bounds.start => 0, | ||
v if v > self.bounds.end => 255, | ||
_ => ((*v - self.bounds.start) * T::Native::from_u32(255).unwrap() / range) | ||
.round() | ||
.to_u8() | ||
.unwrap(), | ||
} | ||
}) | ||
.collect(); | ||
|
||
Ok(Arc::new(FixedSizeListArray::try_new_from_values( | ||
UInt8Array::from(builder), | ||
fsl.value_length(), | ||
)?)) | ||
} | ||
|
||
/// Get the centroids for each dimension. | ||
fn num_bits(&self) -> u16 { | ||
self.num_bits | ||
} | ||
|
||
/// Whether to use residual as input or not. | ||
fn use_residual(&self) -> bool { | ||
true | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use arrow::datatypes::{Float16Type, Float32Type, Float64Type}; | ||
use arrow_array::{Float16Array, Float32Array, Float64Array}; | ||
use half::f16; | ||
|
||
use super::*; | ||
|
||
#[tokio::test] | ||
async fn test_f16_sq8() { | ||
let mut sq: ScalarQuantizerImpl<Float16Type> = ScalarQuantizerImpl::new(8, MetricType::L2); | ||
let float_values = Vec::from_iter((0..16).map(|v| f16::from_usize(v).unwrap())); | ||
let float_array = Float16Array::from_iter_values(float_values.clone()); | ||
let vectors = | ||
FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32) | ||
.unwrap(); | ||
|
||
sq.update_bounds(&vectors).unwrap(); | ||
assert_eq!(sq.bounds.start, float_values[0]); | ||
assert_eq!(sq.bounds.end, float_values.last().cloned().unwrap()); | ||
|
||
let sq_code = sq.transform(&vectors).await.unwrap(); | ||
let sq_values = sq_code | ||
.as_fixed_size_list() | ||
.values() | ||
.as_any() | ||
.downcast_ref::<UInt8Array>() | ||
.unwrap(); | ||
|
||
sq_values.values().iter().enumerate().for_each(|(i, v)| { | ||
assert_eq!(*v, (i * 17) as u8); | ||
}); | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_f32_sq8() { | ||
let mut sq: ScalarQuantizerImpl<Float32Type> = ScalarQuantizerImpl::new(8, MetricType::L2); | ||
let float_values = Vec::from_iter((0..16).map(|v| v as f32)); | ||
let float_array = Float32Array::from_iter_values(float_values.clone()); | ||
let vectors = | ||
FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32) | ||
.unwrap(); | ||
|
||
sq.update_bounds(&vectors).unwrap(); | ||
assert_eq!(sq.bounds.start, float_values[0]); | ||
assert_eq!(sq.bounds.end, float_values.last().cloned().unwrap()); | ||
|
||
let sq_code = sq.transform(&vectors).await.unwrap(); | ||
let sq_values = sq_code | ||
.as_fixed_size_list() | ||
.values() | ||
.as_any() | ||
.downcast_ref::<UInt8Array>() | ||
.unwrap(); | ||
|
||
sq_values.values().iter().enumerate().for_each(|(i, v)| { | ||
assert_eq!(*v, (i * 17) as u8,); | ||
}); | ||
} | ||
|
||
#[tokio::test] | ||
async fn test_f64_sq8() { | ||
let mut sq: ScalarQuantizerImpl<Float64Type> = ScalarQuantizerImpl::new(8, MetricType::L2); | ||
let float_values = Vec::from_iter((0..16).map(|v| v as f64)); | ||
let float_array = Float64Array::from_iter_values(float_values.clone()); | ||
let vectors = | ||
FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32) | ||
.unwrap(); | ||
|
||
sq.update_bounds(&vectors).unwrap(); | ||
assert_eq!(sq.bounds.start, float_values[0]); | ||
assert_eq!(sq.bounds.end, float_values.last().cloned().unwrap()); | ||
|
||
let sq_code = sq.transform(&vectors).await.unwrap(); | ||
let sq_values = sq_code | ||
.as_fixed_size_list() | ||
.values() | ||
.as_any() | ||
.downcast_ref::<UInt8Array>() | ||
.unwrap(); | ||
|
||
sq_values.values().iter().enumerate().for_each(|(i, v)| { | ||
assert_eq!(*v, (i * 17) as u8,); | ||
}); | ||
} | ||
} |