Skip to content

Commit

Permalink
feat: add sq storage and transformer (#2135)
Browse files Browse the repository at this point in the history
Signed-off-by: BubbleCal <[email protected]>
  • Loading branch information
BubbleCal authored Apr 1, 2024
1 parent 2f4ed88 commit 29e8a55
Show file tree
Hide file tree
Showing 4 changed files with 412 additions and 72 deletions.
1 change: 1 addition & 0 deletions rust/lance-index/src/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub mod utils;

// TODO: Make these crate private once the migration from lance to lance-index is done.
pub const PQ_CODE_COLUMN: &str = "__pq_code";
pub const SQ_CODE_COLUMN: &str = "__sq_code";
pub const PART_ID_COLUMN: &str = "__ivf_part_id";
pub const DIST_COL: &str = "_distance";

Expand Down
132 changes: 60 additions & 72 deletions rust/lance-index/src/vector/sq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,45 +12,27 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::{any::Any, ops::Range, sync::Arc};
use std::{ops::Range, sync::Arc};

use arrow::array::AsArray;
use arrow_array::{Array, ArrayRef, FixedSizeListArray, UInt8Array};

use itertools::Itertools;
use lance_arrow::*;
use lance_core::{Error, Result};
use lance_linalg::distance::{Dot, MetricType, L2};
use lance_linalg::distance::MetricType;
use num_traits::*;
use snafu::{location, Location};

#[async_trait::async_trait]
pub trait ScalarQuantizer: Send + Sync + std::fmt::Debug {
fn as_any(&self) -> &dyn Any;

/// Transform a vector column to SQ code column.
///
/// Parameters
/// ----------
/// *data*: vector array, must be a `FixedSizeListArray`
///
/// Returns
/// -------
/// SQ code column
async fn transform(&self, data: &dyn Array) -> Result<ArrayRef>;

/// Get the centroids for each dimension.
fn num_bits(&self) -> u16;

/// Whether to use residual as input or not.
fn use_residual(&self) -> bool;
}
pub mod storage;
pub mod transform;

/// Scalar Quantization, optimized for [Apache Arrow] buffer memory layout.
///
//
// TODO: move this to be pub(crate) once we have a better way to test it.
#[derive(Debug)]
pub struct ScalarQuantizerImpl<T: ArrowFloatType + Dot + L2> {
pub struct ScalarQuantizer {
/// Number of bits for the centroids.
///
/// Only support 8, as one of `u8` byte now.
Expand All @@ -59,28 +41,31 @@ pub struct ScalarQuantizerImpl<T: ArrowFloatType + Dot + L2> {
/// Distance type.
pub metric_type: MetricType,

pub bounds: Range<T::Native>,
pub bounds: Range<f64>,
}

impl<T: ArrowFloatType + Dot + L2> ScalarQuantizerImpl<T> {
impl ScalarQuantizer {
pub fn new(num_bits: u16, metric_type: MetricType) -> Self {
Self {
num_bits,
metric_type,
bounds: Range::<T::Native> {
start: T::MAX,
end: T::MIN,
bounds: Range::<f64> {
start: f64::MAX,
end: f64::MIN,
},
}
}

pub fn with_bounds(num_bits: u16, metric_type: MetricType, bounds: Range<T::Native>) -> Self {
pub fn with_bounds(num_bits: u16, metric_type: MetricType, bounds: Range<f64>) -> Self {
let mut sq = Self::new(num_bits, metric_type);
sq.bounds = bounds;
sq
}

pub fn update_bounds(&mut self, vectors: &FixedSizeListArray) -> Result<Range<T::Native>> {
pub fn update_bounds<T: ArrowFloatType>(
&mut self,
vectors: &FixedSizeListArray,
) -> Result<Range<f64>> {
let data = vectors
.values()
.as_any()
Expand All @@ -94,21 +79,14 @@ impl<T: ArrowFloatType + Dot + L2> ScalarQuantizerImpl<T> {
})?
.as_slice();

self.bounds = data
.iter()
.fold(self.bounds.clone(), |f, v| f.start.min(*v)..f.end.max(*v));
self.bounds = data.iter().fold(self.bounds.clone(), |f, v| {
f.start.min(v.to_f64().unwrap())..f.end.max(v.to_f64().unwrap())
});

Ok(self.bounds.clone())
}
}

#[async_trait::async_trait]
impl<T: ArrowFloatType + Dot + L2 + 'static> ScalarQuantizer for ScalarQuantizerImpl<T> {
fn as_any(&self) -> &dyn Any {
self
}

async fn transform(&self, data: &dyn Array) -> Result<ArrayRef> {
async fn transform<T: ArrowFloatType>(&self, data: &dyn Array) -> Result<ArrayRef> {
let fsl = data
.as_fixed_size_list_opt()
.ok_or(Error::Index {
Expand All @@ -133,20 +111,7 @@ impl<T: ArrowFloatType + Dot + L2 + 'static> ScalarQuantizer for ScalarQuantizer
.as_slice();

// TODO: support SQ4
let builder: Vec<u8> = data
.iter()
.map(|v| {
let range = self.bounds.end - self.bounds.start;
match *v {
v if v < self.bounds.start => 0,
v if v > self.bounds.end => 255,
_ => ((*v - self.bounds.start) * T::Native::from_u32(255).unwrap() / range)
.round()
.to_u8()
.unwrap(),
}
})
.collect();
let builder: Vec<u8> = scale_to_u8::<T>(data, self.bounds.clone());

Ok(Arc::new(FixedSizeListArray::try_new_from_values(
UInt8Array::from(builder),
Expand All @@ -155,16 +120,33 @@ impl<T: ArrowFloatType + Dot + L2 + 'static> ScalarQuantizer for ScalarQuantizer
}

/// Get the centroids for each dimension.
fn num_bits(&self) -> u16 {
pub fn num_bits(&self) -> u16 {
self.num_bits
}

/// Whether to use residual as input or not.
fn use_residual(&self) -> bool {
true
pub fn use_residual(&self) -> bool {
false
}
}

pub(crate) fn scale_to_u8<T: ArrowFloatType>(values: &[T::Native], bounds: Range<f64>) -> Vec<u8> {
let range = bounds.end - bounds.start;
values
.iter()
.map(|&v| {
let v = v.to_f64().unwrap();
match v {
v if v < bounds.start => 0,
v if v > bounds.end => 255,
_ => ((v - bounds.start) * f64::from_u32(255).unwrap() / range)
.round()
.to_u8()
.unwrap(),
}
})
.collect_vec()
}
#[cfg(test)]
mod tests {
use arrow::datatypes::{Float16Type, Float32Type, Float64Type};
Expand All @@ -175,18 +157,21 @@ mod tests {

#[tokio::test]
async fn test_f16_sq8() {
let mut sq: ScalarQuantizerImpl<Float16Type> = ScalarQuantizerImpl::new(8, MetricType::L2);
let mut sq = ScalarQuantizer::new(8, MetricType::L2);
let float_values = Vec::from_iter((0..16).map(|v| f16::from_usize(v).unwrap()));
let float_array = Float16Array::from_iter_values(float_values.clone());
let vectors =
FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
.unwrap();

sq.update_bounds(&vectors).unwrap();
assert_eq!(sq.bounds.start, float_values[0]);
assert_eq!(sq.bounds.end, float_values.last().cloned().unwrap());
sq.update_bounds::<Float16Type>(&vectors).unwrap();
assert_eq!(sq.bounds.start, float_values[0].to_f64());
assert_eq!(
sq.bounds.end,
float_values.last().cloned().unwrap().to_f64()
);

let sq_code = sq.transform(&vectors).await.unwrap();
let sq_code = sq.transform::<Float16Type>(&vectors).await.unwrap();
let sq_values = sq_code
.as_fixed_size_list()
.values()
Expand All @@ -201,18 +186,21 @@ mod tests {

#[tokio::test]
async fn test_f32_sq8() {
let mut sq: ScalarQuantizerImpl<Float32Type> = ScalarQuantizerImpl::new(8, MetricType::L2);
let mut sq = ScalarQuantizer::new(8, MetricType::L2);
let float_values = Vec::from_iter((0..16).map(|v| v as f32));
let float_array = Float32Array::from_iter_values(float_values.clone());
let vectors =
FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
.unwrap();

sq.update_bounds(&vectors).unwrap();
assert_eq!(sq.bounds.start, float_values[0]);
assert_eq!(sq.bounds.end, float_values.last().cloned().unwrap());
sq.update_bounds::<Float32Type>(&vectors).unwrap();
assert_eq!(sq.bounds.start, float_values[0].to_f64().unwrap());
assert_eq!(
sq.bounds.end,
float_values.last().cloned().unwrap().to_f64().unwrap()
);

let sq_code = sq.transform(&vectors).await.unwrap();
let sq_code = sq.transform::<Float32Type>(&vectors).await.unwrap();
let sq_values = sq_code
.as_fixed_size_list()
.values()
Expand All @@ -227,18 +215,18 @@ mod tests {

#[tokio::test]
async fn test_f64_sq8() {
let mut sq: ScalarQuantizerImpl<Float64Type> = ScalarQuantizerImpl::new(8, MetricType::L2);
let mut sq = ScalarQuantizer::new(8, MetricType::L2);
let float_values = Vec::from_iter((0..16).map(|v| v as f64));
let float_array = Float64Array::from_iter_values(float_values.clone());
let vectors =
FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
.unwrap();

sq.update_bounds(&vectors).unwrap();
sq.update_bounds::<Float64Type>(&vectors).unwrap();
assert_eq!(sq.bounds.start, float_values[0]);
assert_eq!(sq.bounds.end, float_values.last().cloned().unwrap());

let sq_code = sq.transform(&vectors).await.unwrap();
let sq_code = sq.transform::<Float64Type>(&vectors).await.unwrap();
let sq_values = sq_code
.as_fixed_size_list()
.values()
Expand Down
Loading

0 comments on commit 29e8a55

Please sign in to comment.