diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index 40dcf68e68..59bfb92862 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -93,11 +93,26 @@ void build_knn_graph(raft::resources const& res, search_params->lut_dtype = CUDA_R_8U; search_params->internal_distance_dtype = CUDA_R_32F; } - const auto top_k = node_degree + 1; - uint32_t gpu_top_k = node_degree * refine_rate.value_or(2.0f); - gpu_top_k = std::min(std::max(gpu_top_k, top_k), dataset.extent(0)); - const auto num_queries = dataset.extent(0); - const auto max_batch_size = 1024; + const auto top_k = node_degree + 1; + uint32_t gpu_top_k = node_degree * refine_rate.value_or(2.0f); + gpu_top_k = std::min(std::max(gpu_top_k, top_k), dataset.extent(0)); + const auto num_queries = dataset.extent(0); + rmm::device_async_resource_ref workspace_mr = raft::resource::get_workspace_resource(res); + + constexpr size_t kDefaultBatchSize = 1024; + constexpr size_t kMaxBatchSize = 4096; // No more perf beyond this + constexpr size_t kMinBatchSize = 128; // Too slow if smaller + // Heuristic: how much of the workspace we can spare for the queries. + // The rest is going to be used by ivf_pq::search + const auto workspace_queries_bytes = raft::resource::get_workspace_free_bytes(res) / 5; + auto max_batch_size = + std::min(workspace_queries_bytes / sizeof(DataT) / dataset.extent(1), kMaxBatchSize); + // Heuristic: if the workspace is too small for a decent batch size, switch to use the large + // resource with a default batch size. + if (max_batch_size < kMinBatchSize) { + max_batch_size = kDefaultBatchSize; + workspace_mr = raft::resource::get_large_workspace_resource(res); + } RAFT_LOG_DEBUG( "IVF-PQ search node_degree: %d, top_k: %d, gpu_top_k: %d, max_batch_size:: %d, n_probes: %u", node_degree, @@ -106,12 +121,17 @@ void build_knn_graph(raft::resources const& res, max_batch_size, search_params->n_probes); - auto distances = raft::make_device_matrix(res, max_batch_size, gpu_top_k); - auto neighbors = raft::make_device_matrix(res, max_batch_size, gpu_top_k); - auto refined_distances = raft::make_device_matrix(res, max_batch_size, top_k); - auto refined_neighbors = raft::make_device_matrix(res, max_batch_size, top_k); - auto neighbors_host = raft::make_host_matrix(max_batch_size, gpu_top_k); - auto queries_host = raft::make_host_matrix(max_batch_size, dataset.extent(1)); + rmm::device_async_resource_ref large_mr = raft::resource::get_large_workspace_resource(res); + auto distances = raft::make_device_mdarray( + res, large_mr, raft::make_extents(max_batch_size, gpu_top_k)); + auto neighbors = raft::make_device_mdarray( + res, large_mr, raft::make_extents(max_batch_size, gpu_top_k)); + auto refined_distances = raft::make_device_mdarray( + res, large_mr, raft::make_extents(max_batch_size, top_k)); + auto refined_neighbors = raft::make_device_mdarray( + res, large_mr, raft::make_extents(max_batch_size, top_k)); + auto neighbors_host = raft::make_host_matrix(max_batch_size, gpu_top_k); + auto queries_host = raft::make_host_matrix(max_batch_size, dataset.extent(1)); auto refined_neighbors_host = raft::make_host_matrix(max_batch_size, top_k); auto refined_distances_host = raft::make_host_matrix(max_batch_size, top_k); @@ -120,15 +140,13 @@ void build_knn_graph(raft::resources const& res, bool first = true; const auto start_clock = std::chrono::system_clock::now(); - rmm::device_async_resource_ref device_memory = raft::resource::get_workspace_resource(res); - raft::spatial::knn::detail::utils::batch_load_iterator vec_batches( dataset.data_handle(), dataset.extent(0), dataset.extent(1), max_batch_size, resource::get_cuda_stream(res), - device_memory); + workspace_mr); size_t next_report_offset = 0; size_t d_report_offset = dataset.extent(0) / 100; // Report progress in 1% steps. diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index e4e3ea3512..f307e8b149 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -246,7 +246,10 @@ void sort_knn_graph(raft::resources const& res, const uint32_t input_graph_degree = knn_graph.extent(1); IdxT* const input_graph_ptr = knn_graph.data_handle(); - auto d_input_graph = raft::make_device_matrix(res, graph_size, input_graph_degree); + auto large_tmp_mr = resource::get_large_workspace_resource(res); + + auto d_input_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, input_graph_degree)); // // Sorting kNN graph @@ -254,7 +257,8 @@ void sort_knn_graph(raft::resources const& res, const double time_sort_start = cur_time(); RAFT_LOG_DEBUG("# Sorting kNN Graph on GPUs "); - auto d_dataset = raft::make_device_matrix(res, dataset_size, dataset_dim); + auto d_dataset = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(dataset_size, dataset_dim)); raft::copy(d_dataset.data_handle(), dataset_ptr, dataset_size * dataset_dim, @@ -325,6 +329,7 @@ void optimize(raft::resources const& res, { RAFT_LOG_DEBUG( "# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1)); + auto large_tmp_mr = resource::get_large_workspace_resource(res); RAFT_EXPECTS(knn_graph.extent(0) == new_graph.extent(0), "Each input array is expected to have the same number of rows"); @@ -340,15 +345,16 @@ void optimize(raft::resources const& res, // // Prune kNN graph // - auto d_detour_count = - raft::make_device_matrix(res, graph_size, input_graph_degree); + auto d_detour_count = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, input_graph_degree)); RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), 0xff, graph_size * input_graph_degree * sizeof(uint8_t), resource::get_cuda_stream(res))); - auto d_num_no_detour_edges = raft::make_device_vector(res, graph_size); + auto d_num_no_detour_edges = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size)); RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(), 0x00, graph_size * sizeof(uint32_t), @@ -468,14 +474,16 @@ void optimize(raft::resources const& res, graph_size * output_graph_degree * sizeof(IdxT), resource::get_cuda_stream(res))); - auto d_rev_graph_count = raft::make_device_vector(res, graph_size); + auto d_rev_graph_count = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size)); RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph_count.data_handle(), 0x00, graph_size * sizeof(uint32_t), resource::get_cuda_stream(res))); - auto dest_nodes = raft::make_host_vector(graph_size); - auto d_dest_nodes = raft::make_device_vector(res, graph_size); + auto dest_nodes = raft::make_host_vector(graph_size); + auto d_dest_nodes = + raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); for (uint64_t k = 0; k < output_graph_degree; k++) { #pragma omp parallel for diff --git a/cpp/include/raft/neighbors/detail/cagra/utils.hpp b/cpp/include/raft/neighbors/detail/cagra/utils.hpp index ece95a7cb7..38de11abf4 100644 --- a/cpp/include/raft/neighbors/detail/cagra/utils.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/utils.hpp @@ -184,8 +184,11 @@ class device_matrix_view_from_host { device_ptr = reinterpret_cast(attr.devicePointer); if (device_ptr == NULL) { // allocate memory and copy over - device_mem_.emplace( - raft::make_device_matrix(res, host_view.extent(0), host_view.extent(1))); + // NB: We use the temporary "large" workspace resource here; this structure is supposed to + // live on stack and not returned to a user. + // The user may opt to set this resource to managed memory to allow large allocations. + device_mem_.emplace(make_device_mdarray( + res, resource::get_large_workspace_resource(res), host_view.extents())); raft::copy(device_mem_->data_handle(), host_view.data_handle(), host_view.extent(0) * host_view.extent(1), diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh index 55184cc615..643d1e8102 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -180,7 +181,8 @@ void extend(raft::resources const& handle, RAFT_EXPECTS(new_indices != nullptr || index->size() == 0, "You must pass data indices when the index is non-empty."); - auto new_labels = raft::make_device_vector(handle, n_rows); + auto new_labels = raft::make_device_mdarray( + handle, resource::get_large_workspace_resource(handle), raft::make_extents(n_rows)); raft::cluster::kmeans_balanced_params kmeans_params; kmeans_params.metric = index->metric(); auto orig_centroids_view = @@ -211,7 +213,8 @@ void extend(raft::resources const& handle, } auto* list_sizes_ptr = index->list_sizes().data_handle(); - auto old_list_sizes_dev = raft::make_device_vector(handle, n_lists); + auto old_list_sizes_dev = raft::make_device_mdarray( + handle, resource::get_workspace_resource(handle), raft::make_extents(n_lists)); copy(old_list_sizes_dev.data_handle(), list_sizes_ptr, n_lists, stream); // Calculate the centers and sizes on the new data, starting from the original values @@ -367,7 +370,8 @@ inline auto build(raft::resources const& handle, auto trainset_ratio = std::max( 1, n_rows / std::max(params.kmeans_trainset_fraction * n_rows, index.n_lists())); auto n_rows_train = n_rows / trainset_ratio; - rmm::device_uvector trainset(n_rows_train * index.dim(), stream); + rmm::device_uvector trainset( + n_rows_train * index.dim(), stream, raft::resource::get_large_workspace_resource(handle)); // TODO: a proper sampling RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(), sizeof(T) * index.dim(), @@ -427,7 +431,8 @@ inline void fill_refinement_index(raft::resources const& handle, common::nvtx::range fun_scope( "ivf_flat::fill_refinement_index(%zu, %u)", size_t(n_queries)); - rmm::device_uvector new_labels(n_queries * n_candidates, stream); + rmm::device_uvector new_labels( + n_queries * n_candidates, stream, raft::resource::get_workspace_resource(handle)); auto new_labels_view = raft::make_device_vector_view(new_labels.data(), n_queries * n_candidates); linalg::map_offset( diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh index c14b0e810f..43ff5e43e5 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh @@ -39,8 +39,8 @@ void search(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::device_async_resource_ref mr, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; + std::optional mr = std::nullopt, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; } // namespace raft::neighbors::ivf_flat::detail @@ -56,7 +56,7 @@ void search(raft::resources const& handle, uint32_t k, \ IdxT* neighbors, \ float* distances, \ - rmm::device_async_resource_ref mr, \ + std::optional mr, \ IvfSampleFilterT sample_filter) instantiate_raft_neighbors_ivf_flat_detail_search( diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index 388dd60f14..9b6513d7fd 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -18,6 +18,7 @@ #include // RAFT_LOG_TRACE #include +#include // workspace resource #include // raft::resources #include // is_min_close, DistanceType #include // raft::linalg::gemm @@ -276,8 +277,8 @@ inline void search(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource(), - IvfSampleFilterT sample_filter = IvfSampleFilterT()) + std::optional mr = std::nullopt, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) { common::nvtx::range fun_scope( "ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim()); @@ -297,15 +298,18 @@ inline void search(raft::resources const& handle, } // a batch size heuristic: try to keep the workspace within the specified size - constexpr uint64_t kExpectedWsSize = 1024 * 1024 * 1024; - uint64_t max_ws_size = std::min(resource::get_workspace_free_bytes(handle), kExpectedWsSize); + uint64_t expected_ws_size = 1024 * 1024 * 1024ull; + if (!mr.has_value()) { + mr.emplace(resource::get_workspace_resource(handle)); + expected_ws_size = resource::get_workspace_free_bytes(handle); + } uint64_t ws_size_per_query = 4ull * (2 * n_probes + index.n_lists() + index.dim() + 1) + (manage_local_topk ? ((sizeof(IdxT) + 4) * n_probes * k) : (4ull * (max_samples + n_probes + 1))); const uint32_t max_queries = - std::min(n_queries, raft::div_rounding_up_safe(max_ws_size, ws_size_per_query)); + std::min(n_queries, raft::div_rounding_up_safe(expected_ws_size, ws_size_per_query)); for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { uint32_t queries_batch = min(max_queries, n_queries - offset_q); @@ -321,7 +325,7 @@ inline void search(raft::resources const& handle, raft::distance::is_min_close(index.metric()), neighbors + offset_q * k, distances + offset_q * k, - mr, + mr.value(), sample_filter); } } diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index 24574642ef..1fe6dd899c 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -48,6 +48,7 @@ #include #include +#include #include #include @@ -1518,6 +1519,8 @@ void extend(raft::resources const& handle, "Unsupported data type"); rmm::device_async_resource_ref device_memory = raft::resource::get_workspace_resource(handle); + rmm::device_async_resource_ref large_memory = + raft::resource::get_large_workspace_resource(handle); // The spec defines how the clusters look like auto spec = list_spec{ @@ -1532,12 +1535,22 @@ void extend(raft::resources const& handle, n_rows + (kIndexGroupSize - 1) * std::min(n_clusters, n_rows)); // Available device memory - size_t free_mem, total_mem; - RAFT_CUDA_TRY(cudaMemGetInfo(&free_mem, &total_mem)); - + size_t free_mem = raft::resource::get_workspace_free_bytes(handle); + + // We try to use the workspace memory by default here. + // If the workspace limit is too small, we change the resource for batch data to the + // `large_workspace_resource`, which does not have the explicit allocation limit. The user may opt + // to populate the `large_workspace_resource` memory resource with managed memory for easier + // scaling. + rmm::device_async_resource_ref labels_mr = device_memory; + rmm::device_async_resource_ref batches_mr = device_memory; + if (n_rows * (index->dim() * sizeof(T) + index->pq_dim() + sizeof(IdxT) + sizeof(uint32_t)) > + free_mem) { + labels_mr = large_memory; + } // Allocate a buffer for the new labels (classifying the new data) - rmm::device_uvector new_data_labels(n_rows, stream, device_memory); - free_mem -= sizeof(uint32_t) * n_rows; + rmm::device_uvector new_data_labels(n_rows, stream, labels_mr); + if (labels_mr == device_memory) { free_mem -= sizeof(uint32_t) * n_rows; } // Calculate the batch size for the input data if it's not accessible directly from the device constexpr size_t kReasonableMaxBatchSize = 65536; @@ -1566,13 +1579,19 @@ void extend(raft::resources const& handle, while (size_factor * max_batch_size > free_mem && max_batch_size > 128) { max_batch_size >>= 1; } - // If we're keeping the batches in device memory, update the available mem tracker. - free_mem -= size_factor * max_batch_size; + if (size_factor * max_batch_size > free_mem) { + // if that still doesn't fit, resort to the UVM + batches_mr = large_memory; + max_batch_size = kReasonableMaxBatchSize; + } else { + // If we're keeping the batches in device memory, update the available mem tracker. + free_mem -= size_factor * max_batch_size; + } } // Predict the cluster labels for the new data, in batches if necessary utils::batch_load_iterator vec_batches( - new_vectors, n_rows, index->dim(), max_batch_size, stream, device_memory); + new_vectors, n_rows, index->dim(), max_batch_size, stream, batches_mr); // Release the placeholder memory, because we don't intend to allocate any more long-living // temporary buffers before we allocate the index data. // This memory could potentially speed up UVM accesses, if any. @@ -1645,7 +1664,7 @@ void extend(raft::resources const& handle, // By this point, the index state is updated and valid except it doesn't contain the new data // Fill the extended index with the new data (possibly, in batches) utils::batch_load_iterator idx_batches( - new_indices, n_rows, 1, max_batch_size, stream, device_memory); + new_indices, n_rows, 1, max_batch_size, stream, batches_mr); for (const auto& vec_batch : vec_batches) { const auto& idx_batch = *idx_batches++; process_and_fill_codes(handle, @@ -1656,7 +1675,7 @@ void extend(raft::resources const& handle, : std::variant(IdxT(idx_batch.offset())), new_data_labels.data() + vec_batch.offset(), IdxT(vec_batch.size()), - device_memory); + batches_mr); } } @@ -1709,11 +1728,21 @@ auto build(raft::resources const& handle, size_t n_rows_train = n_rows / trainset_ratio; auto* device_memory = resource::get_workspace_resource(handle); - rmm::mr::managed_memory_resource managed_memory_upstream; + rmm::mr::managed_memory_resource managed_memory; + + // If the trainset is small enough to comfortably fit into device memory, put it there. + // Otherwise, use the managed memory. + constexpr size_t kTolerableRatio = 4; + rmm::device_async_resource_ref big_memory_resource = + resource::get_large_workspace_resource(handle); + if (sizeof(float) * n_rows_train * index.dim() * kTolerableRatio < + resource::get_workspace_free_bytes(handle)) { + big_memory_resource = device_memory; + } // Besides just sampling, we transform the input dataset into floats to make it easier // to use gemm operations from cublas. - rmm::device_uvector trainset(n_rows_train * index.dim(), stream, device_memory); + rmm::device_uvector trainset(n_rows_train * index.dim(), stream, big_memory_resource); // TODO: a proper sampling if constexpr (std::is_same_v) { RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(), @@ -1782,7 +1811,7 @@ auto build(raft::resources const& handle, handle, kmeans_params, trainset_const_view, centers_view, utils::mapping{}); // Trainset labels are needed for training PQ codebooks - rmm::device_uvector labels(n_rows_train, stream, device_memory); + rmm::device_uvector labels(n_rows_train, stream, big_memory_resource); auto centers_const_view = raft::make_device_matrix_view( cluster_centers, index.n_lists(), index.dim()); auto labels_view = @@ -1812,7 +1841,7 @@ auto build(raft::resources const& handle, trainset.data(), labels.data(), params.kmeans_n_iters, - &managed_memory_upstream); + &managed_memory); break; case codebook_gen::PER_CLUSTER: train_per_cluster(handle, @@ -1821,7 +1850,7 @@ auto build(raft::resources const& handle, trainset.data(), labels.data(), params.kmeans_n_iters, - &managed_memory_upstream); + &managed_memory); break; default: RAFT_FAIL("Unreachable code"); } diff --git a/cpp/include/raft/neighbors/ivf_flat-ext.cuh b/cpp/include/raft/neighbors/ivf_flat-ext.cuh index 12ab0dc3a6..7d5e357307 100644 --- a/cpp/include/raft/neighbors/ivf_flat-ext.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-ext.cuh @@ -109,7 +109,7 @@ void search_with_filtering(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::device_async_resource_ref mr, + std::optional mr = std::nullopt, IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; template @@ -121,7 +121,7 @@ void search(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::device_async_resource_ref mr) RAFT_EXPLICIT; + std::optional mr = std::nullopt) RAFT_EXPLICIT; template void search_with_filtering(raft::resources const& handle, @@ -240,7 +240,7 @@ instantiate_raft_neighbors_ivf_flat_extend(uint8_t, int64_t); uint32_t k, \ IdxT* neighbors, \ float* distances, \ - rmm::device_async_resource_ref mr); \ + std::optional mr); \ \ extern template void raft::neighbors::ivf_flat::search( \ raft::resources const& handle, \ diff --git a/cpp/include/raft/neighbors/ivf_flat-inl.cuh b/cpp/include/raft/neighbors/ivf_flat-inl.cuh index ea7cff7060..439110946f 100644 --- a/cpp/include/raft/neighbors/ivf_flat-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-inl.cuh @@ -462,8 +462,8 @@ void search_with_filtering(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::device_async_resource_ref mr, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) + std::optional mr = std::nullopt, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) { raft::neighbors::ivf_flat::detail::search( handle, params, index, queries, n_queries, k, neighbors, distances, mr, sample_filter); @@ -520,7 +520,7 @@ void search(raft::resources const& handle, uint32_t k, IdxT* neighbors, float* distances, - rmm::device_async_resource_ref mr) + std::optional mr = std::nullopt) { raft::neighbors::ivf_flat::detail::search(handle, params, diff --git a/cpp/src/neighbors/detail/ivf_flat_search.cu b/cpp/src/neighbors/detail/ivf_flat_search.cu index 336bea19b6..054042cee9 100644 --- a/cpp/src/neighbors/detail/ivf_flat_search.cu +++ b/cpp/src/neighbors/detail/ivf_flat_search.cu @@ -29,7 +29,7 @@ uint32_t k, \ IdxT* neighbors, \ float* distances, \ - rmm::device_async_resource_ref mr, \ + std::optional mr, \ IvfSampleFilterT sample_filter) instantiate_raft_neighbors_ivf_flat_detail_search( diff --git a/cpp/src/neighbors/ivf_flat_00_generate.py b/cpp/src/neighbors/ivf_flat_00_generate.py index 7b55cad4de..8abd72d737 100644 --- a/cpp/src/neighbors/ivf_flat_00_generate.py +++ b/cpp/src/neighbors/ivf_flat_00_generate.py @@ -114,7 +114,7 @@ template auto raft::neighbors::ivf_flat::extend( \\ const raft::resources& handle, \\ raft::host_matrix_view new_vectors, \\ - std::optional> new_indices, \\ + std::optional> new_indices, \\ const raft::neighbors::ivf_flat::index& idx) \\ -> raft::neighbors::ivf_flat::index; \\ \\ @@ -122,7 +122,7 @@ raft::resources const& handle, \\ raft::host_matrix_view new_vectors, \\ std::optional> new_indices, \\ - raft::neighbors::ivf_flat::index* index); + raft::neighbors::ivf_flat::index* index); """ search_macro = """ @@ -136,7 +136,7 @@ uint32_t k, \\ IdxT* neighbors, \\ float* distances, \\ - rmm::device_async_resource_ref mr); \\ + std::optional mr); \\ \\ template void raft::neighbors::ivf_flat::search( \\ raft::resources const& handle, \\ diff --git a/cpp/src/neighbors/ivf_flat_search_float_int64_t.cu b/cpp/src/neighbors/ivf_flat_search_float_int64_t.cu index e5cfe14e3f..b435583bae 100644 --- a/cpp/src/neighbors/ivf_flat_search_float_int64_t.cu +++ b/cpp/src/neighbors/ivf_flat_search_float_int64_t.cu @@ -25,8 +25,6 @@ #include -#include - #define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ template void raft::neighbors::ivf_flat::search( \ raft::resources const& handle, \ @@ -37,7 +35,7 @@ uint32_t k, \ IdxT* neighbors, \ float* distances, \ - rmm::device_async_resource_ref mr); \ + std::optional mr); \ \ template void raft::neighbors::ivf_flat::search( \ raft::resources const& handle, \ diff --git a/cpp/src/neighbors/ivf_flat_search_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_search_int8_t_int64_t.cu index 35792a78a8..fb2d1bcb43 100644 --- a/cpp/src/neighbors/ivf_flat_search_int8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_flat_search_int8_t_int64_t.cu @@ -25,8 +25,6 @@ #include -#include - #define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ template void raft::neighbors::ivf_flat::search( \ raft::resources const& handle, \ @@ -37,7 +35,7 @@ uint32_t k, \ IdxT* neighbors, \ float* distances, \ - rmm::device_async_resource_ref mr); \ + std::optional mr); \ \ template void raft::neighbors::ivf_flat::search( \ raft::resources const& handle, \ diff --git a/cpp/src/neighbors/ivf_flat_search_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_search_uint8_t_int64_t.cu index 663e52cb99..0a0a148f84 100644 --- a/cpp/src/neighbors/ivf_flat_search_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_flat_search_uint8_t_int64_t.cu @@ -25,8 +25,6 @@ #include -#include - #define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ template void raft::neighbors::ivf_flat::search( \ raft::resources const& handle, \ @@ -37,7 +35,7 @@ uint32_t k, \ IdxT* neighbors, \ float* distances, \ - rmm::device_async_resource_ref mr); \ + std::optional mr); \ \ template void raft::neighbors::ivf_flat::search( \ raft::resources const& handle, \