Skip to content

Commit

Permalink
feat: add scalar quantizer (#2134)
Browse files Browse the repository at this point in the history
plan to support scalar quantization
it's been without any optimization yet

---------

Signed-off-by: BubbleCal <[email protected]>
  • Loading branch information
BubbleCal authored Mar 31, 2024
1 parent da1d236 commit 874bf9c
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 0 deletions.
10 changes: 10 additions & 0 deletions rust/lance-arrow/src/floats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ pub trait ArrowFloatType: Debug {
+ Display;

const FLOAT_TYPE: FloatType;
const MIN: Self::Native;
const MAX: Self::Native;

/// Arrow Float Array Type.
type ArrayType: FloatArray<Self>;
Expand Down Expand Up @@ -140,6 +142,8 @@ impl ArrowFloatType for BFloat16Type {
type Native = bf16;

const FLOAT_TYPE: FloatType = FloatType::BFloat16;
const MIN: Self::Native = bf16::MIN;
const MAX: Self::Native = bf16::MAX;

type ArrayType = BFloat16Array;
}
Expand All @@ -148,6 +152,8 @@ impl ArrowFloatType for Float16Type {
type Native = f16;

const FLOAT_TYPE: FloatType = FloatType::Float16;
const MIN: Self::Native = f16::MIN;
const MAX: Self::Native = f16::MAX;

type ArrayType = Float16Array;
}
Expand All @@ -156,6 +162,8 @@ impl ArrowFloatType for Float32Type {
type Native = f32;

const FLOAT_TYPE: FloatType = FloatType::Float32;
const MIN: Self::Native = f32::MIN;
const MAX: Self::Native = f32::MAX;

type ArrayType = Float32Array;
}
Expand All @@ -164,6 +172,8 @@ impl ArrowFloatType for Float64Type {
type Native = f64;

const FLOAT_TYPE: FloatType = FloatType::Float64;
const MIN: Self::Native = f64::MIN;
const MAX: Self::Native = f64::MAX;

type ArrayType = Float64Array;
}
Expand Down
1 change: 1 addition & 0 deletions rust/lance-index/src/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub mod ivf;
pub mod kmeans;
pub mod pq;
pub mod residual;
pub mod sq;
pub mod transform;
pub mod utils;

Expand Down
253 changes: 253 additions & 0 deletions rust/lance-index/src/vector/sq.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::{any::Any, ops::Range, sync::Arc};

use arrow::array::AsArray;
use arrow_array::{Array, ArrayRef, FixedSizeListArray, UInt8Array};

use lance_arrow::*;
use lance_core::{Error, Result};
use lance_linalg::distance::{Dot, MetricType, L2};
use num_traits::*;
use snafu::{location, Location};

#[async_trait::async_trait]
pub trait ScalarQuantizer: Send + Sync + std::fmt::Debug {
fn as_any(&self) -> &dyn Any;

/// Transform a vector column to SQ code column.
///
/// Parameters
/// ----------
/// *data*: vector array, must be a `FixedSizeListArray`
///
/// Returns
/// -------
/// SQ code column
async fn transform(&self, data: &dyn Array) -> Result<ArrayRef>;

/// Get the centroids for each dimension.
fn num_bits(&self) -> u16;

/// Whether to use residual as input or not.
fn use_residual(&self) -> bool;
}

/// Scalar Quantization, optimized for [Apache Arrow] buffer memory layout.
///
//
// TODO: move this to be pub(crate) once we have a better way to test it.
#[derive(Debug)]
pub struct ScalarQuantizerImpl<T: ArrowFloatType + Dot + L2> {
/// Number of bits for the centroids.
///
/// Only support 8, as one of `u8` byte now.
pub num_bits: u16,

/// Distance type.
pub metric_type: MetricType,

pub bounds: Range<T::Native>,
}

impl<T: ArrowFloatType + Dot + L2> ScalarQuantizerImpl<T> {
pub fn new(num_bits: u16, metric_type: MetricType) -> Self {
Self {
num_bits,
metric_type,
bounds: Range::<T::Native> {
start: T::MAX,
end: T::MIN,
},
}
}

pub fn with_bounds(num_bits: u16, metric_type: MetricType, bounds: Range<T::Native>) -> Self {
let mut sq = Self::new(num_bits, metric_type);
sq.bounds = bounds;
sq
}

pub fn update_bounds(&mut self, vectors: &FixedSizeListArray) -> Result<Range<T::Native>> {
let data = vectors
.values()
.as_any()
.downcast_ref::<T::ArrayType>()
.ok_or(Error::Index {
message: format!(
"Expect to be a float vector array, got: {:?}",
vectors.value_type()
),
location: location!(),
})?
.as_slice();

self.bounds = data
.iter()
.fold(self.bounds.clone(), |f, v| f.start.min(*v)..f.end.max(*v));

Ok(self.bounds.clone())
}
}

#[async_trait::async_trait]
impl<T: ArrowFloatType + Dot + L2 + 'static> ScalarQuantizer for ScalarQuantizerImpl<T> {
fn as_any(&self) -> &dyn Any {
self
}

async fn transform(&self, data: &dyn Array) -> Result<ArrayRef> {
let fsl = data
.as_fixed_size_list_opt()
.ok_or(Error::Index {
message: format!(
"Expect to be a FixedSizeList<float> vector array, got: {:?} array",
data.data_type()
),
location: location!(),
})?
.clone();
let data = fsl
.values()
.as_any()
.downcast_ref::<T::ArrayType>()
.ok_or(Error::Index {
message: format!(
"Expect to be a float vector array, got: {:?}",
fsl.value_type()
),
location: location!(),
})?
.as_slice();

// TODO: support SQ4
let builder: Vec<u8> = data
.iter()
.map(|v| {
let range = self.bounds.end - self.bounds.start;
match *v {
v if v < self.bounds.start => 0,
v if v > self.bounds.end => 255,
_ => ((*v - self.bounds.start) * T::Native::from_u32(255).unwrap() / range)
.round()
.to_u8()
.unwrap(),
}
})
.collect();

Ok(Arc::new(FixedSizeListArray::try_new_from_values(
UInt8Array::from(builder),
fsl.value_length(),
)?))
}

/// Get the centroids for each dimension.
fn num_bits(&self) -> u16 {
self.num_bits
}

/// Whether to use residual as input or not.
fn use_residual(&self) -> bool {
true
}
}

#[cfg(test)]
mod tests {
use arrow::datatypes::{Float16Type, Float32Type, Float64Type};
use arrow_array::{Float16Array, Float32Array, Float64Array};
use half::f16;

use super::*;

#[tokio::test]
async fn test_f16_sq8() {
let mut sq: ScalarQuantizerImpl<Float16Type> = ScalarQuantizerImpl::new(8, MetricType::L2);
let float_values = Vec::from_iter((0..16).map(|v| f16::from_usize(v).unwrap()));
let float_array = Float16Array::from_iter_values(float_values.clone());
let vectors =
FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
.unwrap();

sq.update_bounds(&vectors).unwrap();
assert_eq!(sq.bounds.start, float_values[0]);
assert_eq!(sq.bounds.end, float_values.last().cloned().unwrap());

let sq_code = sq.transform(&vectors).await.unwrap();
let sq_values = sq_code
.as_fixed_size_list()
.values()
.as_any()
.downcast_ref::<UInt8Array>()
.unwrap();

sq_values.values().iter().enumerate().for_each(|(i, v)| {
assert_eq!(*v, (i * 17) as u8);
});
}

#[tokio::test]
async fn test_f32_sq8() {
let mut sq: ScalarQuantizerImpl<Float32Type> = ScalarQuantizerImpl::new(8, MetricType::L2);
let float_values = Vec::from_iter((0..16).map(|v| v as f32));
let float_array = Float32Array::from_iter_values(float_values.clone());
let vectors =
FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
.unwrap();

sq.update_bounds(&vectors).unwrap();
assert_eq!(sq.bounds.start, float_values[0]);
assert_eq!(sq.bounds.end, float_values.last().cloned().unwrap());

let sq_code = sq.transform(&vectors).await.unwrap();
let sq_values = sq_code
.as_fixed_size_list()
.values()
.as_any()
.downcast_ref::<UInt8Array>()
.unwrap();

sq_values.values().iter().enumerate().for_each(|(i, v)| {
assert_eq!(*v, (i * 17) as u8,);
});
}

#[tokio::test]
async fn test_f64_sq8() {
let mut sq: ScalarQuantizerImpl<Float64Type> = ScalarQuantizerImpl::new(8, MetricType::L2);
let float_values = Vec::from_iter((0..16).map(|v| v as f64));
let float_array = Float64Array::from_iter_values(float_values.clone());
let vectors =
FixedSizeListArray::try_new_from_values(float_array, float_values.len() as i32)
.unwrap();

sq.update_bounds(&vectors).unwrap();
assert_eq!(sq.bounds.start, float_values[0]);
assert_eq!(sq.bounds.end, float_values.last().cloned().unwrap());

let sq_code = sq.transform(&vectors).await.unwrap();
let sq_values = sq_code
.as_fixed_size_list()
.values()
.as_any()
.downcast_ref::<UInt8Array>()
.unwrap();

sq_values.values().iter().enumerate().for_each(|(i, v)| {
assert_eq!(*v, (i * 17) as u8,);
});
}
}

0 comments on commit 874bf9c

Please sign in to comment.