Skip to content

Commit

Permalink
C API for renumbering the samples (#3724)
Browse files Browse the repository at this point in the history
Definition of C API for renumbering of sampling results.

Authors:
  - Chuck Hastings (https://github.com/ChuckHastings)
  - Seunghwa Kang (https://github.com/seunghwak)
  - Alex Barghi (https://github.com/alexbarghi-nv)

Approvers:
  - Alex Barghi (https://github.com/alexbarghi-nv)
  - Seunghwa Kang (https://github.com/seunghwak)

URL: #3724
  • Loading branch information
ChuckHastings authored Jul 25, 2023
1 parent 4a1537b commit b09763d
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 30 deletions.
26 changes: 26 additions & 0 deletions cpp/include/cugraph_c/sampling_algorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,14 @@ typedef enum cugraph_prior_sources_behavior_t {
cugraph_error_code_t cugraph_sampling_options_create(cugraph_sampling_options_t** options,
cugraph_error_t** error);

/**
* @brief Set flag to renumber results
*
* @param options - opaque pointer to the sampling options
* @param value - Boolean value to assign to the option
*/
void cugraph_sampling_set_renumber_results(cugraph_sampling_options_t* options, bool_t value);

/**
* @brief Set flag to sample with_replacement
*
Expand Down Expand Up @@ -446,6 +454,24 @@ cugraph_type_erased_device_array_view_t* cugraph_sample_result_get_index(
cugraph_type_erased_device_array_view_t* cugraph_sample_result_get_offsets(
const cugraph_sample_result_t* result);

/**
* @brief Get the renumber map
*
* @param [in] result The result from a sampling algorithm
* @return type erased array pointing to the renumber map
*/
cugraph_type_erased_device_array_view_t* cugraph_sample_result_get_renumber_map(
const cugraph_sample_result_t* result);

/**
* @brief Get the renumber map offsets
*
* @param [in] result The result from a sampling algorithm
* @return type erased array pointing to the renumber map offsets
*/
cugraph_type_erased_device_array_view_t* cugraph_sample_result_get_renumber_map_offsets(
const cugraph_sample_result_t* result);

/**
* @brief Free a sampling result
*
Expand Down
54 changes: 53 additions & 1 deletion cpp/src/c_api/uniform_neighbor_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct cugraph_sampling_options_t {
bool_t return_hops_{FALSE};
prior_sources_behavior_t prior_sources_behavior_{prior_sources_behavior_t::DEFAULT};
bool_t dedupe_sources_{FALSE};
bool_t renumber_results_{FALSE};
};

struct cugraph_sample_result_t {
Expand All @@ -48,6 +49,8 @@ struct cugraph_sample_result_t {
cugraph_type_erased_device_array_t* hop_{nullptr};
cugraph_type_erased_device_array_t* label_{nullptr};
cugraph_type_erased_device_array_t* offsets_{nullptr};
cugraph_type_erased_device_array_t* renumber_map_{nullptr};
cugraph_type_erased_device_array_t* renumber_map_offsets_{nullptr};
};

} // namespace c_api
Expand Down Expand Up @@ -226,6 +229,22 @@ struct uniform_neighbor_sampling_functor : public cugraph::c_api::abstract_funct
vertex_partition_lasts,
do_expensive_check_);

std::optional<rmm::device_uvector<vertex_t>> renumber_map{std::nullopt};
std::optional<rmm::device_uvector<size_t>> renumber_map_offsets{std::nullopt};

if (options_.renumber_results_) {
std::tie(src, dst, renumber_map, renumber_map_offsets) = cugraph::renumber_sampled_edgelist(
handle_,
std::move(src),
hop ? std::make_optional(raft::device_span<int32_t const>{hop->data(), hop->size()})
: std::nullopt,
std::move(dst),
std::make_optional(std::make_tuple(
raft::device_span<label_t const>{edge_label->data(), edge_label->size()},
raft::device_span<size_t const>{offsets->data(), offsets->size()})),
do_expensive_check_);
}

result_ = new cugraph::c_api::cugraph_sample_result_t{
new cugraph::c_api::cugraph_type_erased_device_array_t(src, graph_->vertex_type_),
new cugraph::c_api::cugraph_type_erased_device_array_t(dst, graph_->vertex_type_),
Expand All @@ -242,7 +261,13 @@ struct uniform_neighbor_sampling_functor : public cugraph::c_api::abstract_funct
? new cugraph::c_api::cugraph_type_erased_device_array_t(edge_label.value(), INT32)
: nullptr,
(offsets) ? new cugraph::c_api::cugraph_type_erased_device_array_t(offsets.value(), SIZE_T)
: nullptr};
: nullptr,
(renumber_map) ? new cugraph::c_api::cugraph_type_erased_device_array_t(
renumber_map.value(), graph_->vertex_type_)
: nullptr,
(renumber_map_offsets) ? new cugraph::c_api::cugraph_type_erased_device_array_t(
renumber_map_offsets.value(), SIZE_T)
: nullptr};
}
}
};
Expand All @@ -263,6 +288,13 @@ extern "C" cugraph_error_code_t cugraph_sampling_options_create(
return CUGRAPH_SUCCESS;
}

extern "C" void cugraph_sampling_set_renumber_results(cugraph_sampling_options_t* options,
bool_t value)
{
auto internal_pointer = reinterpret_cast<cugraph::c_api::cugraph_sampling_options_t*>(options);
internal_pointer->renumber_results_ = value;
}

extern "C" void cugraph_sampling_set_with_replacement(cugraph_sampling_options_t* options,
bool_t value)
{
Expand Down Expand Up @@ -386,6 +418,26 @@ extern "C" cugraph_type_erased_device_array_view_t* cugraph_sample_result_get_of
internal_pointer->offsets_->view());
}

extern "C" cugraph_type_erased_device_array_view_t* cugraph_sample_result_get_renumber_map(
const cugraph_sample_result_t* result)
{
auto internal_pointer = reinterpret_cast<cugraph::c_api::cugraph_sample_result_t const*>(result);
return internal_pointer->renumber_map_ == nullptr
? NULL
: reinterpret_cast<cugraph_type_erased_device_array_view_t*>(
internal_pointer->renumber_map_->view());
}

extern "C" cugraph_type_erased_device_array_view_t* cugraph_sample_result_get_renumber_map_offsets(
const cugraph_sample_result_t* result)
{
auto internal_pointer = reinterpret_cast<cugraph::c_api::cugraph_sample_result_t const*>(result);
return internal_pointer->renumber_map_ == nullptr
? NULL
: reinterpret_cast<cugraph_type_erased_device_array_view_t*>(
internal_pointer->renumber_map_offsets_->view());
}

extern "C" cugraph_error_code_t cugraph_test_uniform_neighborhood_sample_result_create(
const cugraph_resource_handle_t* handle,
const cugraph_type_erased_device_array_view_t* srcs,
Expand Down
Loading

0 comments on commit b09763d

Please sign in to comment.