From 09014e8aae9d0223027c52ee9b8c565fe3b63ca9 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 11 Mar 2024 20:40:32 +0000 Subject: [PATCH 1/7] allow individual row lengths for select_k in order to account for massively padded data --- .../raft/matrix/detail/select_k-ext.cuh | 8 ++-- .../raft/matrix/detail/select_k-inl.cuh | 16 +++++--- .../raft/matrix/detail/select_radix.cuh | 37 +++++++++++++++++-- .../raft/neighbors/detail/ivf_common.cuh | 9 +++-- .../neighbors/detail/ivf_flat_search-inl.cuh | 9 +++-- .../raft/neighbors/detail/ivf_pq_search.cuh | 16 +++++--- .../matrix/detail/select_k_double_int64_t.cu | 3 +- .../matrix/detail/select_k_double_uint32_t.cu | 3 +- cpp/src/matrix/detail/select_k_float_int32.cu | 3 +- .../matrix/detail/select_k_float_int64_t.cu | 3 +- .../matrix/detail/select_k_float_uint32_t.cu | 3 +- .../matrix/detail/select_k_half_int64_t.cu | 3 +- .../matrix/detail/select_k_half_uint32_t.cu | 3 +- 13 files changed, 83 insertions(+), 33 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh index 6a7847d8a0..506cbffcb9 100644 --- a/cpp/include/raft/matrix/detail/select_k-ext.cuh +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -41,8 +41,9 @@ void select_k(raft::resources const& handle, T* out_val, IdxT* out_idx, bool select_min, - bool sorted = false, - SelectAlgo algo = SelectAlgo::kAuto) RAFT_EXPLICIT; + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto, + const IdxT* len_i = nullptr) RAFT_EXPLICIT; } // namespace raft::matrix::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -58,7 +59,8 @@ void select_k(raft::resources const& handle, IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(__half, uint32_t); instantiate_raft_matrix_detail_select_k(__half, int64_t); instantiate_raft_matrix_detail_select_k(float, int64_t); diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh index 8f40e6ae00..93d233152b 100644 --- a/cpp/include/raft/matrix/detail/select_k-inl.cuh +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -229,6 +229,9 @@ void segmented_sort_by_key(raft::resources const& handle, * whether to make sure selected pairs are sorted by value * @param[in] algo * the selection algorithm to use + * @param[in] len_i + * array of size (batch_size) providing lengths for each individual row + * only radix select-k supported */ template void select_k(raft::resources const& handle, @@ -240,8 +243,9 @@ void select_k(raft::resources const& handle, T* out_val, IdxT* out_idx, bool select_min, - bool sorted = false, - SelectAlgo algo = SelectAlgo::kAuto) + bool sorted = false, + SelectAlgo algo = SelectAlgo::kAuto, + const IdxT* len_i = nullptr) { common::nvtx::range fun_scope( "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); @@ -262,9 +266,8 @@ void select_k(raft::resources const& handle, out_val, out_idx, select_min, - true // fused_last_filter - ); - + true, // fused_last_filter + len_i); } else { bool fused_last_filter = algo == SelectAlgo::kRadix11bits; detail::select::radix::select_k(handle, @@ -276,7 +279,8 @@ void select_k(raft::resources const& handle, out_val, out_idx, select_min, - fused_last_filter); + fused_last_filter, + len_i); } if (sorted) { auto offsets = make_device_mdarray( diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 82983b7cd2..3d256c39f9 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -450,6 +450,7 @@ RAFT_KERNEL last_filter_kernel(const T* in, T* out, IdxT* out_idx, const IdxT len, + const IdxT* len_i, const IdxT k, Counter* counters, const bool select_min) @@ -557,6 +558,7 @@ RAFT_KERNEL radix_kernel(const T* in, Counter* counters, IdxT* histograms, const IdxT len, + const IdxT* len_i, const IdxT k, const bool select_min, const int pass) @@ -598,6 +600,14 @@ RAFT_KERNEL radix_kernel(const T* in, in_buf += batch_id * buf_len; in_idx_buf += batch_id * buf_len; } + + // in case we have individual len for each query defined we want to make sure + // that we only iterate valid elements. + if (len_i != nullptr) { + const IdxT max_len = max(len_i[batch_id], k); + if (max_len < previous_len) previous_len = len_i[batch_id]; + } + // "current_len > buf_len" means current pass will skip writing buffer if (pass == 0 || current_len > buf_len) { out_buf = nullptr; @@ -829,6 +839,7 @@ void radix_topk(const T* in, IdxT* out_idx, bool select_min, bool fused_last_filter, + const IdxT* len_i, unsigned grid_dim, int sm_cnt, rmm::cuda_stream_view stream, @@ -868,6 +879,7 @@ void radix_topk(const T* in, const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr; T* chunk_out = out + offset * k; IdxT* chunk_out_idx = out_idx + offset * k; + const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; const T* in_buf = nullptr; const IdxT* in_idx_buf = nullptr; @@ -905,6 +917,7 @@ void radix_topk(const T* in, counters.data(), histograms.data(), len, + chunk_len_i, k, select_min, pass); @@ -919,6 +932,7 @@ void radix_topk(const T* in, chunk_out, chunk_out_idx, len, + chunk_len_i, k, counters.data(), select_min); @@ -1007,6 +1021,7 @@ template RAFT_KERNEL radix_topk_one_block_kernel(const T* in, const IdxT* in_idx, const IdxT len, + const IdxT* len_i, const IdxT k, T* out, IdxT* out_idx, @@ -1057,6 +1072,13 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, out_idx_buf = nullptr; } + // in case we have individual len for each query defined we want to make sure + // that we only iterate valid elements. + if (len_i != nullptr) { + const IdxT max_len = max(len_i[batch_id], k); + if (max_len < previous_len) previous_len = len_i[batch_id]; + } + filter_and_histogram_for_one_block(in_buf, in_idx_buf, out_buf, @@ -1106,6 +1128,7 @@ void radix_topk_one_block(const T* in, T* out, IdxT* out_idx, bool select_min, + const IdxT* len_i, int sm_cnt, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) @@ -1121,10 +1144,12 @@ void radix_topk_one_block(const T* in, max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr); for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { - int chunk_size = std::min(max_chunk_size, batch_size - offset); + int chunk_size = std::min(max_chunk_size, batch_size - offset); + const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; kernel<<>>(in + offset * len, in_idx ? (in_idx + offset * len) : nullptr, len, + chunk_len_i, k, out + offset * k, out_idx + offset * k, @@ -1188,6 +1213,8 @@ void radix_topk_one_block(const T* in, * blocks is called. The later case is preferable when leading bits of input data are almost the * same. That is, when the value range of input data is narrow. In such case, there could be a * large number of inputs for the last filter, hence using multiple thread blocks is beneficial. + * @param len_i + * optional array of size (batch_size) providing lengths for each individual row */ template void select_k(raft::resources const& res, @@ -1199,7 +1226,8 @@ void select_k(raft::resources const& res, T* out, IdxT* out_idx, bool select_min, - bool fused_last_filter) + bool fused_last_filter, + const IdxT* len_i) { auto stream = resource::get_cuda_stream(res); auto mr = resource::get_workspace_resource(res); @@ -1223,13 +1251,13 @@ void select_k(raft::resources const& res, if (len <= BlockSize * items_per_thread) { impl::radix_topk_one_block( - in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); + in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { unsigned grid_dim = impl::calc_grid_dim(batch_size, len, sm_cnt); if (grid_dim == 1) { impl::radix_topk_one_block( - in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); + in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { impl::radix_topk(in, in_idx, @@ -1240,6 +1268,7 @@ void select_k(raft::resources const& res, out_idx, select_min, fused_last_filter, + len_i, grid_dim, sm_cnt, stream, diff --git a/cpp/include/raft/neighbors/detail/ivf_common.cuh b/cpp/include/raft/neighbors/detail/ivf_common.cuh index ef7ae7c804..4574208929 100644 --- a/cpp/include/raft/neighbors/detail/ivf_common.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_common.cuh @@ -44,13 +44,13 @@ struct dummy_block_sort_t { * in chunk_indices. Essentially this is a segmented inclusive scan of the cluster sizes. The total * number of samples per query (sum of the cluster sizes that we probe) is returned in n_samples. */ -template +template __launch_bounds__(BlockDim) RAFT_KERNEL calc_chunk_indices_kernel(uint32_t n_probes, const uint32_t* cluster_sizes, // [n_clusters] const uint32_t* clusters_to_probe, // [n_queries, n_probes] uint32_t* chunk_indices, // [n_queries, n_probes] - uint32_t* n_samples // [n_queries] + IdxT* n_samples // [n_queries] ) { using block_scan = cub::BlockScan; @@ -75,6 +75,7 @@ __launch_bounds__(BlockDim) RAFT_KERNEL if (threadIdx.x == 0) { n_samples[blockIdx.x] = total; } } +template struct calc_chunk_indices { public: struct configured { @@ -86,7 +87,7 @@ struct calc_chunk_indices { inline void operator()(const uint32_t* cluster_sizes, const uint32_t* clusters_to_probe, uint32_t* chunk_indices, - uint32_t* n_samples, + IdxT* n_samples, rmm::cuda_stream_view stream) { void* args[] = // NOLINT @@ -107,7 +108,7 @@ struct calc_chunk_indices { if constexpr (BlockDim >= WarpSize * 2) { if (BlockDim >= n_probes * 2) { return try_block_dim<(BlockDim / 2)>(n_probes, n_queries); } } - return {reinterpret_cast(calc_chunk_indices_kernel), + return {reinterpret_cast(calc_chunk_indices_kernel), dim3(BlockDim, 1, 1), dim3(n_queries, 1, 1), n_probes}; diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index 98bdeda42f..d68fb349a7 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -70,7 +70,7 @@ void search_impl(raft::resources const& handle, // The topk index of candidate vectors from each cluster(list) rmm::device_uvector indices_tmp_dev(0, stream, search_mr); // Number of samples for each query - rmm::device_uvector num_samples(0, stream, search_mr); + rmm::device_uvector num_samples(0, stream, search_mr); // Offsets per probe for each query rmm::device_uvector chunk_index(0, stream, search_mr); @@ -184,7 +184,7 @@ void search_impl(raft::resources const& handle, num_samples.resize(n_queries, stream); chunk_index.resize(n_queries_probes, stream); - ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)( + ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)( index.list_sizes().data_handle(), coarse_indices_dev.data(), chunk_index.data(), @@ -232,7 +232,10 @@ void search_impl(raft::resources const& handle, k, distances, neighbors, - select_min); + select_min, + false, + matrix::SelectAlgo::kAuto, + num_samples.data()); if (!manage_local_topk) { // post process distances && neighbor IDs diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index d445f909e5..0f20ab12a4 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -306,11 +306,12 @@ void ivfpq_search_worker(raft::resources const& handle, neighbors_uint32 = neighbors_uint32_buf.data(); } - ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)(index.list_sizes().data_handle(), - clusters_to_probe, - chunk_index.data(), - num_samples.data(), - stream); + ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)( + index.list_sizes().data_handle(), + clusters_to_probe, + chunk_index.data(), + num_samples.data(), + stream); auto coresidency = expected_probe_coresidency(index.n_lists(), n_probes, n_queries); @@ -447,7 +448,10 @@ void ivfpq_search_worker(raft::resources const& handle, topK, topk_dists.data(), neighbors_uint32, - true); + true, + false, + matrix::SelectAlgo::kAuto, + num_samples.data()); // Postprocessing ivf::detail::postprocess_distances( diff --git a/cpp/src/matrix/detail/select_k_double_int64_t.cu b/cpp/src/matrix/detail/select_k_double_int64_t.cu index e32b4ef6f0..bf234aacbf 100644 --- a/cpp/src/matrix/detail/select_k_double_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_double_int64_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(double, int64_t); diff --git a/cpp/src/matrix/detail/select_k_double_uint32_t.cu b/cpp/src/matrix/detail/select_k_double_uint32_t.cu index 21c954ca46..7f0511a76a 100644 --- a/cpp/src/matrix/detail/select_k_double_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_double_uint32_t.cu @@ -29,7 +29,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(double, uint32_t); diff --git a/cpp/src/matrix/detail/select_k_float_int32.cu b/cpp/src/matrix/detail/select_k_float_int32.cu index 7f163a0b0d..e68b1e32df 100644 --- a/cpp/src/matrix/detail/select_k_float_int32.cu +++ b/cpp/src/matrix/detail/select_k_float_int32.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(float, int); diff --git a/cpp/src/matrix/detail/select_k_float_int64_t.cu b/cpp/src/matrix/detail/select_k_float_int64_t.cu index 87b6525356..5aa40d8c9d 100644 --- a/cpp/src/matrix/detail/select_k_float_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_float_int64_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(float, int64_t); diff --git a/cpp/src/matrix/detail/select_k_float_uint32_t.cu b/cpp/src/matrix/detail/select_k_float_uint32_t.cu index e698f811d8..9aba147edf 100644 --- a/cpp/src/matrix/detail/select_k_float_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_float_uint32_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(float, uint32_t); diff --git a/cpp/src/matrix/detail/select_k_half_int64_t.cu b/cpp/src/matrix/detail/select_k_half_int64_t.cu index 0eee20b1fa..bc513e4aeb 100644 --- a/cpp/src/matrix/detail/select_k_half_int64_t.cu +++ b/cpp/src/matrix/detail/select_k_half_int64_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(__half, int64_t); diff --git a/cpp/src/matrix/detail/select_k_half_uint32_t.cu b/cpp/src/matrix/detail/select_k_half_uint32_t.cu index f4e6bae21f..e46c7d46bb 100644 --- a/cpp/src/matrix/detail/select_k_half_uint32_t.cu +++ b/cpp/src/matrix/detail/select_k_half_uint32_t.cu @@ -27,7 +27,8 @@ IdxT* out_idx, \ bool select_min, \ bool sorted, \ - raft::matrix::SelectAlgo algo) + raft::matrix::SelectAlgo algo, \ + const IdxT* len_i) instantiate_raft_matrix_detail_select_k(__half, uint32_t); From 1f6f6e30c7f721c338b852169d7bbacfeda588e0 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Tue, 12 Mar 2024 10:02:15 +0000 Subject: [PATCH 2/7] IVF-flat neighbor ids forced to uint32 during compute --- .../detail/ivf_flat_interleaved_scan-ext.cuh | 4 +- .../detail/ivf_flat_interleaved_scan-inl.cuh | 17 ++-- .../neighbors/detail/ivf_flat_search-inl.cuh | 97 +++++++++++-------- .../raft/neighbors/detail/refine_device.cuh | 39 +++++++- ...at_interleaved_scan_float_float_int64_t.cu | 2 +- ...flat_interleaved_scan_half_half_int64_t.cu | 2 +- ...interleaved_scan_int8_t_int32_t_int64_t.cu | 2 +- ...terleaved_scan_uint8_t_uint32_t_int64_t.cu | 2 +- 8 files changed, 104 insertions(+), 61 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh index 7c2d1d2157..140a9f17c8 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh @@ -45,7 +45,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& i const uint32_t* chunk_indices, const bool select_min, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) RAFT_EXPLICIT; @@ -70,7 +70,7 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& i const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh index 6fc528e26b..283f56fe76 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -700,7 +700,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) const uint32_t* chunk_indices, const uint32_t dim, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances) { extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[]; @@ -752,11 +752,9 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2 uint32_t sample_offset = 0; - if constexpr (!kManageLocalTopK) { - if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; } - assert(list_length == chunk_indices[probe_id] - sample_offset); - assert(sample_offset + list_length <= max_samples); - } + if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; } + assert(list_length == chunk_indices[probe_id] - sample_offset); + assert(sample_offset + list_length <= max_samples); constexpr int kUnroll = WarpSize / Veclen; constexpr uint32_t kNumWarps = kThreadsPerBlock / WarpSize; @@ -806,8 +804,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) // Enqueue one element per thread const float val = valid ? static_cast(dist) : local_topk_t::queue_t::kDummy; if constexpr (kManageLocalTopK) { - const size_t idx = valid ? static_cast(list_indices_ptrs[list_id][vec_id]) : 0; - queue.add(val, idx); + queue.add(val, sample_offset + vec_id); } else { if (vec_id < list_length) distances[sample_offset + vec_id] = val; } @@ -873,7 +870,7 @@ void launch_kernel(Lambda lambda, const uint32_t max_samples, const uint32_t* chunk_indices, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) @@ -1161,7 +1158,7 @@ void ivfflat_interleaved_scan(const index& index, const uint32_t* chunk_indices, const bool select_min, IvfSampleFilterT sample_filter, - IdxT* neighbors, + uint32_t* neighbors, float* distances, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index d68fb349a7..36b3c0f91b 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -67,13 +67,16 @@ void search_impl(raft::resources const& handle, // Optional structures if postprocessing is required // The topk distance value of candidate vectors from each cluster(list) rmm::device_uvector distances_tmp_dev(0, stream, search_mr); - // The topk index of candidate vectors from each cluster(list) - rmm::device_uvector indices_tmp_dev(0, stream, search_mr); // Number of samples for each query - rmm::device_uvector num_samples(0, stream, search_mr); + rmm::device_uvector num_samples(0, stream, search_mr); // Offsets per probe for each query rmm::device_uvector chunk_index(0, stream, search_mr); + // The topk index of candidate vectors from each cluster(list), local index offset + // also we might need additional storage for select_k + rmm::device_uvector indices_tmp_dev(0, stream, search_mr); + rmm::device_uvector neighbors_uint32_buf(0, stream, search_mr); + size_t float_query_size; if constexpr (std::is_integral_v) { float_query_size = n_queries * index.dim(); @@ -175,23 +178,30 @@ void search_impl(raft::resources const& handle, grid_dim_x = 1; } + num_samples.resize(n_queries, stream); + chunk_index.resize(n_queries_probes, stream); + + ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)( + index.list_sizes().data_handle(), + coarse_indices_dev.data(), + chunk_index.data(), + num_samples.data(), + stream); + auto distances_dev_ptr = distances; - auto indices_dev_ptr = neighbors; + + uint32_t* neighbors_uint32 = nullptr; + if constexpr (sizeof(IdxT) == sizeof(uint32_t)) { + neighbors_uint32 = reinterpret_cast(neighbors); + } else { + neighbors_uint32_buf.resize(std::size_t(n_queries) * std::size_t(k), stream); + neighbors_uint32 = neighbors_uint32_buf.data(); + } + + uint32_t* indices_dev_ptr = nullptr; bool manage_local_topk = is_local_topk_feasible(k); if (!manage_local_topk || grid_dim_x > 1) { - if (!manage_local_topk) { - num_samples.resize(n_queries, stream); - chunk_index.resize(n_queries_probes, stream); - - ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)( - index.list_sizes().data_handle(), - coarse_indices_dev.data(), - chunk_index.data(), - num_samples.data(), - stream); - } - auto target_size = std::size_t(n_queries) * (manage_local_topk ? grid_dim_x * k : max_samples); distances_tmp_dev.resize(target_size, stream); @@ -199,6 +209,8 @@ void search_impl(raft::resources const& handle, distances_dev_ptr = distances_tmp_dev.data(); indices_dev_ptr = indices_tmp_dev.data(); + } else { + indices_dev_ptr = neighbors_uint32; } ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( @@ -224,34 +236,33 @@ void search_impl(raft::resources const& handle, // Merge topk values from different blocks if (!manage_local_topk || grid_dim_x > 1) { - matrix::detail::select_k(handle, - distances_tmp_dev.data(), - indices_tmp_dev.data(), - n_queries, - manage_local_topk ? (k * grid_dim_x) : max_samples, - k, - distances, - neighbors, - select_min, - false, - matrix::SelectAlgo::kAuto, - num_samples.data()); - - if (!manage_local_topk) { - // post process distances && neighbor IDs - ivf::detail::postprocess_distances( - distances, distances, index.metric(), n_queries, k, 1.0, false, stream); - ivf::detail::postprocess_neighbors(neighbors, - neighbors, - index.inds_ptrs().data_handle(), - coarse_indices_dev.data(), - chunk_index.data(), - n_queries, - n_probes, - k, - stream); - } + matrix::detail::select_k(handle, + distances_tmp_dev.data(), + indices_tmp_dev.data(), + n_queries, + manage_local_topk ? (k * grid_dim_x) : max_samples, + k, + distances, + neighbors_uint32, + select_min, + false, + matrix::SelectAlgo::kAuto, + num_samples.data()); + } + if (!manage_local_topk) { + // post process distances && neighbor IDs + ivf::detail::postprocess_distances( + distances, distances, index.metric(), n_queries, k, 1.0, false, stream); } + ivf::detail::postprocess_neighbors(neighbors, + neighbors_uint32, + index.inds_ptrs().data_handle(), + coarse_indices_dev.data(), + chunk_index.data(), + n_queries, + n_probes, + k, + stream); } /** See raft::neighbors::ivf_flat::search docs */ diff --git a/cpp/include/raft/neighbors/detail/refine_device.cuh b/cpp/include/raft/neighbors/detail/refine_device.cuh index e76e52657b..a2c4c21ef5 100644 --- a/cpp/include/raft/neighbors/detail/refine_device.cuh +++ b/cpp/include/raft/neighbors/detail/refine_device.cuh @@ -88,6 +88,30 @@ void refine_device(raft::resources const& handle, n_queries, n_candidates); uint32_t grid_dim_x = 1; + + // the neighbor ids will be computed in uint32_t as offset + rmm::device_uvector neighbors_uint32_buf(0, resource::get_cuda_stream(handle)); + // Number of samples for each query + rmm::device_uvector num_samples(n_queries, resource::get_cuda_stream(handle)); + // Offsets per probe for each query + rmm::device_uvector chunk_index(n_queries, resource::get_cuda_stream(handle)); + + ivf::detail::calc_chunk_indices::configure(1, n_queries)( + refinement_index.list_sizes().data_handle(), + fake_coarse_idx.data(), + chunk_index.data(), + num_samples.data(), + resource::get_cuda_stream(handle)); + + uint32_t* neighbors_uint32 = nullptr; + if constexpr (sizeof(idx_t) == sizeof(uint32_t)) { + neighbors_uint32 = reinterpret_cast(indices.data_handle()); + } else { + neighbors_uint32_buf.resize(std::size_t(n_queries) * std::size_t(k), + resource::get_cuda_stream(handle)); + neighbors_uint32 = neighbors_uint32_buf.data(); + } + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< data_t, typename raft::spatial::knn::detail::utils::config::value_t, @@ -100,13 +124,24 @@ void refine_device(raft::resources const& handle, 1, k, 0, - nullptr, + chunk_index.data(), raft::distance::is_min_close(metric), raft::neighbors::filtering::none_ivf_sample_filter(), - indices.data_handle(), + neighbors_uint32, distances.data_handle(), grid_dim_x, resource::get_cuda_stream(handle)); + + // postprocessing -- neighbors from position to actual id + ivf::detail::postprocess_neighbors(indices.data_handle(), + neighbors_uint32, + refinement_index.inds_ptrs().data_handle(), + fake_coarse_idx.data(), + chunk_index.data(), + n_queries, + 1, + k, + resource::get_cuda_stream(handle)); } } // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu index def33e493e..5ac820e0dd 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu @@ -33,7 +33,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu index e96600ee02..4d847cdeb1 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu @@ -35,7 +35,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu index 13c9d2e283..8a0e8f0118 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu @@ -33,7 +33,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu index 51f02343fc..7cad992e2b 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu @@ -33,7 +33,7 @@ const uint32_t* chunk_indices, \ const bool select_min, \ IvfSampleFilterT sample_filter, \ - IdxT* neighbors, \ + uint32_t* neighbors, \ float* distances, \ uint32_t& grid_dim_x, \ rmm::cuda_stream_view stream) From b3f20038410b5f73aa536f2e538183895b86b4c1 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Fri, 15 Mar 2024 21:10:11 +0000 Subject: [PATCH 3/7] bugfix & cleanup --- .../raft/matrix/detail/select_radix.cuh | 2 -- .../raft/neighbors/detail/ivf_common.cuh | 29 +++++++++---------- .../detail/ivf_flat_interleaved_scan-inl.cuh | 6 ++-- .../neighbors/detail/ivf_flat_search-inl.cuh | 11 ++++--- .../raft/neighbors/detail/ivf_pq_search.cuh | 11 ++++--- .../raft/neighbors/detail/refine_device.cuh | 2 +- 6 files changed, 27 insertions(+), 34 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 3d256c39f9..d23472a249 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -450,7 +450,6 @@ RAFT_KERNEL last_filter_kernel(const T* in, T* out, IdxT* out_idx, const IdxT len, - const IdxT* len_i, const IdxT k, Counter* counters, const bool select_min) @@ -932,7 +931,6 @@ void radix_topk(const T* in, chunk_out, chunk_out_idx, len, - chunk_len_i, k, counters.data(), select_min); diff --git a/cpp/include/raft/neighbors/detail/ivf_common.cuh b/cpp/include/raft/neighbors/detail/ivf_common.cuh index 4574208929..df0319e181 100644 --- a/cpp/include/raft/neighbors/detail/ivf_common.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_common.cuh @@ -44,13 +44,13 @@ struct dummy_block_sort_t { * in chunk_indices. Essentially this is a segmented inclusive scan of the cluster sizes. The total * number of samples per query (sum of the cluster sizes that we probe) is returned in n_samples. */ -template +template __launch_bounds__(BlockDim) RAFT_KERNEL calc_chunk_indices_kernel(uint32_t n_probes, const uint32_t* cluster_sizes, // [n_clusters] const uint32_t* clusters_to_probe, // [n_queries, n_probes] uint32_t* chunk_indices, // [n_queries, n_probes] - IdxT* n_samples // [n_queries] + uint32_t* n_samples // [n_queries] ) { using block_scan = cub::BlockScan; @@ -75,7 +75,6 @@ __launch_bounds__(BlockDim) RAFT_KERNEL if (threadIdx.x == 0) { n_samples[blockIdx.x] = total; } } -template struct calc_chunk_indices { public: struct configured { @@ -87,7 +86,7 @@ struct calc_chunk_indices { inline void operator()(const uint32_t* cluster_sizes, const uint32_t* clusters_to_probe, uint32_t* chunk_indices, - IdxT* n_samples, + uint32_t* n_samples, rmm::cuda_stream_view stream) { void* args[] = // NOLINT @@ -108,7 +107,7 @@ struct calc_chunk_indices { if constexpr (BlockDim >= WarpSize * 2) { if (BlockDim >= n_probes * 2) { return try_block_dim<(BlockDim / 2)>(n_probes, n_queries); } } - return {reinterpret_cast(calc_chunk_indices_kernel), + return {reinterpret_cast(calc_chunk_indices_kernel), dim3(BlockDim, 1, 1), dim3(n_queries, 1, 1), n_probes}; @@ -148,11 +147,11 @@ __device__ inline auto find_chunk_ix(uint32_t& sample_ix, // NOLINT return ix_min; } -template +template __launch_bounds__(BlockDim) RAFT_KERNEL - postprocess_neighbors_kernel(IdxT1* neighbors_out, // [n_queries, topk] - const IdxT2* neighbors_in, // [n_queries, topk] - const IdxT1* const* db_indices, // [n_clusters][..] + postprocess_neighbors_kernel(IdxT* neighbors_out, // [n_queries, topk] + const uint32_t* neighbors_in, // [n_queries, topk] + const IdxT* const* db_indices, // [n_clusters][..] const uint32_t* clusters_to_probe, // [n_queries, n_probes] const uint32_t* chunk_indices, // [n_queries, n_probes] uint32_t n_queries, @@ -171,7 +170,7 @@ __launch_bounds__(BlockDim) RAFT_KERNEL const uint32_t chunk_ix = find_chunk_ix(data_ix, n_probes, chunk_indices); const bool valid = chunk_ix < n_probes; neighbors_out[k] = - valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord; + valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord; } /** @@ -181,10 +180,10 @@ __launch_bounds__(BlockDim) RAFT_KERNEL * probed clusters / defined by the `chunk_indices`. * We assume the searched sample sizes (for a single query) fit into `uint32_t`. */ -template -void postprocess_neighbors(IdxT1* neighbors_out, // [n_queries, topk] - const IdxT2* neighbors_in, // [n_queries, topk] - const IdxT1* const* db_indices, // [n_clusters][..] +template +void postprocess_neighbors(IdxT* neighbors_out, // [n_queries, topk] + const uint32_t* neighbors_in, // [n_queries, topk] + const IdxT* const* db_indices, // [n_clusters][..] const uint32_t* clusters_to_probe, // [n_queries, n_probes] const uint32_t* chunk_indices, // [n_queries, n_probes] uint32_t n_queries, @@ -194,7 +193,7 @@ void postprocess_neighbors(IdxT1* neighbors_out, // [n_queries, to { constexpr int kPNThreads = 256; const int pn_blocks = raft::div_rounding_up_unsafe(n_queries * topk, kPNThreads); - postprocess_neighbors_kernel + postprocess_neighbors_kernel <<>>(neighbors_out, neighbors_in, db_indices, diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh index 283f56fe76..4ddc708ada 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -690,7 +690,6 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) const uint32_t query_smem_elems, const T* query, const uint32_t* coarse_index, - const IdxT* const* list_indices_ptrs, const T* const* list_data_ptrs, const uint32_t* list_sizes, const uint32_t queries_offset, @@ -719,8 +718,8 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) distances += query_id * k * gridDim.x + blockIdx.x * k; } else { distances += query_id * uint64_t(max_samples); - chunk_indices += (n_probes * query_id); } + chunk_indices += (n_probes * query_id); coarse_index += query_id * n_probes; } @@ -728,7 +727,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) copy_vectorized(query_shared, query, std::min(dim, query_smem_elems)); __syncthreads(); - using local_topk_t = block_sort_t; + using local_topk_t = block_sort_t; local_topk_t queue(k); { using align_warp = Pow2; @@ -924,7 +923,6 @@ void launch_kernel(Lambda lambda, query_smem_elems, queries, coarse_index, - index.inds_ptrs().data_handle(), index.data_ptrs().data_handle(), index.list_sizes().data_handle(), queries_offset + query_offset, diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index 36b3c0f91b..6dcb77fb5c 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -181,12 +181,11 @@ void search_impl(raft::resources const& handle, num_samples.resize(n_queries, stream); chunk_index.resize(n_queries_probes, stream); - ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)( - index.list_sizes().data_handle(), - coarse_indices_dev.data(), - chunk_index.data(), - num_samples.data(), - stream); + ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)(index.list_sizes().data_handle(), + coarse_indices_dev.data(), + chunk_index.data(), + num_samples.data(), + stream); auto distances_dev_ptr = distances; diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index 0f20ab12a4..7e81d8e28c 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -306,12 +306,11 @@ void ivfpq_search_worker(raft::resources const& handle, neighbors_uint32 = neighbors_uint32_buf.data(); } - ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)( - index.list_sizes().data_handle(), - clusters_to_probe, - chunk_index.data(), - num_samples.data(), - stream); + ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)(index.list_sizes().data_handle(), + clusters_to_probe, + chunk_index.data(), + num_samples.data(), + stream); auto coresidency = expected_probe_coresidency(index.n_lists(), n_probes, n_queries); diff --git a/cpp/include/raft/neighbors/detail/refine_device.cuh b/cpp/include/raft/neighbors/detail/refine_device.cuh index a2c4c21ef5..b3c4559b9a 100644 --- a/cpp/include/raft/neighbors/detail/refine_device.cuh +++ b/cpp/include/raft/neighbors/detail/refine_device.cuh @@ -96,7 +96,7 @@ void refine_device(raft::resources const& handle, // Offsets per probe for each query rmm::device_uvector chunk_index(n_queries, resource::get_cuda_stream(handle)); - ivf::detail::calc_chunk_indices::configure(1, n_queries)( + ivf::detail::calc_chunk_indices::configure(1, n_queries)( refinement_index.list_sizes().data_handle(), fake_coarse_idx.data(), chunk_index.data(), From 364b5d5430abc5c4b0665ea54c4e9eaabbffb783 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Mon, 18 Mar 2024 19:44:57 +0000 Subject: [PATCH 4/7] 2 more bugfixes & hardened tests --- .../raft/matrix/detail/select_radix.cuh | 4 +-- .../detail/ivf_flat_interleaved_scan-inl.cuh | 2 +- .../neighbors/detail/ivf_flat_search-inl.cuh | 2 +- .../raft/neighbors/detail/ivf_pq_search.cuh | 2 +- cpp/test/neighbors/ann_utils.cuh | 31 +++++++++++++++++-- 5 files changed, 34 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index d23472a249..36a346fda3 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -604,7 +604,7 @@ RAFT_KERNEL radix_kernel(const T* in, // that we only iterate valid elements. if (len_i != nullptr) { const IdxT max_len = max(len_i[batch_id], k); - if (max_len < previous_len) previous_len = len_i[batch_id]; + if (max_len < previous_len) previous_len = max_len; } // "current_len > buf_len" means current pass will skip writing buffer @@ -1074,7 +1074,7 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, // that we only iterate valid elements. if (len_i != nullptr) { const IdxT max_len = max(len_i[batch_id], k); - if (max_len < previous_len) previous_len = len_i[batch_id]; + if (max_len < previous_len) previous_len = max_len; } filter_and_histogram_for_one_block(in_buf, diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh index 4ddc708ada..9cd8b70148 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -940,8 +940,8 @@ void launch_kernel(Lambda lambda, distances += grid_dim_y * grid_dim_x * k; } else { distances += grid_dim_y * max_samples; - chunk_indices += grid_dim_y * n_probes; } + chunk_indices += grid_dim_y * n_probes; coarse_index += grid_dim_y * n_probes; } } diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index 6dcb77fb5c..441fb76b2f 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -246,7 +246,7 @@ void search_impl(raft::resources const& handle, select_min, false, matrix::SelectAlgo::kAuto, - num_samples.data()); + manage_local_topk ? nullptr : num_samples.data()); } if (!manage_local_topk) { // post process distances && neighbor IDs diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index 7e81d8e28c..4c5da38092 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -450,7 +450,7 @@ void ivfpq_search_worker(raft::resources const& handle, true, false, matrix::SelectAlgo::kAuto, - num_samples.data()); + manage_local_topk ? nullptr : num_samples.data()); // Postprocessing ivf::detail::postprocess_distances( diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index afd083d512..91017bb4e0 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -35,6 +35,7 @@ #include #include +#include namespace raft::neighbors { @@ -153,6 +154,32 @@ auto calc_recall(const std::vector& expected_idx, static_cast(match_count) / static_cast(total_count), match_count, total_count); } +/** check uniqueness of indices + */ +template +auto check_unique_indices(const std::vector& actual_idx, size_t rows, size_t cols) +{ + size_t max_count; + std::set unique_indices; + for (size_t i = 0; i < rows; ++i) { + unique_indices.clear(); + max_count = 0; + for (size_t k = 0; k < cols; ++k) { + size_t idx_k = i * cols + k; // row major assumption! + auto act_idx = actual_idx[idx_k]; + if (act_idx == std::numeric_limits::max()) { + max_count++; + } else if (unique_indices.find(act_idx) == unique_indices.end()) { + unique_indices.insert(act_idx); + } else { + return testing::AssertionFailure() + << "Duplicated index " << act_idx << " for query " << i << "! "; + } + } + } + return testing::AssertionSuccess(); +} + template auto eval_recall(const std::vector& expected_idx, const std::vector& actual_idx, @@ -176,7 +203,7 @@ auto eval_recall(const std::vector& expected_idx, << "actual recall (" << actual_recall << ") is lower than the minimum expected recall (" << min_recall << "); eps = " << eps << ". "; } - return testing::AssertionSuccess(); + return check_unique_indices(actual_idx, rows, cols); } /** Overload of calc_recall to account for distances @@ -241,7 +268,7 @@ auto eval_neighbours(const std::vector& expected_idx, << "actual recall (" << actual_recall << ") is lower than the minimum expected recall (" << min_recall << "); eps = " << eps << ". "; } - return testing::AssertionSuccess(); + return check_unique_indices(actual_idx, rows, cols); } template From 4454d87dfde08ea8f0b855658076b686fd4653ec Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Wed, 20 Mar 2024 08:18:56 +0000 Subject: [PATCH 5/7] disable consistency check for cagra filter test --- .../raft/neighbors/detail/refine_device.cuh | 15 ++++++--------- cpp/test/neighbors/ann_cagra.cuh | 4 +++- cpp/test/neighbors/ann_utils.cuh | 18 +++++++++++++----- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/refine_device.cuh b/cpp/include/raft/neighbors/detail/refine_device.cuh index b3c4559b9a..bdc29ca121 100644 --- a/cpp/include/raft/neighbors/detail/refine_device.cuh +++ b/cpp/include/raft/neighbors/detail/refine_device.cuh @@ -91,17 +91,14 @@ void refine_device(raft::resources const& handle, // the neighbor ids will be computed in uint32_t as offset rmm::device_uvector neighbors_uint32_buf(0, resource::get_cuda_stream(handle)); - // Number of samples for each query - rmm::device_uvector num_samples(n_queries, resource::get_cuda_stream(handle)); - // Offsets per probe for each query + // Offsets per probe for each query [n_queries] as n_probes = 1 rmm::device_uvector chunk_index(n_queries, resource::get_cuda_stream(handle)); - ivf::detail::calc_chunk_indices::configure(1, n_queries)( - refinement_index.list_sizes().data_handle(), - fake_coarse_idx.data(), - chunk_index.data(), - num_samples.data(), - resource::get_cuda_stream(handle)); + // we know that each cluster has exactly n_candidates entries + thrust::fill(resource::get_thrust_policy(handle), + chunk_index.data(), + chunk_index.data() + n_queries, + uint32_t(n_candidates)); uint32_t* neighbors_uint32 = nullptr; if constexpr (sizeof(idx_t) == sizeof(uint32_t)) { diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index a111de0762..cdd9d69241 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -549,6 +549,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { EXPECT_FALSE(unacceptable_node); double min_recall = ps.min_recall; + // TODO(mfoerster): re-enable uniquenes test EXPECT_TRUE(eval_neighbours(indices_naive, indices_Cagra, distances_naive, @@ -556,7 +557,8 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { ps.n_queries, ps.k, 0.003, - min_recall)); + min_recall, + false)); EXPECT_TRUE(eval_distances(handle_, database.data(), search_queries.data(), diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index 91017bb4e0..6be2ac7fc7 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -173,7 +173,7 @@ auto check_unique_indices(const std::vector& actual_idx, size_t rows, size_t unique_indices.insert(act_idx); } else { return testing::AssertionFailure() - << "Duplicated index " << act_idx << " for query " << i << "! "; + << "Duplicated index " << act_idx << " at k " << k << " for query " << i << "! "; } } } @@ -186,7 +186,8 @@ auto eval_recall(const std::vector& expected_idx, size_t rows, size_t cols, double eps, - double min_recall) -> testing::AssertionResult + double min_recall, + bool test_unique = true) -> testing::AssertionResult { auto [actual_recall, match_count, total_count] = calc_recall(expected_idx, actual_idx, rows, cols); @@ -203,7 +204,10 @@ auto eval_recall(const std::vector& expected_idx, << "actual recall (" << actual_recall << ") is lower than the minimum expected recall (" << min_recall << "); eps = " << eps << ". "; } - return check_unique_indices(actual_idx, rows, cols); + if (test_unique) + return check_unique_indices(actual_idx, rows, cols); + else + return testing::AssertionSuccess(); } /** Overload of calc_recall to account for distances @@ -251,7 +255,8 @@ auto eval_neighbours(const std::vector& expected_idx, size_t rows, size_t cols, double eps, - double min_recall) -> testing::AssertionResult + double min_recall, + bool test_unique = true) -> testing::AssertionResult { auto [actual_recall, match_count, total_count] = calc_recall(expected_idx, actual_idx, expected_dist, actual_dist, rows, cols, eps); @@ -268,7 +273,10 @@ auto eval_neighbours(const std::vector& expected_idx, << "actual recall (" << actual_recall << ") is lower than the minimum expected recall (" << min_recall << "); eps = " << eps << ". "; } - return check_unique_indices(actual_idx, rows, cols); + if (test_unique) + return check_unique_indices(actual_idx, rows, cols); + else + return testing::AssertionSuccess(); } template From ae087349b7b5087c0ae9fe6c9baca8e34a8d4777 Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Wed, 20 Mar 2024 08:34:42 +0000 Subject: [PATCH 6/7] deactivate 2nd check --- cpp/test/neighbors/ann_cagra.cuh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index cdd9d69241..2caaf7c01e 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -670,6 +670,7 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { } double min_recall = ps.min_recall; + // TODO(mfoerster): re-enable uniquenes test EXPECT_TRUE(eval_neighbours(indices_naive, indices_Cagra, distances_naive, @@ -677,7 +678,8 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { ps.n_queries, ps.k, 0.003, - min_recall)); + min_recall), + false); EXPECT_TRUE(eval_distances(handle_, database.data(), search_queries.data(), From f14d944c26816e61e60a265e740d24c01feea0ad Mon Sep 17 00:00:00 2001 From: Malte Foerster Date: Wed, 20 Mar 2024 08:52:34 +0000 Subject: [PATCH 7/7] fix typo --- cpp/test/neighbors/ann_cagra.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 2caaf7c01e..7278f71a24 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -678,8 +678,8 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { ps.n_queries, ps.k, 0.003, - min_recall), - false); + min_recall, + false)); EXPECT_TRUE(eval_distances(handle_, database.data(), search_queries.data(),