Skip to content

Commit

Permalink
Merge branch 'branch-24.10' into nccl-clique-in-raft-handle
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Sep 18, 2024
2 parents 1892f0a + e60cd7d commit 3b3925a
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ repos:
# Explicitly specify the pyproject.toml at the repo root, not per-project.
args: ["--config", "pyproject.toml"]
- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
rev: 7.1.1
hooks:
- id: flake8
args: ["--config=.flake8"]
Expand Down
6 changes: 3 additions & 3 deletions cpp/include/raft/core/bitmap.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ _RAFT_HOST_DEVICE inline bool bitmap_view<bitmap_t, index_t>::test(const index_t
}

template <typename bitmap_t, typename index_t>
_RAFT_HOST_DEVICE void bitmap_view<bitmap_t, index_t>::set(const index_t row,
const index_t col,
bool new_value) const
_RAFT_DEVICE void bitmap_view<bitmap_t, index_t>::set(const index_t row,
const index_t col,
bool new_value) const
{
set(row * cols_ + col, new_value);
}
Expand Down
24 changes: 6 additions & 18 deletions cpp/include/raft/core/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ _RAFT_HOST_DEVICE bool bitset_view<bitset_t, index_t>::operator[](const index_t
}

template <typename bitset_t, typename index_t>
_RAFT_HOST_DEVICE void bitset_view<bitset_t, index_t>::set(const index_t sample_index,
bool set_value) const
_RAFT_DEVICE void bitset_view<bitset_t, index_t>::set(const index_t sample_index,
bool set_value) const
{
const index_t bit_element = sample_index / bitset_element_size;
const index_t bit_index = sample_index % bitset_element_size;
Expand All @@ -60,18 +60,12 @@ _RAFT_HOST_DEVICE void bitset_view<bitset_t, index_t>::set(const index_t sample_
}
}

template <typename bitset_t, typename index_t>
_RAFT_HOST_DEVICE inline index_t bitset_view<bitset_t, index_t>::n_elements() const
{
return raft::ceildiv(bitset_len_, bitset_element_size);
}

