Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Outdated] Scaling workspace resources #2194

Closed
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
4e5d842
[Discussion] Scaling workspace resources
achirkin Feb 22, 2024
952c6b9
Avoid using cudaMemGetInfo and adjust the default workspace size
achirkin Feb 23, 2024
5bf0a76
Add the new resource to the cagra_build.cuh
achirkin Feb 23, 2024
26ae6fc
Use the memory workspaces everywhere across ANN methods
achirkin Feb 23, 2024
9ed2314
Merge branch 'branch-24.04' into fea-scaled-workspace-resource
achirkin Feb 27, 2024
b68faf2
Merge branch 'branch-24.04' into fea-scaled-workspace-resource
achirkin Feb 28, 2024
7dad403
Merge branch 'branch-24.04' into fea-scaled-workspace-resource
achirkin Mar 1, 2024
71a3530
Merge branch 'branch-24.04' into fea-scaled-workspace-resource
achirkin Mar 6, 2024
494cc6f
Merge branch 'branch-24.04' into fea-scaled-workspace-resource
achirkin Mar 6, 2024
2ec49fb
Merge branch 'branch-24.04' into fea-scaled-workspace-resource
achirkin Mar 13, 2024
e0b45c0
Fix style
achirkin Mar 13, 2024
cf7cbd3
Merge branch 'branch-24.04' into fea-scaled-workspace-resource
achirkin Mar 13, 2024
3f57a63
Merge branch 'branch-24.04' into fea-scaled-workspace-resource
achirkin Apr 4, 2024
3abcdda
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
achirkin Apr 4, 2024
d11ef67
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
cjnolet Apr 11, 2024
bf088d1
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
achirkin Apr 16, 2024
070a9b6
Merge remote-tracking branch 'rapidsai/branch-24.06' into fea-scaled-…
achirkin Apr 26, 2024
2516692
Style & naming fixes
achirkin Apr 26, 2024
600bf5c
Style & naming fixes
achirkin Apr 26, 2024
9b858f7
Make sure the default resource is not accidentally used instead of th…
achirkin Apr 26, 2024
ddad5fc
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
achirkin Apr 29, 2024
d7569aa
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
cjnolet Apr 30, 2024
2db0322
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
achirkin May 3, 2024
d475161
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
achirkin May 8, 2024
ec9469a
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
achirkin May 15, 2024
e048ea0
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
achirkin May 16, 2024
d463188
Fix an error coming from automatic merge
achirkin May 16, 2024
94893bc
Merge branch 'branch-24.06' into fea-scaled-workspace-resource
achirkin Jun 6, 2024
5d5674c
Merge branch 'branch-24.08' into fea-scaled-workspace-resource
achirkin Jun 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions cpp/bench/ann/src/raft/raft_ann_bench_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/failure_callback_resource_adaptor.hpp>
#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

