diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 6fa1b5830e..eadb9b2b98 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -337,6 +337,9 @@ if(RAFT_COMPILE_LIBRARY) src/neighbors/ivf_flat_search_float_int64_t.cu src/neighbors/ivf_flat_search_int8_t_int64_t.cu src/neighbors/ivf_flat_search_uint8_t_int64_t.cu + src/neighbors/ivf_flat_reconstruct_float_int64_t.cu + src/neighbors/ivf_flat_reconstruct_int8_t_int64_t.cu + src/neighbors/ivf_flat_reconstruct_uint8_t_int64_t.cu src/neighbors/ivfpq_build_float_int64_t.cu src/neighbors/ivfpq_build_int8_t_int64_t.cu src/neighbors/ivfpq_build_uint8_t_int64_t.cu diff --git a/cpp/bench/ann/src/faiss/faiss_benchmark.cu b/cpp/bench/ann/src/faiss/faiss_benchmark.cu index 0aa4e76103..0bad86905b 100644 --- a/cpp/bench/ann/src/faiss/faiss_benchmark.cu +++ b/cpp/bench/ann/src/faiss/faiss_benchmark.cu @@ -104,10 +104,10 @@ std::unique_ptr> create_algo(const std::string& algo, // stop compiler warning; not all algorithms support multi-GPU so it may not be used (void)dev_list; - raft::bench::ann::Metric metric = parse_metric(distance); std::unique_ptr> ann; if constexpr (std::is_same_v) { + raft::bench::ann::Metric metric = parse_metric(distance); if (algo == "faiss_gpu_ivf_flat") { ann = make_algo(metric, dim, conf, dev_list); } else if (algo == "faiss_gpu_ivf_pq") { diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh index 7c2fa05bfe..92c8c8ed81 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh @@ -17,11 +17,13 @@ #pragma once #include +#include #include #include #include #include #include +#include #include #include #include @@ -35,6 +37,8 @@ #include +#include + #include namespace raft::neighbors::ivf_flat::detail { @@ -416,4 +420,151 @@ inline void fill_refinement_index(raft::resources const& handle, refinement_index->veclen()); RAFT_CUDA_TRY(cudaPeekAtLastError()); } + +template +__global__ void get_data_ptr_kernel(const uint32_t* list_sizes, + const T* const* list_data_ptrs, + const IdxT* const* list_indices_ptrs, + uint32_t dim, + uint32_t veclen, + uint32_t n_list, + IdxT max_indice, + T** ptrs_to_data) +{ + const IdxT list_id = IdxT(blockDim.x) * IdxT(blockIdx.x) + threadIdx.x; + if (list_id >= n_list) { return; } + const IdxT inlist_id = IdxT(blockDim.y) * IdxT(blockIdx.y) + threadIdx.y; + const uint32_t list_size = list_sizes[list_id]; + if (inlist_id >= list_size) { return; } + + auto* list_indices = list_indices_ptrs[list_id]; + IdxT id = list_indices[inlist_id]; + if (id > max_indice) { return; } + + using interleaved_group = Pow2; + auto group_offset = interleaved_group::roundDown(inlist_id); + auto ingroup_id = interleaved_group::mod(inlist_id) * veclen; + + auto* list_data = list_data_ptrs[list_id]; + const T* ptr = list_data + (group_offset * dim) + ingroup_id; + ptrs_to_data[id] = (T*)ptr; +} + +template +__global__ void reconstruct_batch_kernel(const IdxT* vector_ids, + const T** ptrs_to_data, + uint32_t dim, + uint32_t veclen, + IdxT n_rows, + T* reconstr) +{ + const IdxT i = IdxT(blockDim.x) * IdxT(blockIdx.x) + threadIdx.x; + if (i >= n_rows) { return; } + + const IdxT vector_id = vector_ids[i]; + const T* src = ptrs_to_data[vector_id]; + if (!src) { return; } + + reconstr += i * dim; + for (uint32_t l = 0; l < dim; l += veclen) { + for (uint32_t j = 0; j < veclen; j++) { + reconstr[l + j] = src[l * kIndexGroupSize + j]; + } + } +} + +template +void reconstruct_batch(raft::resources const& handle, + const index& index, + raft::device_vector_view vector_ids, + raft::device_matrix_view vector_out) +{ + auto stream = raft::resource::get_cuda_stream(handle); + auto exec_policy = raft::resource::get_thrust_policy(handle); + + thrust::device_ptr vector_ids_ptr = + thrust::device_pointer_cast(vector_ids.data_handle()); + IdxT max_indice = + *thrust::max_element(exec_policy, vector_ids_ptr, vector_ids_ptr + vector_ids.extent(0)); + + rmm::device_uvector ptrs_to_data(max_indice + 1, stream); + utils::memzero(ptrs_to_data.data(), ptrs_to_data.size(), stream); + + thrust::device_ptr list_sizes_ptr = + thrust::device_pointer_cast(index.list_sizes().data_handle()); + uint32_t max_list_size = *thrust::max_element( + exec_policy, list_sizes_ptr, list_sizes_ptr + index.list_sizes().extent(0)); + + const dim3 block_dim1(16, 16); + const dim3 grid_dim1(raft::ceildiv(index.n_lists(), block_dim1.x), + raft::ceildiv(max_list_size, block_dim1.y)); + get_data_ptr_kernel<<>>(index.list_sizes().data_handle(), + index.data_ptrs().data_handle(), + index.inds_ptrs().data_handle(), + index.dim(), + index.veclen(), + index.n_lists(), + max_indice, + ptrs_to_data.data()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + auto n_reconstruction = vector_ids.extent(0); + const dim3 block_dim2(256); + const dim3 grid_dim2(raft::ceildiv(n_reconstruction, block_dim2.x)); + reconstruct_batch_kernel<<>>(vector_ids.data_handle(), + (const T**)ptrs_to_data.data(), + index.dim(), + index.veclen(), + n_reconstruction, + vector_out.data_handle()); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +template +__global__ void reconstruct_list_data_kernel(T* out_vectors, + T* in_list_data, + std::variant offset_or_indices, + IdxT len, + size_t veclen, + IdxT dim) +{ + for (IdxT ix = threadIdx.x + blockDim.x * blockIdx.x; ix < len; ix += blockDim.x) { + const IdxT src_ix = std::holds_alternative(offset_or_indices) + ? std::get(offset_or_indices) + ix + : std::get(offset_or_indices)[ix]; + + using group_align = Pow2; + const IdxT group_ix = group_align::div(src_ix); + const IdxT ingroup_ix = group_align::mod(src_ix) * veclen; + + for (IdxT l = 0; l < dim; l += veclen) { + for (IdxT j = 0; j < veclen; j++) { + out_vectors[ix * dim + l + j] = in_list_data[l * kIndexGroupSize + ingroup_ix + j]; + } + } + } +} + +/** Decode the list data; see the public interface for the api and usage. */ +template +void reconstruct_list_data(raft::resources const& handle, + const index& index, + device_matrix_view out_vectors, + uint32_t label, + uint32_t offset) +{ + auto stream = raft::resource::get_cuda_stream(handle); + + IdxT len = out_vectors.extent(0); + const dim3 block_dim(256); + const dim3 grid_dim(raft::div_rounding_up_safe(len, block_dim.x)); + reconstruct_list_data_kernel + <<>>((T*)out_vectors.data_handle(), + (T*)index.lists()[label]->data.data_handle(), + (IdxT)offset, + (IdxT)len, + (size_t)index.veclen(), + (IdxT)index.dim()); +} + } // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/ivf_flat-ext.cuh b/cpp/include/raft/neighbors/ivf_flat-ext.cuh index 848703c9b5..2bcb361242 100644 --- a/cpp/include/raft/neighbors/ivf_flat-ext.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-ext.cuh @@ -114,6 +114,13 @@ void search(raft::resources const& handle, raft::device_matrix_view neighbors, raft::device_matrix_view distances) RAFT_EXPLICIT; +template +void reconstruct_list_data(raft::resources const& handle, + const index& index, + device_matrix_view out_vectors, + uint32_t label, + uint32_t offset) RAFT_EXPLICIT; + } // namespace raft::neighbors::ivf_flat #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -204,3 +211,17 @@ instantiate_raft_neighbors_ivf_flat_search(int8_t, int64_t); instantiate_raft_neighbors_ivf_flat_search(uint8_t, int64_t); #undef instantiate_raft_neighbors_ivf_flat_search + +#define instantiate_raft_neighbors_ivf_flat_reconstruct(T, IdxT) \ + extern template void raft::neighbors::ivf_flat::reconstruct_list_data( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_flat::index& index, \ + raft::device_matrix_view out_vectors, \ + uint32_t label, \ + uint32_t offset); + +instantiate_raft_neighbors_ivf_flat_reconstruct(float, int64_t); +instantiate_raft_neighbors_ivf_flat_reconstruct(int8_t, int64_t); +instantiate_raft_neighbors_ivf_flat_reconstruct(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_reconstruct diff --git a/cpp/include/raft/neighbors/ivf_flat-inl.cuh b/cpp/include/raft/neighbors/ivf_flat-inl.cuh index a18ee065bf..c28e8e1f46 100644 --- a/cpp/include/raft/neighbors/ivf_flat-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-inl.cuh @@ -591,6 +591,30 @@ void search(raft::resources const& handle, raft::neighbors::filtering::none_ivf_sample_filter()); } +/** + * @brief Reconstruct vectors of a given cluster + * + * + * @tparam T data element type + * @tparam IdxT type of the indices + * + * @param[in] handle + * @param[in] index ivf-flat constructed index + * @param[out] out_vectors matrix with the vectors contained in the cluster + * @param[in] label cluster index + * @param[in] offset offset + */ +template +void reconstruct_list_data(raft::resources const& handle, + const index& index, + device_matrix_view out_vectors, + uint32_t label, + uint32_t offset) +{ + return raft::neighbors::ivf_flat::detail::reconstruct_list_data( + handle, index, out_vectors, label, offset); +} + /** @} */ } // namespace raft::neighbors::ivf_flat diff --git a/cpp/include/raft_runtime/neighbors/ivf_flat.hpp b/cpp/include/raft_runtime/neighbors/ivf_flat.hpp index 5b8918ec7f..7b6daad427 100644 --- a/cpp/include/raft_runtime/neighbors/ivf_flat.hpp +++ b/cpp/include/raft_runtime/neighbors/ivf_flat.hpp @@ -46,6 +46,12 @@ namespace raft::runtime::neighbors::ivf_flat { std::optional> new_indices, \ raft::neighbors::ivf_flat::index* idx); \ \ + void reconstruct_list_data(raft::resources const& handle, \ + const raft::neighbors::ivf_flat::index& idx, \ + device_matrix_view out_vectors, \ + uint32_t label, \ + uint32_t offset); \ + \ void serialize_file(raft::resources const& handle, \ const std::string& filename, \ const raft::neighbors::ivf_flat::index& index); \ diff --git a/cpp/src/neighbors/ivf_flat_reconstruct_float_int64_t.cu b/cpp/src/neighbors/ivf_flat_reconstruct_float_int64_t.cu new file mode 100644 index 0000000000..d13d2b5829 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_reconstruct_float_int64_t.cu @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_reconstruct(T, IdxT) \ + template void raft::neighbors::ivf_flat::reconstruct_list_data( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_flat::index& idx, \ + raft::device_matrix_view out_vectors, \ + uint32_t label, \ + uint32_t offset); + +instantiate_raft_neighbors_ivf_flat_reconstruct(float, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_reconstruct diff --git a/cpp/src/neighbors/ivf_flat_reconstruct_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_reconstruct_int8_t_int64_t.cu new file mode 100644 index 0000000000..8940c18209 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_reconstruct_int8_t_int64_t.cu @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_reconstruct(T, IdxT) \ + template void raft::neighbors::ivf_flat::reconstruct_list_data( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_flat::index& idx, \ + raft::device_matrix_view out_vectors, \ + uint32_t label, \ + uint32_t offset); + +instantiate_raft_neighbors_ivf_flat_reconstruct(int8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_reconstruct diff --git a/cpp/src/neighbors/ivf_flat_reconstruct_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_reconstruct_uint8_t_int64_t.cu new file mode 100644 index 0000000000..8329fe15ff --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_reconstruct_uint8_t_int64_t.cu @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_reconstruct(T, IdxT) \ + template void raft::neighbors::ivf_flat::reconstruct_list_data( \ + raft::resources const& handle, \ + const raft::neighbors::ivf_flat::index& idx, \ + raft::device_matrix_view out_vectors, \ + uint32_t label, \ + uint32_t offset); + +instantiate_raft_neighbors_ivf_flat_reconstruct(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_reconstruct diff --git a/cpp/src/raft_runtime/neighbors/ivf_flat_build.cu b/cpp/src/raft_runtime/neighbors/ivf_flat_build.cu index 7fccb95411..68cf13e704 100644 --- a/cpp/src/raft_runtime/neighbors/ivf_flat_build.cu +++ b/cpp/src/raft_runtime/neighbors/ivf_flat_build.cu @@ -51,6 +51,16 @@ namespace raft::runtime::neighbors::ivf_flat { raft::neighbors::ivf_flat::index* idx) \ { \ raft::neighbors::ivf_flat::extend(handle, new_vectors, new_indices, idx); \ + } \ + \ + void reconstruct_list_data(raft::resources const& handle, \ + const raft::neighbors::ivf_flat::index& idx, \ + device_matrix_view out_vectors, \ + uint32_t label, \ + uint32_t offset) \ + { \ + raft::neighbors::ivf_flat::reconstruct_list_data( \ + handle, idx, out_vectors, label, offset); \ } RAFT_INST_BUILD_EXTEND(float, int64_t); diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index a252b26600..2a216328ca 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -19,12 +19,14 @@ #include "ann_utils.cuh" #include #include +#include #include #include #include #include +#include #include #include #include @@ -262,6 +264,73 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { 0.001, min_recall)); } + + { + raft::spatial::knn::ivf_flat::index_params index_params; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + index_params.adaptive_centers = ps.adaptive_centers; + index_params.add_data_on_build = true; + index_params.kmeans_trainset_fraction = 0.5; + + auto vectors_data = + raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); + auto index = ivf_flat::build(handle_, index_params, vectors_data); + + rmm::device_uvector vecs_ids(ps.num_db_vecs, stream_); + thrust::sequence(resource::get_thrust_policy(handle_), + thrust::device_pointer_cast(vecs_ids.data()), + thrust::device_pointer_cast(vecs_ids.data() + ps.num_db_vecs)); + resource::sync_stream(handle_); + + auto vectors_ids = + raft::make_device_vector_view(vecs_ids.data(), ps.num_db_vecs); + auto vectors_out = + raft::make_device_matrix(handle_, ps.num_db_vecs, ps.dim); + ivf_flat::detail::reconstruct_batch(handle_, index, vectors_ids, vectors_out.view()); + + resource::sync_stream(handle_); + + ASSERT_TRUE(raft::devArrMatch(vectors_data.data_handle(), + vectors_out.data_handle(), + ps.num_db_vecs * ps.dim, + raft::Compare(), + stream_)); + } + + { + raft::spatial::knn::ivf_flat::index_params index_params; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + index_params.adaptive_centers = ps.adaptive_centers; + index_params.add_data_on_build = true; + index_params.kmeans_trainset_fraction = 0.5; + + auto vectors_data = + raft::make_device_matrix_view(database.data(), ps.num_db_vecs, ps.dim); + auto index = ivf_flat::build(handle_, index_params, vectors_data); + + uint32_t cluster = 0; + uint32_t offset = 0; + uint32_t n_rows = index.lists()[cluster]->size; + if (n_rows > 0 && n_rows <= 30) { + auto vectors_out = + raft::make_device_matrix(handle_, n_rows, ps.dim); + ivf_flat::reconstruct_list_data(handle_, index, vectors_out.view(), cluster, offset); + + std::vector h_indices(n_rows); + raft::update_host( + h_indices.data(), index.lists()[cluster]->indices.data_handle(), n_rows, stream_); + for (IdxT i = 0; i < n_rows; ++i) { + IdxT idx = h_indices[i]; + ASSERT_TRUE(raft::devArrMatch(&(vectors_data.data_handle()[idx * ps.dim]), + &(vectors_out.data_handle()[i * ps.dim]), + ps.dim, + raft::Compare(), + stream_)); + } + } + } } void SetUp() override