template <typename bitset_t, typename index_t>
bitset<bitset_t, index_t>::bitset(const raft::resources& res,
raft::device_vector_view<const index_t, index_t> mask_index,
index_t bitset_len,
bool default_value)
: bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)),
: bitset_{std::size_t(raft::div_rounding_up_safe(bitset_len, bitset_element_size)),
raft::resource::get_cuda_stream(res)},
bitset_len_{bitset_len}
{
Expand All @@ -83,26 +77,20 @@ template <typename bitset_t, typename index_t>
bitset<bitset_t, index_t>::bitset(const raft::resources& res,
index_t bitset_len,
bool default_value)
: bitset_{std::size_t(raft::ceildiv(bitset_len, bitset_element_size)),
: bitset_{std::size_t(raft::div_rounding_up_safe(bitset_len, bitset_element_size)),
raft::resource::get_cuda_stream(res)},
bitset_len_{bitset_len}
{
reset(res, default_value);
}

template <typename bitset_t, typename index_t>
index_t bitset<bitset_t, index_t>::n_elements() const
{
return raft::ceildiv(bitset_len_, bitset_element_size);
}

template <typename bitset_t, typename index_t>
void bitset<bitset_t, index_t>::resize(const raft::resources& res,
index_t new_bitset_len,
bool default_value)
{
auto old_size = raft::ceildiv(bitset_len_, bitset_element_size);
auto new_size = raft::ceildiv(new_bitset_len, bitset_element_size);
auto old_size = raft::div_rounding_up_safe(bitset_len_, bitset_element_size);
auto new_size = raft::div_rounding_up_safe(new_bitset_len, bitset_element_size);
bitset_.resize(new_size);
bitset_len_ = new_bitset_len;
if (old_size < new_size) {
Expand Down
11 changes: 9 additions & 2 deletions cpp/include/raft/core/bitset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <raft/core/device_mdarray.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/util/integer_utils.hpp>

namespace raft::core {
/**
Expand Down Expand Up @@ -89,7 +90,10 @@ struct bitset_view {
/**
* @brief Get the number of elements used by the bitset representation.
*/
inline _RAFT_HOST_DEVICE auto n_elements() const -> index_t;
inline _RAFT_HOST_DEVICE auto n_elements() const -> index_t
{
return raft::div_rounding_up_safe(bitset_len_, bitset_element_size);
}

inline auto to_mdspan() -> raft::device_vector_view<bitset_t, index_t>
{
Expand Down Expand Up @@ -173,7 +177,10 @@ struct bitset {
/**
* @brief Get the number of elements used by the bitset representation.
*/
inline auto n_elements() const -> index_t;
inline auto n_elements() const -> index_t
{
return raft::div_rounding_up_safe(bitset_len_, bitset_element_size);
}

/** @brief Get an mdspan view of the current bitset */
inline auto to_mdspan() -> raft::device_vector_view<bitset_t, index_t>
Expand Down
10 changes: 5 additions & 5 deletions cpp/include/raft/sparse/op/detail/sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ struct TupleComp {
* @param vals vals array from coo matrix
* @param stream: cuda stream to use
*/
template <typename T>
void coo_sort(int m, int n, int nnz, int* rows, int* cols, T* vals, cudaStream_t stream)
template <typename T, typename IdxT = int>
void coo_sort(IdxT m, IdxT n, IdxT nnz, IdxT* rows, IdxT* cols, T* vals, cudaStream_t stream)
{
auto coo_indices = thrust::make_zip_iterator(thrust::make_tuple(rows, cols));

Expand All @@ -83,10 +83,10 @@ void coo_sort(int m, int n, int nnz, int* rows, int* cols, T* vals, cudaStream_t
* @param in: COO to sort by row
* @param stream: the cuda stream to use
*/
template <typename T>
void coo_sort(COO<T>* const in, cudaStream_t stream)
template <typename T, typename IdxT = int>
void coo_sort(COO<T, IdxT>* const in, cudaStream_t stream)
{
coo_sort<T>(in->n_rows, in->n_cols, in->nnz, in->rows(), in->cols(), in->vals(), stream);
coo_sort<T, IdxT>(in->n_rows, in->n_cols, in->nnz, in->rows(), in->cols(), in->vals(), stream);
}

/**
Expand Down
14 changes: 7 additions & 7 deletions cpp/include/raft/sparse/op/sort.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,8 +37,8 @@ namespace op {
* @param vals vals array from coo matrix
* @param stream: cuda stream to use
*/
template <typename T>
void coo_sort(int m, int n, int nnz, int* rows, int* cols, T* vals, cudaStream_t stream)
template <typename T, typename IdxT = int>
void coo_sort(IdxT m, IdxT n, IdxT nnz, IdxT* rows, IdxT* cols, T* vals, cudaStream_t stream)
{
detail::coo_sort(m, n, nnz, rows, cols, vals, stream);
}
Expand All @@ -49,10 +49,10 @@ void coo_sort(int m, int n, int nnz, int* rows, int* cols, T* vals, cudaStream_t
* @param in: COO to sort by row
* @param stream: the cuda stream to use
*/
template <typename T>
void coo_sort(COO<T>* const in, cudaStream_t stream)
template <typename T, typename IdxT = int>
void coo_sort(COO<T, IdxT>* const in, cudaStream_t stream)
{
coo_sort<T>(in->n_rows, in->n_cols, in->nnz, in->rows(), in->cols(), in->vals(), stream);
coo_sort<T, IdxT>(in->n_rows, in->n_cols, in->nnz, in->rows(), in->cols(), in->vals(), stream);
}

/**
Expand All @@ -75,4 +75,4 @@ void coo_sort_by_weight(
}; // end NAMESPACE sparse
}; // end NAMESPACE raft

#endif
#endif

0 comments on commit 3b3925a

Please sign in to comment.