Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Outdated] Scaling workspace resources #2194

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
4e5d842
[Discussion] Scaling workspace resources
achirkin Feb 22, 2024
952c6b9
Avoid using cudaMemGetInfo and adjust the default workspace size
achirkin Feb 23, 2024
5bf0a76
Add the new resource to the cagra_build.cuh
achirkin Feb 23, 2024
26ae6fc
Use the memory workspaces everywhere across ANN methods
achirkin Feb 23, 2024
9ed2314
Merge branch 'branch-24.04' into fea-scaled-workspace-resource
achirkin Feb 27, 2024
b68faf2
Merge branch 'branch-24.04' into fea-scaled-workspace-resource
achirkin Feb 28, 2024
7dad403
Merge branch 'branch-24.04' into fea-scaled-workspace-resource
achirkin Mar 1, 2024
71a3530
Merge branch 'branch-24.04' into fea-scaled-workspace-resource
achirkin Mar 6, 2024
494cc6f
Merge branch 'branch-24.04' into fea-scaled-workspace-resource
achirkin Mar 6, 2024
2ec49fb
Merge branch 'branch-24.04' into fea-scaled-workspace-resource
achirkin Mar 13, 2024
e0b45c0
Fix style
achirkin Mar 13, 2024
cf7cbd3
Merge branch 'branch-24.04' into fea-scaled-workspace-resource
achirkin Mar 13, 2024
3f57a63
Merge branch 'branch-24.04' into fea-scaled-workspace-resource
achirkin Apr 4, 2024
3abcdda
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
achirkin Apr 4, 2024
d11ef67
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
cjnolet Apr 11, 2024
bf088d1
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
achirkin Apr 16, 2024
070a9b6
Merge remote-tracking branch 'rapidsai/branch-24.06' into fea-scaled-…
achirkin Apr 26, 2024
2516692
Style & naming fixes
achirkin Apr 26, 2024
600bf5c
Style & naming fixes
achirkin Apr 26, 2024
9b858f7
Make sure the default resource is not accidentally used instead of th…
achirkin Apr 26, 2024
ddad5fc
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
achirkin Apr 29, 2024
d7569aa
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
cjnolet Apr 30, 2024
2db0322
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
achirkin May 3, 2024
d475161
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
achirkin May 8, 2024
ec9469a
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
achirkin May 15, 2024
e048ea0
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
achirkin May 16, 2024
d463188
Fix an error coming from automatic merge
achirkin May 16, 2024
94893bc
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
achirkin Jun 6, 2024
5d5674c
Merge branch 'branch-24.08' into fea-scaled-workspace-resource
achirkin Jun 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 32 additions & 14 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,26 @@ void build_knn_graph(raft::resources const& res,
search_params->lut_dtype = CUDA_R_8U;
search_params->internal_distance_dtype = CUDA_R_32F;
}
const auto top_k = node_degree + 1;
uint32_t gpu_top_k = node_degree * refine_rate.value_or(2.0f);
gpu_top_k = std::min<IdxT>(std::max(gpu_top_k, top_k), dataset.extent(0));
const auto num_queries = dataset.extent(0);
const auto max_batch_size = 1024;
const auto top_k = node_degree + 1;
uint32_t gpu_top_k = node_degree * refine_rate.value_or(2.0f);
gpu_top_k = std::min<IdxT>(std::max(gpu_top_k, top_k), dataset.extent(0));
const auto num_queries = dataset.extent(0);
rmm::device_async_resource_ref workspace_mr = raft::resource::get_workspace_resource(res);