#include <memory>
Expand Down Expand Up @@ -70,13 +71,14 @@ inline auto rmm_oom_callback(std::size_t bytes, void*) -> bool
*/
class shared_raft_resources {
public:
using pool_mr_type = rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>;
using mr_type = rmm::mr::failure_callback_resource_adaptor<pool_mr_type>;
using pool_mr_type = rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource>;
using mr_type = rmm::mr::failure_callback_resource_adaptor<pool_mr_type>;
using large_mr_type = rmm::mr::managed_memory_resource;

shared_raft_resources()
try : orig_resource_{rmm::mr::get_current_device_resource()},
pool_resource_(orig_resource_, 1024 * 1024 * 1024ull),
resource_(&pool_resource_, rmm_oom_callback, nullptr) {
resource_(&pool_resource_, rmm_oom_callback, nullptr), large_mr_() {
rmm::mr::set_current_device_resource(&resource_);
} catch (const std::exception& e) {
auto cuda_status = cudaGetLastError();
Expand All @@ -99,10 +101,16 @@ class shared_raft_resources {

~shared_raft_resources() noexcept { rmm::mr::set_current_device_resource(orig_resource_); }

auto get_large_memory_resource() noexcept
{
return static_cast<rmm::mr::device_memory_resource*>(&large_mr_);
}

private:
rmm::mr::device_memory_resource* orig_resource_;
pool_mr_type pool_resource_;
mr_type resource_;
large_mr_type large_mr_;
};

/**
Expand All @@ -123,6 +131,12 @@ class configured_raft_resources {
explicit configured_raft_resources(const std::shared_ptr<shared_raft_resources>& shared_res)
: shared_res_{shared_res}, res_{rmm::cuda_stream_view(get_stream_from_global_pool())}
{
// set the large workspace resource to the raft handle, but without the deleter
// (this resource is managed by the shared_res).
raft::resource::set_large_workspace_resource(
res_,
std::shared_ptr<rmm::mr::device_memory_resource>(shared_res_->get_large_memory_resource(),
raft::void_op{}));
}

/** Default constructor creates all resources anew. */
Expand Down
52 changes: 50 additions & 2 deletions cpp/include/raft/core/resource/device_memory_resource.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,6 +35,16 @@ namespace raft::resource {
* @{
*/

class memory_resource : public resource {
achirkin marked this conversation as resolved.
Show resolved Hide resolved
public:
explicit memory_resource(std::shared_ptr<rmm::mr::device_memory_resource> mr) : mr_(mr) {}
~memory_resource() override = default;
auto get_resource() -> void* override { return mr_.get(); }

private:
std::shared_ptr<rmm::mr::device_memory_resource> mr_;
};

class limiting_memory_resource : public resource {
public:
limiting_memory_resource(std::shared_ptr<rmm::mr::device_memory_resource> mr,
Expand Down Expand Up @@ -66,6 +76,29 @@ class limiting_memory_resource : public resource {
}
};

/**
* Factory that knows how to construct a specific raft::resource to populate
* the resources instance.
*/
class large_workspace_resource_factory : public resource_factory {
public:
explicit large_workspace_resource_factory(
std::shared_ptr<rmm::mr::device_memory_resource> mr = {nullptr})
: mr_{mr ? mr
: std::shared_ptr<rmm::mr::device_memory_resource>{
rmm::mr::get_current_device_resource(), void_op{}}}
{
}
auto get_resource_type() -> resource_type override
{
return resource_type::LARGE_MEMORY_RESOURCE;
}
auto make_resource() -> resource* override { return new memory_resource(mr_); }

private:
std::shared_ptr<rmm::mr::device_memory_resource> mr_;
};

/**
* Factory that knows how to construct a specific raft::resource to populate
* the resources instance.
Expand Down Expand Up @@ -144,7 +177,7 @@ class workspace_resource_factory : public resource_factory {
// Note, the workspace does not claim all this memory from the start, so it's still usable by
// the main resource as well.
// This limit is merely an order for algorithm internals to plan the batching accordingly.
return total_size / 2;
return total_size / 4;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The OOM errors we have seen with CAGRA were related to workspace pool grabbing all this place. What about limiting to a much smaller workspace size? (E.g. faiss has 1.5 GiB limit).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is an option, but so far I think it's not necessary. I also think it can hurt performance a little by reducing the batch size in places like ivf_pq::search or ivf_pq::extend.

With the current proposal, ann-bench executable (as a user of raft) set these resources:

  • default - pool on top of device memory
  • limited workspace - shares the same pool with default
  • large workspace - managed memory (without pooling)

Hence the dataset/user allocations do not conflict for the same memory with the workspace (as they both use the same pool). At the same time, large temporary allocations (such as the cagra graph on device) use the managed memory and free it as soon as the algorithm finishes.

}
};

Expand Down Expand Up @@ -241,6 +274,21 @@ inline void set_workspace_to_global_resource(
workspace_resource_factory::default_plain_resource(), allocation_limit, std::nullopt));
};

inline auto get_large_workspace_resource(resources const& res) -> rmm::mr::device_memory_resource*
{
if (!res.has_resource_factory(resource_type::LARGE_MEMORY_RESOURCE)) {
res.add_resource_factory(std::make_shared<large_workspace_resource_factory>());
}
return res.get_resource<rmm::mr::device_memory_resource>(resource_type::LARGE_MEMORY_RESOURCE);
};

inline void set_large_workspace_resource(resources const& res,
std::shared_ptr<rmm::mr::device_memory_resource> mr = {
nullptr})
{
res.add_resource_factory(std::make_shared<large_workspace_resource_factory>(mr));
};

/** @} */

} // namespace raft::resource
3 changes: 2 additions & 1 deletion cpp/include/raft/core/resource/resource_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ enum resource_type {
STREAM_VIEW, // view of a cuda stream or a placeholder in
// CUDA-free builds
THRUST_POLICY, // thrust execution policy
WORKSPACE_RESOURCE, // rmm device memory resource
WORKSPACE_RESOURCE, // rmm device memory resource for small temporary allocations
CUBLASLT_HANDLE, // cublasLt handle
CUSTOM, // runtime-shared default-constructible resource
LARGE_MEMORY_RESOURCE, // rmm device memory resource for somewhat large temporary allocations
achirkin marked this conversation as resolved.
Show resolved Hide resolved

LAST_KEY // reserved for the last key
};
Expand Down
42 changes: 28 additions & 14 deletions cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,22 @@ void build_knn_graph(raft::resources const& res,
search_params->lut_dtype = CUDA_R_8U;
search_params->internal_distance_dtype = CUDA_R_32F;
}
const auto top_k = node_degree + 1;
uint32_t gpu_top_k = node_degree * refine_rate.value_or(2.0f);
gpu_top_k = std::min<IdxT>(std::max(gpu_top_k, top_k), dataset.extent(0));
const auto num_queries = dataset.extent(0);
const auto max_batch_size = 1024;
const auto top_k = node_degree + 1;
uint32_t gpu_top_k = node_degree * refine_rate.value_or(2.0f);
gpu_top_k = std::min<IdxT>(std::max(gpu_top_k, top_k), dataset.extent(0));
const auto num_queries = dataset.extent(0);
rmm::mr::device_memory_resource* workspace_mr = raft::resource::get_workspace_resource(res);
// Heuristic: how much of the workspace we can spare for the queries.
// The rest is going to be used by ivf_pq::search
const auto workspace_queries_bytes = raft::resource::get_workspace_free_bytes(res) / 5;
auto max_batch_size =
std::min<size_t>(workspace_queries_bytes / sizeof(DataT) / dataset.extent(1), 4096);
// Heuristic: if the workspace is too small for a decent batch size, switch to use the large
// resource with a default batch size.
if (max_batch_size < 128) {
achirkin marked this conversation as resolved.
Show resolved Hide resolved
max_batch_size = 1024;
workspace_mr = raft::resource::get_large_workspace_resource(res);
}
RAFT_LOG_DEBUG(
"IVF-PQ search node_degree: %d, top_k: %d, gpu_top_k: %d, max_batch_size:: %d, n_probes: %u",
node_degree,
Expand All @@ -110,12 +121,17 @@ void build_knn_graph(raft::resources const& res,
max_batch_size,
search_params->n_probes);

auto distances = raft::make_device_matrix<float, int64_t>(res, max_batch_size, gpu_top_k);
auto neighbors = raft::make_device_matrix<int64_t, int64_t>(res, max_batch_size, gpu_top_k);
auto refined_distances = raft::make_device_matrix<float, int64_t>(res, max_batch_size, top_k);
auto refined_neighbors = raft::make_device_matrix<int64_t, int64_t>(res, max_batch_size, top_k);
auto neighbors_host = raft::make_host_matrix<int64_t, int64_t>(max_batch_size, gpu_top_k);
auto queries_host = raft::make_host_matrix<DataT, int64_t>(max_batch_size, dataset.extent(1));
rmm::mr::device_memory_resource* large_mr = raft::resource::get_large_workspace_resource(res);
auto distances = raft::make_device_mdarray<float>(
res, large_mr, raft::make_extents<int64_t>(max_batch_size, gpu_top_k));
auto neighbors = raft::make_device_mdarray<int64_t>(
res, large_mr, raft::make_extents<int64_t>(max_batch_size, gpu_top_k));
auto refined_distances = raft::make_device_mdarray<float>(
res, large_mr, raft::make_extents<int64_t>(max_batch_size, top_k));
auto refined_neighbors = raft::make_device_mdarray<int64_t>(
res, large_mr, raft::make_extents<int64_t>(max_batch_size, top_k));
auto neighbors_host = raft::make_host_matrix<int64_t, int64_t>(max_batch_size, gpu_top_k);
auto queries_host = raft::make_host_matrix<DataT, int64_t>(max_batch_size, dataset.extent(1));
auto refined_neighbors_host = raft::make_host_matrix<int64_t, int64_t>(max_batch_size, top_k);
auto refined_distances_host = raft::make_host_matrix<float, int64_t>(max_batch_size, top_k);

Expand All @@ -124,15 +140,13 @@ void build_knn_graph(raft::resources const& res,
bool first = true;
const auto start_clock = std::chrono::system_clock::now();

rmm::mr::device_memory_resource* device_memory = raft::resource::get_workspace_resource(res);

raft::spatial::knn::detail::utils::batch_load_iterator<DataT> vec_batches(
dataset.data_handle(),
dataset.extent(0),
dataset.extent(1),
max_batch_size,
resource::get_cuda_stream(res),
device_memory);
workspace_mr);

size_t next_report_offset = 0;
size_t d_report_offset = dataset.extent(0) / 100; // Report progress in 1% steps.
Expand Down
24 changes: 16 additions & 8 deletions cpp/include/raft/neighbors/detail/cagra/graph_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,19 @@ void sort_knn_graph(raft::resources const& res,
const uint32_t input_graph_degree = knn_graph.extent(1);
IdxT* const input_graph_ptr = knn_graph.data_handle();

auto d_input_graph = raft::make_device_matrix<IdxT, int64_t>(res, graph_size, input_graph_degree);
auto large_tmp_mr = resource::get_large_workspace_resource(res);

auto d_input_graph = raft::make_device_mdarray<IdxT>(
res, large_tmp_mr, raft::make_extents<int64_t>(graph_size, input_graph_degree));

//
// Sorting kNN graph
//
const double time_sort_start = cur_time();
RAFT_LOG_DEBUG("# Sorting kNN Graph on GPUs ");

auto d_dataset = raft::make_device_matrix<DataT, int64_t>(res, dataset_size, dataset_dim);
auto d_dataset = raft::make_device_mdarray<DataT>(
res, large_tmp_mr, raft::make_extents<int64_t>(dataset_size, dataset_dim));
raft::copy(d_dataset.data_handle(),
dataset_ptr,
dataset_size * dataset_dim,
Expand Down Expand Up @@ -325,6 +329,7 @@ void optimize(raft::resources const& res,
{
RAFT_LOG_DEBUG(
"# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1));
auto large_tmp_mr = resource::get_large_workspace_resource(res);

RAFT_EXPECTS(knn_graph.extent(0) == new_graph.extent(0),
"Each input array is expected to have the same number of rows");
Expand All @@ -340,15 +345,16 @@ void optimize(raft::resources const& res,
//
// Prune kNN graph
//
auto d_detour_count =
raft::make_device_matrix<uint8_t, int64_t>(res, graph_size, input_graph_degree);
auto d_detour_count = raft::make_device_mdarray<uint8_t>(
res, large_tmp_mr, raft::make_extents<int64_t>(graph_size, input_graph_degree));

RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(),
0xff,
graph_size * input_graph_degree * sizeof(uint8_t),
resource::get_cuda_stream(res)));

auto d_num_no_detour_edges = raft::make_device_vector<uint32_t, int64_t>(res, graph_size);
auto d_num_no_detour_edges = raft::make_device_mdarray<uint32_t>(
res, large_tmp_mr, raft::make_extents<int64_t>(graph_size));
RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(),
0x00,
graph_size * sizeof(uint32_t),
Expand Down Expand Up @@ -468,14 +474,16 @@ void optimize(raft::resources const& res,
graph_size * output_graph_degree * sizeof(IdxT),
resource::get_cuda_stream(res)));

auto d_rev_graph_count = raft::make_device_vector<uint32_t, int64_t>(res, graph_size);
auto d_rev_graph_count = raft::make_device_mdarray<uint32_t>(
res, large_tmp_mr, raft::make_extents<int64_t>(graph_size));
RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph_count.data_handle(),
0x00,
graph_size * sizeof(uint32_t),
resource::get_cuda_stream(res)));

auto dest_nodes = raft::make_host_vector<IdxT, int64_t>(graph_size);
auto d_dest_nodes = raft::make_device_vector<IdxT, int64_t>(res, graph_size);
auto dest_nodes = raft::make_host_vector<IdxT, int64_t>(graph_size);
auto d_dest_nodes =
raft::make_device_mdarray<IdxT>(res, large_tmp_mr, raft::make_extents<int64_t>(graph_size));

for (uint64_t k = 0; k < output_graph_degree; k++) {
#pragma omp parallel for
Expand Down
7 changes: 5 additions & 2 deletions cpp/include/raft/neighbors/detail/cagra/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,11 @@ class device_matrix_view_from_host {
device_ptr = reinterpret_cast<T*>(attr.devicePointer);
if (device_ptr == NULL) {
// allocate memory and copy over
device_mem_.emplace(
raft::make_device_matrix<T, IdxT>(res, host_view.extent(0), host_view.extent(1)));
// NB: We use the temporary "large" workspace resource here; this structure is supposed to
// live on stack and not returned to a user.
// The user may opt to set this resource to managed memory to allow large allocations.
device_mem_.emplace(make_device_mdarray<T, IdxT>(
res, resource::get_large_workspace_resource(res), host_view.extents()));
raft::copy(device_mem_->data_handle(),
host_view.data_handle(),
host_view.extent(0) * host_view.extent(1),
Expand Down
13 changes: 9 additions & 4 deletions cpp/include/raft/neighbors/detail/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <raft/core/nvtx.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/core/resources.hpp>
#include <raft/linalg/add.cuh>
#include <raft/linalg/map.cuh>
Expand Down Expand Up @@ -180,7 +181,8 @@ void extend(raft::resources const& handle,
RAFT_EXPECTS(new_indices != nullptr || index->size() == 0,
"You must pass data indices when the index is non-empty.");

auto new_labels = raft::make_device_vector<LabelT, IdxT>(handle, n_rows);
auto new_labels = raft::make_device_mdarray<LabelT>(
handle, resource::get_large_workspace_resource(handle), raft::make_extents<IdxT>(n_rows));
raft::cluster::kmeans_balanced_params kmeans_params;
kmeans_params.metric = index->metric();
auto orig_centroids_view =
Expand Down Expand Up @@ -211,7 +213,8 @@ void extend(raft::resources const& handle,
}

auto* list_sizes_ptr = index->list_sizes().data_handle();
auto old_list_sizes_dev = raft::make_device_vector<uint32_t, IdxT>(handle, n_lists);
auto old_list_sizes_dev = raft::make_device_mdarray<uint32_t>(
handle, resource::get_workspace_resource(handle), raft::make_extents<IdxT>(n_lists));
copy(old_list_sizes_dev.data_handle(), list_sizes_ptr, n_lists, stream);

// Calculate the centers and sizes on the new data, starting from the original values
Expand Down Expand Up @@ -367,7 +370,8 @@ inline auto build(raft::resources const& handle,
auto trainset_ratio = std::max<size_t>(
1, n_rows / std::max<size_t>(params.kmeans_trainset_fraction * n_rows, index.n_lists()));
auto n_rows_train = n_rows / trainset_ratio;
rmm::device_uvector<T> trainset(n_rows_train * index.dim(), stream);
rmm::device_uvector<T> trainset(
n_rows_train * index.dim(), stream, raft::resource::get_large_workspace_resource(handle));
// TODO: a proper sampling
RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(),
sizeof(T) * index.dim(),
Expand Down Expand Up @@ -427,7 +431,8 @@ inline void fill_refinement_index(raft::resources const& handle,
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"ivf_flat::fill_refinement_index(%zu, %u)", size_t(n_queries));

rmm::device_uvector<LabelT> new_labels(n_queries * n_candidates, stream);
rmm::device_uvector<LabelT> new_labels(
n_queries * n_candidates, stream, raft::resource::get_workspace_resource(handle));
auto new_labels_view =
raft::make_device_vector_view<LabelT, IdxT>(new_labels.data(), n_queries * n_candidates);
linalg::map_offset(
Expand Down
Loading
Loading