diff --git a/demos/demo_qinco.py b/demos/demo_qinco.py new file mode 100644 index 0000000000..6679e000fb --- /dev/null +++ b/demos/demo_qinco.py @@ -0,0 +1,77 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This demonstrates how to reproduce the QINCo paper results using the Faiss +QINCo implementation. The code loads the reference model because training +is not implemented in Faiss. + +Prepare the data with + +cd /tmp + +# get the reference qinco code +git clone https://github.com/facebookresearch/Qinco.git + +# get the data +wget https://dl.fbaipublicfiles.com/QINCo/datasets/bigann/bigann1M.bvecs + +# get the model +wget https://dl.fbaipublicfiles.com/QINCo/models/bigann_8x8_L2.pt + +""" + +import numpy as np +from faiss.contrib.vecs_io import bvecs_mmap +import sys +import time +import torch +import faiss + +# make sure pickle deserialization will work +sys.path.append("/tmp/Qinco") +import model_qinco + +with torch.no_grad(): + + qinco = torch.load("/tmp/bigann_8x8_L2.pt") + qinco.eval() + # print(qinco) + if True: + torch.set_num_threads(1) + faiss.omp_set_num_threads(1) + + x_base = bvecs_mmap("/tmp/bigann1M.bvecs")[:1000].astype('float32') + x_scaled = torch.from_numpy(x_base) / qinco.db_scale + + t0 = time.time() + codes, _ = qinco.encode(x_scaled) + x_decoded_scaled = qinco.decode(codes) + print(f"Pytorch encode {time.time() - t0:.3f} s") + # multi-thread: 1.13s, single-thread: 7.744 + + x_decoded = x_decoded_scaled.numpy() * qinco.db_scale + + err = ((x_decoded - x_base) ** 2).sum(1).mean() + print("MSE=", err) # = 14211.956, near the L=2 result in Fig 4 of the paper + + qinco2 = faiss.QINCo(qinco) + t0 = time.time() + codes2 = qinco2.encode(faiss.Tensor2D(x_scaled)) + x_decoded2 = qinco2.decode(codes2).numpy() * qinco.db_scale + print(f"Faiss encode {time.time() - t0:.3f} s") + # multi-thread: 3.2s, single thread: 7.019 + + # these tests don't work because there are outlier encodings + # np.testing.assert_array_equal(codes.numpy(), codes2.numpy()) + # np.testing.assert_allclose(x_decoded, x_decoded2) + + ndiff = (codes.numpy() != codes2.numpy()).sum() / codes.numel() + assert ndiff < 0.01 + ndiff = (((x_decoded - x_decoded2) ** 2).sum(1) > 1e-5).sum() + assert ndiff / len(x_base) < 0.01 + + err = ((x_decoded2 - x_base) ** 2).sum(1).mean() + print("MSE=", err) # = 14213.551 diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt index 1b0860f3fb..c33c020008 100644 --- a/faiss/CMakeLists.txt +++ b/faiss/CMakeLists.txt @@ -46,6 +46,7 @@ set(FAISS_SRC IndexScalarQuantizer.cpp IndexShards.cpp IndexShardsIVF.cpp + IndexNeuralNetCodec.cpp MatrixStats.cpp MetaIndexes.cpp VectorTransform.cpp @@ -81,6 +82,7 @@ set(FAISS_SRC invlists/InvertedLists.cpp invlists/InvertedListsIOHook.cpp utils/Heap.cpp + utils/NeuralNet.cpp utils/WorkerThread.cpp utils/distances.cpp utils/distances_simd.cpp diff --git a/faiss/IndexNeuralNetCodec.cpp b/faiss/IndexNeuralNetCodec.cpp new file mode 100644 index 0000000000..3109dce7dc --- /dev/null +++ b/faiss/IndexNeuralNetCodec.cpp @@ -0,0 +1,56 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace faiss { + +/********************************************************* + * IndexNeuralNetCodec implementation + *********************************************************/ + +IndexNeuralNetCodec::IndexNeuralNetCodec( + int d, + int M, + int nbits, + MetricType metric) + : IndexFlatCodes((M * nbits + 7) / 8, d, metric), M(M), nbits(nbits) { + is_trained = false; +} + +void IndexNeuralNetCodec::train(idx_t n, const float* x) { + FAISS_THROW_MSG("Training not implemented in C++, use Pytorch"); +} + +void IndexNeuralNetCodec::sa_encode(idx_t n, const float* x, uint8_t* codes) + const { + nn::Tensor2D x_tensor(n, d, x); + nn::Int32Tensor2D codes_tensor = net->encode(x_tensor); + pack_bitstrings(n, M, nbits, codes_tensor.data(), codes, code_size); +} + +void IndexNeuralNetCodec::sa_decode(idx_t n, const uint8_t* codes, float* x) + const { + nn::Int32Tensor2D codes_tensor(n, M); + unpack_bitstrings(n, M, nbits, codes, code_size, codes_tensor.data()); + nn::Tensor2D x_tensor = net->decode(codes_tensor); + memcpy(x, x_tensor.data(), d * n * sizeof(float)); +} + +/********************************************************* + * IndexQINeuralNetCodec implementation + *********************************************************/ + +IndexQINCo::IndexQINCo(int d, int M, int nbits, int L, int h, MetricType metric) + : IndexNeuralNetCodec(d, M, nbits, metric), + qinco(d, 1 << nbits, L, M, h) { + net = &qinco; +} + +} // namespace faiss diff --git a/faiss/IndexNeuralNetCodec.h b/faiss/IndexNeuralNetCodec.h new file mode 100644 index 0000000000..7c581e80e8 --- /dev/null +++ b/faiss/IndexNeuralNetCodec.h @@ -0,0 +1,49 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include + +namespace faiss { + +struct IndexNeuralNetCodec : IndexFlatCodes { + NeuralNetCodec* net = nullptr; + size_t M, nbits; + + explicit IndexNeuralNetCodec( + int d = 0, + int M = 0, + int nbits = 0, + MetricType metric = METRIC_L2); + + void train(idx_t n, const float* x) override; + + void sa_encode(idx_t n, const float* x, uint8_t* codes) const override; + void sa_decode(idx_t n, const uint8_t* codes, float* x) const override; + + ~IndexNeuralNetCodec() {} +}; + +struct IndexQINCo : IndexNeuralNetCodec { + QINCo qinco; + + IndexQINCo( + int d, + int M, + int nbits, + int L, + int h, + MetricType metric = METRIC_L2); + + ~IndexQINCo() {} +}; + +} // namespace faiss diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h index 511af10f79..3116eb24df 100644 --- a/faiss/impl/ResultHandler.h +++ b/faiss/impl/ResultHandler.h @@ -16,6 +16,7 @@ #include #include #include + #include #include diff --git a/faiss/python/__init__.py b/faiss/python/__init__.py index ce4b42c618..3742fa5372 100644 --- a/faiss/python/__init__.py +++ b/faiss/python/__init__.py @@ -44,6 +44,14 @@ class_wrappers.handle_IDSelectorSubset(IDSelectorBitmap, class_owns=False, force_int64=False) class_wrappers.handle_CodeSet(CodeSet) +class_wrappers.handle_Tensor2D(Tensor2D) +class_wrappers.handle_Tensor2D(Int32Tensor2D) +class_wrappers.handle_Embedding(Embedding) +class_wrappers.handle_Linear(Linear) +class_wrappers.handle_QINCo(QINCo) +class_wrappers.handle_QINCoStep(QINCoStep) + + this_module = sys.modules[__name__] # handle sub-classes diff --git a/faiss/python/class_wrappers.py b/faiss/python/class_wrappers.py index 4af2345009..208065351a 100644 --- a/faiss/python/class_wrappers.py +++ b/faiss/python/class_wrappers.py @@ -1247,3 +1247,152 @@ def replacement_insert(self, codes, inserted=None): return inserted replace_method(the_class, 'insert', replacement_insert) + +###################################################### +# Syntatic sugar for NeuralNet classes +###################################################### + + +def handle_Tensor2D(the_class): + the_class.original_init = the_class.__init__ + + def replacement_init(self, *args): + if len(args) == 1: + array, = args + n, d = array.shape + self.original_init(n, d) + faiss.copy_array_to_vector( + np.ascontiguousarray(array).ravel(), self.v) + else: + self.original_init(*args) + + def numpy(self): + shape = np.zeros(2, dtype=np.int64) + faiss.memcpy(faiss.swig_ptr(shape), self.shape, shape.nbytes) + return faiss.vector_to_array(self.v).reshape(shape[0], shape[1]) + + the_class.__init__ = replacement_init + the_class.numpy = numpy + + +def handle_Embedding(the_class): + the_class.original_init = the_class.__init__ + + def replacement_init(self, *args): + if len(args) != 1 or args[0].__class__ == the_class: + self.original_init(*args) + return + # assume it's a torch.Embedding + emb = args[0] + self.original_init(emb.num_embeddings, emb.embedding_dim) + self.from_torch(emb) + + def from_torch(self, emb): + """ copy weights from torch.Embedding """ + assert emb.weight.shape == (self.num_embeddings, self.embedding_dim) + faiss.copy_array_to_vector( + np.ascontiguousarray(emb.weight.data).ravel(), self.weight) + + def from_array(self, array): + """ copy weights from numpy array """ + assert array.shape == (self.num_embeddings, self.embedding_dim) + faiss.copy_array_to_vector( + np.ascontiguousarray(array).ravel(), self.weight) + + the_class.from_array = from_array + the_class.from_torch = from_torch + the_class.__init__ = replacement_init + + +def handle_Linear(the_class): + the_class.original_init = the_class.__init__ + + def replacement_init(self, *args): + if len(args) != 1 or args[0].__class__ == the_class: + self.original_init(*args) + return + # assume it's a torch.Linear + linear = args[0] + bias = linear.bias is not None + self.original_init(linear.in_features, linear.out_features, bias) + self.from_torch(linear) + + def from_torch(self, linear): + """ copy weights from torch.Linear """ + assert linear.weight.shape == (self.out_features, self.in_features) + faiss.copy_array_to_vector( + linear.weight.data.numpy().ravel(), self.weight) + if linear.bias is not None: + assert linear.bias.shape == (self.out_features,) + faiss.copy_array_to_vector(linear.bias.data.numpy(), self.bias) + + def from_array(self, array, bias=None): + """ copy weights from numpy array """ + assert array.shape == (self.out_features, self.in_features) + faiss.copy_array_to_vector( + np.ascontiguousarray(array).ravel(), self.weight) + if bias is not None: + assert bias.shape == (self.out_features,) + faiss.copy_array_to_vector(bias, self.bias) + + the_class.__init__ = replacement_init + the_class.from_array = from_array + the_class.from_torch = from_torch + +###################################################### +# Syntatic sugar for QINCo and QINCoStep +###################################################### + +def handle_QINCoStep(the_class): + the_class.original_init = the_class.__init__ + + def replacement_init(self, *args): + if len(args) != 1 or args[0].__class__ == the_class: + self.original_init(*args) + return + step = args[0] + # assume it's a Torch QINCoStep + self.original_init(step.d, step.K, step.L, step.h) + self.from_torch(step) + + def from_torch(self, step): + """ copy weights from torch.QINCoStep """ + assert (step.d, step.K, step.L, step.h) == (self.d, self.K, self.L, self.h) + self.codebook.from_torch(step.codebook) + self.MLPconcat.from_torch(step.MLPconcat) + + for l in range(step.L): + src = step.residual_blocks[l] + dest = self.get_residual_block(l) + dest.linear1.from_torch(src[0]) + dest.linear2.from_torch(src[2]) + + the_class.__init__ = replacement_init + the_class.from_torch = from_torch + + +def handle_QINCo(the_class): + the_class.original_init = the_class.__init__ + + def replacement_init(self, *args): + if len(args) != 1 or args[0].__class__ == the_class: + self.original_init(*args) + return + + # assume it's a Torch QINCo + qinco = args[0] + self.original_init(qinco.d, qinco.K, qinco.L, qinco.M, qinco.h) + self.from_torch(qinco) + + def from_torch(self, qinco): + """ copy weights from torch.QINCo """ + assert ( + (qinco.d, qinco.K, qinco.L, qinco.M, qinco.h) == + (self.d, self.K, self.L, self.M, self.h) + ) + self.codebook0.from_torch(qinco.codebook0) + for m in range(qinco.M - 1): + self.get_step(m).from_torch(qinco.steps[m]) + + the_class.__init__ = replacement_init + the_class.from_torch = from_torch diff --git a/faiss/python/swigfaiss.swig b/faiss/python/swigfaiss.swig index 74a371f6cd..f63d76dc0e 100644 --- a/faiss/python/swigfaiss.swig +++ b/faiss/python/swigfaiss.swig @@ -145,6 +145,7 @@ typedef uint64_t size_t; #include #include #include +#include #include @@ -164,6 +165,7 @@ typedef uint64_t size_t; #include #include +#include %} @@ -257,7 +259,6 @@ namespace std { %template(ClusteringIterationStatsVector) std::vector; %template(ParameterRangeVector) std::vector; - #ifndef SWIGWIN %template(OnDiskOneListVector) std::vector; #endif // !SWIGWIN @@ -530,6 +531,12 @@ struct faiss::simd16uint16 {}; %include +%include +%template(Tensor2D) faiss::nn::Tensor2DTemplate; +%template(Int32Tensor2D) faiss::nn::Tensor2DTemplate; + +%include + %ignore faiss::BufferList::Buffer; %ignore faiss::RangeSearchPartialResult::QueryResult; diff --git a/faiss/utils/NeuralNet.cpp b/faiss/utils/NeuralNet.cpp new file mode 100644 index 0000000000..9d5465bae8 --- /dev/null +++ b/faiss/utils/NeuralNet.cpp @@ -0,0 +1,342 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include + +/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */ + +extern "C" { + +int sgemm_( + const char* transa, + const char* transb, + FINTEGER* m, + FINTEGER* n, + FINTEGER* k, + const float* alpha, + const float* a, + FINTEGER* lda, + const float* b, + FINTEGER* ldb, + float* beta, + float* c, + FINTEGER* ldc); +} + +namespace faiss { + +namespace nn { + +/************************************************************* + * Tensor2D implementation + *************************************************************/ + +template +Tensor2DTemplate::Tensor2DTemplate(size_t n0, size_t n1, const T* data_in) + : shape{n0, n1}, v(n0 * n1) { + if (data_in) { + memcpy(data(), data_in, n0 * n1 * sizeof(T)); + } +} + +template +Tensor2DTemplate& Tensor2DTemplate::operator+=( + const Tensor2DTemplate& other) { + FAISS_THROW_IF_NOT(shape[0] == other.shape[0]); + FAISS_THROW_IF_NOT(shape[1] == other.shape[1]); + for (size_t i = 0; i < numel(); i++) { + v[i] += other.v[i]; + } + return *this; +} + +template +Tensor2DTemplate Tensor2DTemplate::column(size_t j) const { + size_t n = shape[0], d = shape[1]; + Tensor2DTemplate out(n, 1); + for (size_t i = 0; i < n; i++) { + out.v[i] = v[i * d + j]; + } + return out; +} + +// explicit template instanciation +template struct Tensor2DTemplate; +template struct Tensor2DTemplate; + +/************************************************************* + * Layers implementation + *************************************************************/ + +Linear::Linear(size_t in_features, size_t out_features, bool bias) + : in_features(in_features), + out_features(out_features), + weight(in_features * out_features) { + if (bias) { + this->bias.resize(out_features); + } +} + +Tensor2D Linear::operator()(const Tensor2D& x) const { + FAISS_THROW_IF_NOT(x.shape[1] == in_features); + size_t n = x.shape[0]; + Tensor2D output(n, out_features); + + float one = 1, zero = 0; + FINTEGER nbiti = out_features, ni = n, di = in_features; + + sgemm_("Transposed", + "Not transposed", + &nbiti, + &ni, + &di, + &one, + weight.data(), + &di, + x.data(), + &di, + &zero, + output.data(), + &nbiti); + + if (bias.size() > 0) { + FAISS_THROW_IF_NOT(bias.size() == out_features); + for (size_t i = 0; i < n; i++) { + for (size_t j = 0; j < out_features; j++) { + output.v[i * out_features + j] += bias[j]; + } + } + } + + return output; +} + +Embedding::Embedding(size_t num_embeddings, size_t embedding_dim) + : num_embeddings(num_embeddings), embedding_dim(embedding_dim) { + weight.resize(num_embeddings * embedding_dim); +} + +Tensor2D Embedding::operator()(const Int32Tensor2D& code) const { + FAISS_THROW_IF_NOT(code.shape[1] == 1); + size_t n = code.shape[0]; + Tensor2D output(n, embedding_dim); + for (size_t i = 0; i < n; ++i) { + size_t ci = code.v[i]; + FAISS_THROW_IF_NOT(ci < num_embeddings); + memcpy(output.data() + i * embedding_dim, + weight.data() + ci * embedding_dim, + sizeof(float) * embedding_dim); + } + return output; // TODO figure out how std::move works +} + +namespace { + +void inplace_relu(Tensor2D& x) { + for (size_t i = 0; i < x.numel(); i++) { + x.v[i] = std::max(0.0f, x.v[i]); + } +} + +Tensor2D concatenate_rows(const Tensor2D& x, const Tensor2D& y) { + size_t n = x.shape[0], d1 = x.shape[1], d2 = y.shape[1]; + FAISS_THROW_IF_NOT(n == y.shape[0]); + Tensor2D out(n, d1 + d2); + for (size_t i = 0; i < n; i++) { + memcpy(out.data() + i * (d1 + d2), + x.data() + i * d1, + sizeof(float) * d1); + memcpy(out.data() + i * (d1 + d2) + d1, + y.data() + i * d2, + sizeof(float) * d2); + } + return out; +} + +} // anonymous namespace + +Tensor2D FFN::operator()(const Tensor2D& x_in) const { + Tensor2D x = linear1(x_in); + inplace_relu(x); + return linear2(x); +} + +} // namespace nn + +/************************************************************* + * QINCoStep implementation + *************************************************************/ + +using namespace nn; + +QINCoStep::QINCoStep(int d, int K, int L, int h) + : d(d), K(K), L(L), h(h), codebook(K, d), MLPconcat(2 * d, d) { + for (int i = 0; i < L; i++) { + residual_blocks.emplace_back(d, h); + } +} + +nn::Tensor2D QINCoStep::decode( + const nn::Tensor2D& xhat, + const nn::Int32Tensor2D& codes) const { + size_t n = xhat.shape[0]; + FAISS_THROW_IF_NOT(n == codes.shape[0]); + Tensor2D zqs = codebook(codes); + Tensor2D cc = concatenate_rows(zqs, xhat); + zqs += MLPconcat(cc); + for (int i = 0; i < L; i++) { + zqs += residual_blocks[i](zqs); + } + return zqs; +} + +nn::Int32Tensor2D QINCoStep::encode( + const nn::Tensor2D& xhat, + const nn::Tensor2D& x, + nn::Tensor2D* residuals) const { + size_t n = xhat.shape[0]; + FAISS_THROW_IF_NOT( + n == x.shape[0] && xhat.shape[1] == d && x.shape[1] == d); + + // repeated codebook + Tensor2D zqs_r(n * K, d); // size n, K, d + Tensor2D cc(n * K, d * 2); // size n, K, d * 2 + size_t d = this->d; + + auto copy_row = [d](Tensor2D& t, size_t i, size_t j, const float* data) { + assert(i <= t.shape[0] && j <= t.shape[1]); + memcpy(t.data() + i * t.shape[1] + j, data, sizeof(float) * d); + }; + + // manual broadcasting + for (size_t i = 0; i < n; i++) { + for (size_t j = 0; j < K; j++) { + copy_row(zqs_r, i * K + j, 0, codebook.data() + j * d); + copy_row(cc, i * K + j, 0, codebook.data() + j * d); + copy_row(cc, i * K + j, d, xhat.data() + i * d); + } + } + + zqs_r += MLPconcat(cc); + + // residual blocks + for (int i = 0; i < L; i++) { + zqs_r += residual_blocks[i](zqs_r); + } + + // add the xhat + for (size_t i = 0; i < n; i++) { + float* zqs_r_row = zqs_r.data() + i * K * d; + const float* xhat_row = xhat.data() + i * d; + for (size_t l = 0; l < K; l++) { + for (size_t j = 0; j < d; j++) { + zqs_r_row[j] += xhat_row[j]; + } + zqs_r_row += d; + } + } + + // perform assignment, finding the nearest + nn::Int32Tensor2D codes(n, 1); + float* res = nullptr; + if (residuals) { + FAISS_THROW_IF_NOT( + residuals->shape[0] == n && residuals->shape[1] == d); + res = residuals->data(); + } + + for (size_t i = 0; i < n; i++) { + const float* q = x.data() + i * d; + const float* db = zqs_r.data() + i * K * d; + float dis_min = HUGE_VALF; + int64_t idx = -1; + for (size_t j = 0; j < K; j++) { + float dis = fvec_L2sqr(q, db, d); + if (dis < dis_min) { + dis_min = dis; + idx = j; + } + db += d; + } + codes.v[i] = idx; + if (res) { + const float* xhat_row = xhat.data() + i * d; + const float* xhat_next_row = zqs_r.data() + (i * K + idx) * d; + for (size_t j = 0; j < d; j++) { + res[j] = xhat_next_row[j] - xhat_row[j]; + } + res += d; + } + } + return codes; +} + +/************************************************************* + * QINCo implementation + *************************************************************/ + +QINCo::QINCo(int d, int K, int L, int M, int h) + : NeuralNetCodec(d, M), K(K), L(L), h(h), codebook0(K, d) { + for (int i = 1; i < M; i++) { + steps.emplace_back(d, K, L, h); + } +} + +nn::Tensor2D QINCo::decode(const nn::Int32Tensor2D& codes) const { + FAISS_THROW_IF_NOT(codes.shape[1] == M); + Tensor2D xhat = codebook0(codes.column(0)); + for (int i = 1; i < M; i++) { + xhat += steps[i - 1].decode(xhat, codes.column(i)); + } + return xhat; +} + +nn::Int32Tensor2D QINCo::encode(const nn::Tensor2D& x) const { + FAISS_THROW_IF_NOT(x.shape[1] == d); + size_t n = x.shape[0]; + Int32Tensor2D codes(n, M); + Tensor2D xhat(n, d); + { + // assign to first codebook as a batch + std::vector dis(n); + std::vector codes64(n); + knn_L2sqr( + x.data(), + codebook0.data(), + d, + n, + K, + 1, + dis.data(), + codes64.data()); + for (size_t i = 0; i < n; i++) { + codes.v[i * M] = codes64[i]; + memcpy(xhat.data() + i * d, + codebook0.data() + codes64[i] * d, + sizeof(float) * d); + } + } + + Tensor2D toadd(n, d); + for (int i = 1; i < M; i++) { + Int32Tensor2D ci = steps[i - 1].encode(xhat, x, &toadd); + for (size_t j = 0; j < n; j++) { + codes.v[j * M + i] = ci.v[j]; + } + xhat += toadd; + } + return codes; +} + +} // namespace faiss diff --git a/faiss/utils/NeuralNet.h b/faiss/utils/NeuralNet.h new file mode 100644 index 0000000000..928fa96bab --- /dev/null +++ b/faiss/utils/NeuralNet.h @@ -0,0 +1,147 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** Implements a few neural net layers, mainly to support QINCo */ + +#pragma once + +#include +#include +#include + +namespace faiss { + +// the names are based on the Pytorch names (more or less) +namespace nn { + +// container for intermediate steps of the neural net +template +struct Tensor2DTemplate { + size_t shape[2]; + std::vector v; + + Tensor2DTemplate(size_t n0, size_t n1, const T* data = nullptr); + + Tensor2DTemplate& operator+=(const Tensor2DTemplate&); + + /// get column #j as a 1-column Tensor2D + Tensor2DTemplate column(size_t j) const; + + size_t numel() const { + return shape[0] * shape[1]; + } + T* data() { + return v.data(); + } + const T* data() const { + return v.data(); + } +}; + +using Tensor2D = Tensor2DTemplate; +using Int32Tensor2D = Tensor2DTemplate; + +/// minimal translation of nn.Linear +struct Linear { + size_t in_features, out_features; + std::vector weight; + std::vector bias; + + Linear(size_t in_features, size_t out_features, bool bias = true); + + Tensor2D operator()(const Tensor2D& x) const; +}; + +/// minimal translation of nn.Embedding +struct Embedding { + size_t num_embeddings, embedding_dim; + std::vector weight; + + Embedding(size_t num_embeddings, size_t embedding_dim); + + Tensor2D operator()(const Int32Tensor2D&) const; + + float* data() { + return weight.data(); + } + + const float* data() const { + return weight.data(); + } +}; + +/// Feed forward layer that expands to a hidden dimension, applies a ReLU non +/// linearity and maps back to the orignal dimension +struct FFN { + Linear linear1, linear2; + + FFN(int d, int h) : linear1(d, h, false), linear2(h, d, false) {} + + Tensor2D operator()(const Tensor2D& x) const; +}; + +} // namespace nn + +// Translation of the QINCo implementation from +// https://github.com/facebookresearch/Qinco/blob/main/model_qinco.py + +struct QINCoStep { + /// d: input dim, K: codebook size, L: # of residual blocks, h: hidden dim + int d, K, L, h; + + QINCoStep(int d, int K, int L, int h); + + nn::Embedding codebook; + nn::Linear MLPconcat; + std::vector residual_blocks; + + nn::FFN& get_residual_block(int i) { + return residual_blocks[i]; + } + + /** encode a set of vectors x with intial estimate xhat. Optionally return + * the delta to be added to xhat to form the new xhat */ + nn::Int32Tensor2D encode( + const nn::Tensor2D& xhat, + const nn::Tensor2D& x, + nn::Tensor2D* residuals = nullptr) const; + + nn::Tensor2D decode( + const nn::Tensor2D& xhat, + const nn::Int32Tensor2D& codes) const; +}; + +struct NeuralNetCodec { + int d, M; + + NeuralNetCodec(int d, int M) : d(d), M(M) {} + + virtual nn::Tensor2D decode(const nn::Int32Tensor2D& codes) const = 0; + virtual nn::Int32Tensor2D encode(const nn::Tensor2D& x) const = 0; + + virtual ~NeuralNetCodec() {} +}; + +struct QINCo : NeuralNetCodec { + int K, L, h; + nn::Embedding codebook0; + std::vector steps; + + QINCo(int d, int K, int L, int M, int h); + + QINCoStep& get_step(int i) { + return steps[i]; + } + + nn::Tensor2D decode(const nn::Int32Tensor2D& codes) const override; + + nn::Int32Tensor2D encode(const nn::Tensor2D& x) const override; + + virtual ~QINCo() {} +}; + +} // namespace faiss diff --git a/tests/torch_test_neural_net.py b/tests/torch_test_neural_net.py new file mode 100644 index 0000000000..4bab6d1ccc --- /dev/null +++ b/tests/torch_test_neural_net.py @@ -0,0 +1,373 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import faiss +import torch +from torch import nn +import unittest +import numpy as np + +from faiss.contrib import datasets +from faiss.contrib.inspect_tools import get_additive_quantizer_codebooks + + +class TestLayer(unittest.TestCase): + + @torch.no_grad() + def test_Embedding(self): + """ verify that the Faiss Embedding works the same as in Pytorch """ + torch.manual_seed(123) + + emb = nn.Embedding(40, 50) + idx = torch.randint(40, (25, )) + ref_batch = emb(idx) + + emb2 = faiss.Embedding(emb) + idx2 = faiss.Int32Tensor2D(idx[:, None].to(dtype=torch.int32)) + new_batch = emb2(idx2) + + new_batch = new_batch.numpy() + np.testing.assert_allclose(ref_batch.numpy(), new_batch, atol=2e-6) + + @torch.no_grad() + def do_test_Linear(self, bias): + """ verify that the Faiss Linear works the same as in Pytorch """ + torch.manual_seed(123) + linear = nn.Linear(50, 40, bias=bias) + x = torch.randn(25, 50) + ref_y = linear(x) + + linear2 = faiss.Linear(linear) + x2 = faiss.Tensor2D(x) + y = linear2(x2) + np.testing.assert_allclose(ref_y.numpy(), y.numpy(), atol=2e-6) + + def test_Linear(self): + self.do_test_Linear(True) + + def test_Linear_nobias(self): + self.do_test_Linear(False) + +###################################################### +# QINCo Pytorch implementation copied from +# https://github.com/facebookresearch/Qinco/blob/main/model_qinco.py +# +# The implementation is copied here to avoid introducting an additional +# dependency. +###################################################### + + +def pairwise_distances(a, b): + anorms = (a**2).sum(-1) + bnorms = (b**2).sum(-1) + return anorms[:, None] + bnorms - 2 * a @ b.T + + +def compute_batch_distances(a, b): + anorms = (a**2).sum(-1) + bnorms = (b**2).sum(-1) + return ( + anorms.unsqueeze(-1) + bnorms.unsqueeze(1) - 2 * torch.bmm(a, b.transpose(2, 1)) + ) + + +def assign_batch_multiple(x, zqs): + bs, d = x.shape + bs, K, d = zqs.shape + + L2distances = compute_batch_distances(x.unsqueeze(1), zqs).squeeze(1) # [bs x ksq] + idx = torch.argmin(L2distances, dim=1).unsqueeze(1) # [bsx1] + quantized = torch.gather(zqs, dim=1, index=idx.unsqueeze(-1).repeat(1, 1, d)) + return idx.squeeze(1), quantized.squeeze(1) + + +def assign_to_codebook(x, c, bs=16384): + nq, d = x.shape + nb, d2 = c.shape + assert d == d2 + if nq * nb < bs * bs: + # small enough to represent the whole distance table + dis = pairwise_distances(x, c) + return dis.argmin(1) + + # otherwise tile computation to avoid OOM + res = torch.empty((nq,), dtype=torch.int64, device=x.device) + cnorms = (c**2).sum(1) + for i in range(0, nq, bs): + xnorms = (x[i : i + bs] ** 2).sum(1, keepdim=True) + for j in range(0, nb, bs): + dis = xnorms + cnorms[j : j + bs] - 2 * x[i : i + bs] @ c[j : j + bs].T + dmini, imini = dis.min(1) + if j == 0: + dmin = dmini + imin = imini + else: + (mask,) = torch.where(dmini < dmin) + dmin[mask] = dmini[mask] + imin[mask] = imini[mask] + j + res[i : i + bs] = imin + return res + + +class QINCoStep(nn.Module): + """ + One quantization step for QINCo. + Contains the codebook, concatenation block, and residual blocks + """ + + def __init__(self, d, K, L, h): + nn.Module.__init__(self) + + self.d, self.K, self.L, self.h = d, K, L, h + + self.codebook = nn.Embedding(K, d) + self.MLPconcat = nn.Linear(2 * d, d) + + self.residual_blocks = [] + for l in range(L): + residual_block = nn.Sequential( + nn.Linear(d, h, bias=False), nn.ReLU(), nn.Linear(h, d, bias=False) + ) + self.add_module(f"residual_block{l}", residual_block) + self.residual_blocks.append(residual_block) + + def decode(self, xhat, codes): + zqs = self.codebook(codes) + cc = torch.concatenate((zqs, xhat), 1) + zqs = zqs + self.MLPconcat(cc) + + for residual_block in self.residual_blocks: + zqs = zqs + residual_block(zqs) + + return zqs + + def encode(self, xhat, x): + # we are trying out the whole codebook + zqs = self.codebook.weight + K, d = zqs.shape + bs, d = xhat.shape + + # repeat so that they are of size bs * K + zqs_r = zqs.repeat(bs, 1, 1).reshape(bs * K, d) + xhat_r = xhat.reshape(bs, 1, d).repeat(1, K, 1).reshape(bs * K, d) + + # pass on batch of size bs * K + cc = torch.concatenate((zqs_r, xhat_r), 1) + zqs_r = zqs_r + self.MLPconcat(cc) + + for residual_block in self.residual_blocks: + zqs_r = zqs_r + residual_block(zqs_r) + + # possible next steps + zqs_r = zqs_r.reshape(bs, K, d) + xhat.reshape(bs, 1, d) + codes, xhat_next = assign_batch_multiple(x, zqs_r) + + return codes, xhat_next - xhat + + +class QINCo(nn.Module): + """ + QINCo quantizer, built from a chain of residual quantization steps + """ + + def __init__(self, d, K, L, M, h): + nn.Module.__init__(self) + + self.d, self.K, self.L, self.M, self.h = d, K, L, M, h + + self.codebook0 = nn.Embedding(K, d) + + self.steps = [] + for m in range(1, M): + step = QINCoStep(d, K, L, h) + self.add_module(f"step{m}", step) + self.steps.append(step) + + def decode(self, codes): + xhat = self.codebook0(codes[:, 0]) + for i, step in enumerate(self.steps): + xhat = xhat + step.decode(xhat, codes[:, i + 1]) + return xhat + + def encode(self, x, code0=None): + """ + Encode a batch of vectors x to codes of length M. + If this function is called from IVF-QINCo, codes are 1 index longer, + due to the first index being the IVF index, and codebook0 is the IVF codebook. + """ + M = len(self.steps) + 1 + bs, d = x.shape + codes = torch.zeros(bs, M, dtype=int, device=x.device) + + if code0 is None: + # at IVF training time, the code0 is fixed (and precomputed) + code0 = assign_to_codebook(x, self.codebook0.weight) + + codes[:, 0] = code0 + xhat = self.codebook0.weight[code0] + + for i, step in enumerate(self.steps): + codes[:, i + 1], toadd = step.encode(xhat, x) + xhat = xhat + toadd + + return codes, xhat + + +###################################################### +# QINCo tests +###################################################### + +def copy_QINCoStep(step): + step2 = faiss.QINCoStep(step.d, step.K, step.L, step.h) + step2.codebook.from_torch(step.codebook) + step2.MLPconcat.from_torch(step.MLPconcat) + + for l in range(step.L): + src = step.residual_blocks[l] + dest = step2.get_residual_block(l) + dest.linear1.from_torch(src[0]) + dest.linear2.from_torch(src[2]) + return step2 + + +class TestQINCoStep(unittest.TestCase): + @torch.no_grad() + def test_decode(self): + torch.manual_seed(123) + step = QINCoStep(d=16, K=20, L=2, h=8) + + codes = torch.randint(0, 20, (10, )) + xhat = torch.randn(10, 16) + ref_decode = step.decode(xhat, codes) + + # step2 = copy_QINCoStep(step) + step2 = faiss.QINCoStep(step) + codes2 = faiss.Int32Tensor2D(codes[:, None].to(dtype=torch.int32)) + + np.testing.assert_array_equal( + step.codebook(codes).numpy(), + step2.codebook(codes2).numpy() + ) + + xhat2 = faiss.Tensor2D(xhat) + # xhat2 = faiss.Tensor2D(len(codes), step2.d) + + new_decode = step2.decode(xhat2, codes2) + + np.testing.assert_allclose( + ref_decode.numpy(), + new_decode.numpy(), + atol=2e-6 + ) + + @torch.no_grad() + def test_encode(self): + torch.manual_seed(123) + step = QINCoStep(d=16, K=20, L=2, h=8) + + # create plausible x for testing starting from actual codes + codes = torch.randint(0, 20, (10, )) + xhat = torch.zeros(10, 16) + x = step.decode(xhat, codes) + del codes + ref_codes, toadd = step.encode(xhat, x) + + step2 = copy_QINCoStep(step) + xhat2 = faiss.Tensor2D(xhat) + x2 = faiss.Tensor2D(x) + toadd2 = faiss.Tensor2D(10, 16) + + new_codes = step2.encode(xhat2, x2, toadd2) + + np.testing.assert_allclose( + ref_codes.numpy(), + new_codes.numpy().ravel(), + atol=2e-6 + ) + np.testing.assert_allclose(toadd.numpy(), toadd2.numpy(), atol=2e-6) + + + +class TestQINCo(unittest.TestCase): + + @torch.no_grad() + def test_decode(self): + torch.manual_seed(123) + qinco = QINCo(d=16, K=20, L=2, M=3, h=8) + codes = torch.randint(0, 20, (10, 3)) + x_ref = qinco.decode(codes) + + qinco2 = faiss.QINCo(qinco) + codes2 = faiss.Int32Tensor2D(codes.to(dtype=torch.int32)) + x_new = qinco2.decode(codes2) + + np.testing.assert_allclose(x_ref.numpy(), x_new.numpy(), atol=2e-6) + + @torch.no_grad() + def test_encode(self): + torch.manual_seed(123) + qinco = QINCo(d=16, K=20, L=2, M=3, h=8) + codes = torch.randint(0, 20, (10, 3)) + x = qinco.decode(codes) + del codes + + ref_codes, _ = qinco.encode(x) + + qinco2 = faiss.QINCo(qinco) + x2 = faiss.Tensor2D(x) + + new_codes = qinco2.encode(x2) + + np.testing.assert_allclose(ref_codes.numpy(), new_codes.numpy(), atol=2e-6) + + +###################################################### +# Test index +###################################################### + +class TestIndexQINCo(unittest.TestCase): + + def test_search(self): + """ + We can't train qinco with just Faiss so we just train a RQ and use the + codebooks in QINCo with L = 0 residual blocks + """ + ds = datasets.SyntheticDataset(32, 1000, 100, 0) + + # prepare reference quantizer + M = 5 + index_ref = faiss.index_factory(ds.d, "RQ5x4") + rq = index_ref.rq + # rq = faiss.ResidualQuantizer(ds.d, M, 4) + rq.train_type = faiss.ResidualQuantizer.Train_default + rq.max_beam_size = 1 # beam search not implemented for QINCo (yet) + index_ref.train(ds.get_train()) + codebooks = get_additive_quantizer_codebooks(rq) + + # convert to QINCo index + qinco_index = faiss.IndexQINCo(ds.d, M, 4, 0, ds.d) + qinco = qinco_index.qinco + qinco.codebook0.from_array(codebooks[0]) + for i in range(1, qinco.M): + step = qinco.get_step(i - 1) + step.codebook.from_array(codebooks[i]) + # MLPConcat left at zero -- it's added to the backbone + qinco_index.is_trained = True + + # verify that the encoding gives the same results + ref_codes = rq.compute_codes(ds.get_database()) + ref_decoded = rq.decode(ref_codes) + new_decoded = qinco_index.sa_decode(ref_codes) + np.testing.assert_allclose(ref_decoded, new_decoded, atol=2e-6) + + new_codes = qinco_index.sa_encode(ds.get_database()) + np.testing.assert_array_equal(ref_codes, new_codes) + + # verify that search gives the same results + Dref, Iref = index_ref.search(ds.get_queries(), 5) + Dnew, Inew = qinco_index.search(ds.get_queries(), 5) + + np.testing.assert_array_equal(Iref, Inew) + np.testing.assert_allclose(Dref, Dnew, atol=2e-6)