constexpr size_t kDefaultBatchSize = 1024;
constexpr size_t kMaxBatchSize = 4096; // No more perf beyond this
constexpr size_t kMinBatchSize = 128; // Too slow if smaller
// Heuristic: how much of the workspace we can spare for the queries.
// The rest is going to be used by ivf_pq::search
const auto workspace_queries_bytes = raft::resource::get_workspace_free_bytes(res) / 5;
auto max_batch_size =
std::min<size_t>(workspace_queries_bytes / sizeof(DataT) / dataset.extent(1), kMaxBatchSize);
// Heuristic: if the workspace is too small for a decent batch size, switch to use the large
// resource with a default batch size.
if (max_batch_size < kMinBatchSize) {
max_batch_size = kDefaultBatchSize;
workspace_mr = raft::resource::get_large_workspace_resource(res);
}
RAFT_LOG_DEBUG(
"IVF-PQ search node_degree: %d, top_k: %d, gpu_top_k: %d, max_batch_size:: %d, n_probes: %u",
node_degree,
Expand All @@ -106,12 +121,17 @@ void build_knn_graph(raft::resources const& res,
max_batch_size,
search_params->n_probes);

auto distances = raft::make_device_matrix<float, int64_t>(res, max_batch_size, gpu_top_k);
auto neighbors = raft::make_device_matrix<int64_t, int64_t>(res, max_batch_size, gpu_top_k);
auto refined_distances = raft::make_device_matrix<float, int64_t>(res, max_batch_size, top_k);
auto refined_neighbors = raft::make_device_matrix<int64_t, int64_t>(res, max_batch_size, top_k);
auto neighbors_host = raft::make_host_matrix<int64_t, int64_t>(max_batch_size, gpu_top_k);
auto queries_host = raft::make_host_matrix<DataT, int64_t>(max_batch_size, dataset.extent(1));
rmm::device_async_resource_ref large_mr = raft::resource::get_large_workspace_resource(res);
auto distances = raft::make_device_mdarray<float>(
res, large_mr, raft::make_extents<int64_t>(max_batch_size, gpu_top_k));
auto neighbors = raft::make_device_mdarray<int64_t>(
res, large_mr, raft::make_extents<int64_t>(max_batch_size, gpu_top_k));
auto refined_distances = raft::make_device_mdarray<float>(
res, large_mr, raft::make_extents<int64_t>(max_batch_size, top_k));
auto refined_neighbors = raft::make_device_mdarray<int64_t>(
res, large_mr, raft::make_extents<int64_t>(max_batch_size, top_k));
auto neighbors_host = raft::make_host_matrix<int64_t, int64_t>(max_batch_size, gpu_top_k);
auto queries_host = raft::make_host_matrix<DataT, int64_t>(max_batch_size, dataset.extent(1));
auto refined_neighbors_host = raft::make_host_matrix<int64_t, int64_t>(max_batch_size, top_k);
auto refined_distances_host = raft::make_host_matrix<float, int64_t>(max_batch_size, top_k);

Expand All @@ -120,15 +140,13 @@ void build_knn_graph(raft::resources const& res,
bool first = true;
const auto start_clock = std::chrono::system_clock::now();

rmm::device_async_resource_ref device_memory = raft::resource::get_workspace_resource(res);

raft::spatial::knn::detail::utils::batch_load_iterator<DataT> vec_batches(
dataset.data_handle(),
dataset.extent(0),
dataset.extent(1),
max_batch_size,
resource::get_cuda_stream(res),
device_memory);
workspace_mr);

size_t next_report_offset = 0;
size_t d_report_offset = dataset.extent(0) / 100; // Report progress in 1% steps.
Expand Down
24 changes: 16 additions & 8 deletions cpp/include/raft/neighbors/detail/cagra/graph_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,19 @@ void sort_knn_graph(raft::resources const& res,
const uint32_t input_graph_degree = knn_graph.extent(1);
IdxT* const input_graph_ptr = knn_graph.data_handle();

auto d_input_graph = raft::make_device_matrix<IdxT, int64_t>(res, graph_size, input_graph_degree);
auto large_tmp_mr = resource::get_large_workspace_resource(res);

auto d_input_graph = raft::make_device_mdarray<IdxT>(
res, large_tmp_mr, raft::make_extents<int64_t>(graph_size, input_graph_degree));

//
// Sorting kNN graph
//
const double time_sort_start = cur_time();
RAFT_LOG_DEBUG("# Sorting kNN Graph on GPUs ");

auto d_dataset = raft::make_device_matrix<DataT, int64_t>(res, dataset_size, dataset_dim);
auto d_dataset = raft::make_device_mdarray<DataT>(
res, large_tmp_mr, raft::make_extents<int64_t>(dataset_size, dataset_dim));
raft::copy(d_dataset.data_handle(),
dataset_ptr,
dataset_size * dataset_dim,
Expand Down Expand Up @@ -325,6 +329,7 @@ void optimize(raft::resources const& res,
{
RAFT_LOG_DEBUG(
"# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1));
auto large_tmp_mr = resource::get_large_workspace_resource(res);

RAFT_EXPECTS(knn_graph.extent(0) == new_graph.extent(0),
"Each input array is expected to have the same number of rows");
Expand All @@ -340,15 +345,16 @@ void optimize(raft::resources const& res,
//
// Prune kNN graph
//
auto d_detour_count =
raft::make_device_matrix<uint8_t, int64_t>(res, graph_size, input_graph_degree);
auto d_detour_count = raft::make_device_mdarray<uint8_t>(
res, large_tmp_mr, raft::make_extents<int64_t>(graph_size, input_graph_degree));

RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(),
0xff,
graph_size * input_graph_degree * sizeof(uint8_t),
resource::get_cuda_stream(res)));

auto d_num_no_detour_edges = raft::make_device_vector<uint32_t, int64_t>(res, graph_size);
auto d_num_no_detour_edges = raft::make_device_mdarray<uint32_t>(
res, large_tmp_mr, raft::make_extents<int64_t>(graph_size));
RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(),
0x00,
graph_size * sizeof(uint32_t),
Expand Down Expand Up @@ -468,14 +474,16 @@ void optimize(raft::resources const& res,
graph_size * output_graph_degree * sizeof(IdxT),
resource::get_cuda_stream(res)));

auto d_rev_graph_count = raft::make_device_vector<uint32_t, int64_t>(res, graph_size);
auto d_rev_graph_count = raft::make_device_mdarray<uint32_t>(
res, large_tmp_mr, raft::make_extents<int64_t>(graph_size));
RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph_count.data_handle(),
0x00,
graph_size * sizeof(uint32_t),
resource::get_cuda_stream(res)));

auto dest_nodes = raft::make_host_vector<IdxT, int64_t>(graph_size);
auto d_dest_nodes = raft::make_device_vector<IdxT, int64_t>(res, graph_size);
auto dest_nodes = raft::make_host_vector<IdxT, int64_t>(graph_size);
auto d_dest_nodes =
raft::make_device_mdarray<IdxT>(res, large_tmp_mr, raft::make_extents<int64_t>(graph_size));

for (uint64_t k = 0; k < output_graph_degree; k++) {
#pragma omp parallel for
Expand Down
7 changes: 5 additions & 2 deletions cpp/include/raft/neighbors/detail/cagra/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,11 @@ class device_matrix_view_from_host {
device_ptr = reinterpret_cast<T*>(attr.devicePointer);
if (device_ptr == NULL) {
// allocate memory and copy over
device_mem_.emplace(
raft::make_device_matrix<T, IdxT>(res, host_view.extent(0), host_view.extent(1)));
// NB: We use the temporary "large" workspace resource here; this structure is supposed to
// live on stack and not returned to a user.
// The user may opt to set this resource to managed memory to allow large allocations.
device_mem_.emplace(make_device_mdarray<T, IdxT>(
res, resource::get_large_workspace_resource(res), host_view.extents()));
raft::copy(device_mem_->data_handle(),
host_view.data_handle(),
host_view.extent(0) * host_view.extent(1),
Expand Down
13 changes: 9 additions & 4 deletions cpp/include/raft/neighbors/detail/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <raft/core/nvtx.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/add.cuh>
#include <raft/linalg/map.cuh>
Expand Down Expand Up @@ -180,7 +181,8 @@ void extend(raft::resources const& handle,
RAFT_EXPECTS(new_indices != nullptr || index->size() == 0,
"You must pass data indices when the index is non-empty.");

auto new_labels = raft::make_device_vector<LabelT, IdxT>(handle, n_rows);
auto new_labels = raft::make_device_mdarray<LabelT>(
handle, resource::get_large_workspace_resource(handle), raft::make_extents<IdxT>(n_rows));
raft::cluster::kmeans_balanced_params kmeans_params;
kmeans_params.metric = index->metric();
auto orig_centroids_view =
Expand Down Expand Up @@ -211,7 +213,8 @@ void extend(raft::resources const& handle,
}

auto* list_sizes_ptr = index->list_sizes().data_handle();
auto old_list_sizes_dev = raft::make_device_vector<uint32_t, IdxT>(handle, n_lists);
auto old_list_sizes_dev = raft::make_device_mdarray<uint32_t>(
handle, resource::get_workspace_resource(handle), raft::make_extents<IdxT>(n_lists));
copy(old_list_sizes_dev.data_handle(), list_sizes_ptr, n_lists, stream);

// Calculate the centers and sizes on the new data, starting from the original values
Expand Down Expand Up @@ -367,7 +370,8 @@ inline auto build(raft::resources const& handle,
auto trainset_ratio = std::max<size_t>(
1, n_rows / std::max<size_t>(params.kmeans_trainset_fraction * n_rows, index.n_lists()));
auto n_rows_train = n_rows / trainset_ratio;
rmm::device_uvector<T> trainset(n_rows_train * index.dim(), stream);
rmm::device_uvector<T> trainset(
n_rows_train * index.dim(), stream, raft::resource::get_large_workspace_resource(handle));
// TODO: a proper sampling
RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(),
sizeof(T) * index.dim(),
Expand Down Expand Up @@ -427,7 +431,8 @@ inline void fill_refinement_index(raft::resources const& handle,
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"ivf_flat::fill_refinement_index(%zu, %u)", size_t(n_queries));

rmm::device_uvector<LabelT> new_labels(n_queries * n_candidates, stream);
rmm::device_uvector<LabelT> new_labels(
n_queries * n_candidates, stream, raft::resource::get_workspace_resource(handle));
auto new_labels_view =
raft::make_device_vector_view<LabelT, IdxT>(new_labels.data(), n_queries * n_candidates);
linalg::map_offset(
Expand Down
6 changes: 3 additions & 3 deletions cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ void search(raft::resources const& handle,
uint32_t k,
IdxT* neighbors,
float* distances,
rmm::device_async_resource_ref mr,
IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT;
std::optional<rmm::device_async_resource_ref> mr = std::nullopt,
IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT;

} // namespace raft::neighbors::ivf_flat::detail

Expand All @@ -56,7 +56,7 @@ void search(raft::resources const& handle,
uint32_t k, \
IdxT* neighbors, \
float* distances, \
rmm::device_async_resource_ref mr, \
std::optional<rmm::device_async_resource_ref> mr, \
IvfSampleFilterT sample_filter)

instantiate_raft_neighbors_ivf_flat_detail_search(
Expand Down
16 changes: 10 additions & 6 deletions cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <raft/core/logger.hpp> // RAFT_LOG_TRACE
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_memory_resource.hpp> // workspace resource
#include <raft/core/resources.hpp> // raft::resources
#include <raft/distance/distance_types.hpp> // is_min_close, DistanceType
#include <raft/linalg/gemm.cuh> // raft::linalg::gemm
Expand Down Expand Up @@ -276,8 +277,8 @@ inline void search(raft::resources const& handle,
uint32_t k,
IdxT* neighbors,
float* distances,
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource(),
IvfSampleFilterT sample_filter = IvfSampleFilterT())
std::optional<rmm::device_async_resource_ref> mr = std::nullopt,
IvfSampleFilterT sample_filter = IvfSampleFilterT())
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim());
Expand All @@ -297,15 +298,18 @@ inline void search(raft::resources const& handle,
}

// a batch size heuristic: try to keep the workspace within the specified size
constexpr uint64_t kExpectedWsSize = 1024 * 1024 * 1024;
uint64_t max_ws_size = std::min(resource::get_workspace_free_bytes(handle), kExpectedWsSize);
uint64_t expected_ws_size = 1024 * 1024 * 1024ull;
if (!mr.has_value()) {
mr.emplace(resource::get_workspace_resource(handle));
expected_ws_size = resource::get_workspace_free_bytes(handle);
}

uint64_t ws_size_per_query = 4ull * (2 * n_probes + index.n_lists() + index.dim() + 1) +
(manage_local_topk ? ((sizeof(IdxT) + 4) * n_probes * k)
: (4ull * (max_samples + n_probes + 1)));

const uint32_t max_queries =
std::min<uint32_t>(n_queries, raft::div_rounding_up_safe(max_ws_size, ws_size_per_query));
std::min<uint32_t>(n_queries, raft::div_rounding_up_safe(expected_ws_size, ws_size_per_query));

for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) {
uint32_t queries_batch = min(max_queries, n_queries - offset_q);
Expand All @@ -321,7 +325,7 @@ inline void search(raft::resources const& handle,
raft::distance::is_min_close(index.metric()),
neighbors + offset_q * k,
distances + offset_q * k,
mr,
mr.value(),
sample_filter);
}
}
Expand Down
Loading
Loading