Skip to content

Commit

Permalink
Fix hashing code to avoid many iterations..
Browse files Browse the repository at this point in the history
  • Loading branch information
danpovey committed Jun 10, 2021
1 parent 56d4aba commit 046eb7b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 19 deletions.
38 changes: 22 additions & 16 deletions k2/csrc/hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ class Hash {
uint64_t key_value = src_data[i];
if (~key_value == 0) return; // equals -1.. nothing there.
uint64_t key = key_value & key_mask,
leftover_index = 1 | (key >> new_buckets_num_bitsm1);
leftover_index = 1 |
((key >> new_buckets_num_bitsm1) ^ (key & new_num_buckets_mask));
size_t cur_bucket = key & new_num_buckets_mask;
while (1) {
uint64_t assumed = ~((uint64_t)0),
Expand Down Expand Up @@ -285,7 +286,8 @@ class Hash {
uint64_t *old_value = nullptr,
uint64_t **key_value_location = nullptr) const {
uint32_t cur_bucket = static_cast<uint32_t>(key) & num_buckets_mask_,
leftover_index = 1 | (key >> buckets_num_bitsm1_);
leftover_index =
1 | ((key >> buckets_num_bitsm1_) & (key & num_buckets_mask_));
constexpr int64_t KEY_MASK = (uint64_t(1)<<NUM_KEY_BITS) - 1,
VALUE_MASK = (uint64_t(1)<< (64 - NUM_KEY_BITS)) - 1;

Expand Down Expand Up @@ -351,7 +353,8 @@ class Hash {
constexpr int64_t KEY_MASK = (uint64_t(1) << NUM_KEY_BITS) - 1;

uint32_t cur_bucket = key & num_buckets_mask_,
leftover_index = 1 | (key >> buckets_num_bitsm1_);
leftover_index =
1 | ((key >> buckets_num_bitsm1_) ^ (key & num_buckets_mask_));
while (1) {
uint64_t old_elem = data_[cur_bucket];
if (~old_elem == 0) {
Expand Down Expand Up @@ -436,7 +439,8 @@ class Hash {
__forceinline__ __host__ __device__ void Delete(uint64_t key) const {
constexpr int64_t KEY_MASK = (uint64_t(1) << NUM_KEY_BITS) - 1;
uint32_t cur_bucket = key & num_buckets_mask_,
leftover_index = 1 | (key >> buckets_num_bitsm1_);
leftover_index =
1 | ((key >> buckets_num_bitsm1_) ^ (key & num_buckets_mask_));
while (1) {
uint64_t old_elem = data_[cur_bucket];
if ((old_elem & KEY_MASK) == key) {
Expand Down Expand Up @@ -505,7 +509,8 @@ class Hash {
uint64_t *old_value = nullptr,
uint64_t **key_value_location = nullptr) const {
uint32_t cur_bucket = static_cast<uint32_t>(key) & num_buckets_mask_,
leftover_index = 1 | (key >> buckets_num_bitsm1_);
leftover_index =
1 | ((key >> buckets_num_bitsm1_) ^ (key & num_buckets_mask_));
const uint32_t num_key_bits = num_key_bits_;
const uint64_t key_mask = (uint64_t(1) << num_key_bits) - 1,
not_value_mask = (uint64_t(-1) << (64 - num_key_bits));
Expand Down Expand Up @@ -571,7 +576,8 @@ class Hash {
const int64_t key_mask = (uint64_t(1) << num_key_bits) - 1;

uint32_t cur_bucket = key & num_buckets_mask_,
leftover_index = 1 | (key >> buckets_num_bitsm1_);
leftover_index =
1 | ((key >> buckets_num_bitsm1_) ^ (key & num_buckets_mask_));
while (1) {
uint64_t old_elem = data_[cur_bucket];
if (~old_elem == 0) {
Expand Down Expand Up @@ -647,7 +653,8 @@ class Hash {
*/
__forceinline__ __host__ __device__ void Delete(uint64_t key) const {
uint32_t cur_bucket = key & num_buckets_mask_,
leftover_index = 1 | (key >> buckets_num_bitsm1_);
leftover_index =
1 | ((key >> buckets_num_bitsm1_) ^ (key & num_buckets_mask_));
const uint64_t key_mask = (uint64_t(1) << num_key_bits_) - 1;
while (1) {
uint64_t old_elem = data_[cur_bucket];
Expand Down Expand Up @@ -737,7 +744,8 @@ class Hash {
// the lowest-order `num_implicit_key_bits_` bits of the bucket index will
// not change when we fail over to the next location. Without this, our
// scheme would not work.
uint32_t leftover_index = (1 | (key >> buckets_num_bitsm1_))
uint32_t leftover_index =
(1 | ((key >> buckets_num_bitsm1_) ^ (key & num_buckets_mask_)))
<< num_implicit_key_bits_;
uint64_t kept_key = key >> num_implicit_key_bits_;

Expand All @@ -746,7 +754,7 @@ class Hash {

K2_DCHECK_EQ((kept_key & ~kept_key_mask) | (value & not_value_mask), 0);

uint64_t new_elem = (value << num_kept_key_bits_) | kept_key;
uint64_t new_elem = (value << num_kept_key_bits_) | kept_key;
while (1) {
uint64_t cur_elem = data_[cur_bucket];
if ((cur_elem & kept_key_mask) == kept_key) {
Expand All @@ -769,11 +777,7 @@ class Hash {
return false; // Another thread inserted this key
}
}
// Rotate bucket index until we find a free location. This will
// eventually visit all bucket indexes before it returns to the same
// location, because leftover_index is odd (so only satisfies
// (n * leftover_index) % num_buckets == 0 for n == num_buckets).
// Note: n here is the number of times we went around the loop.
// Rotate bucket index until we find a free location.
cur_bucket = (cur_bucket + leftover_index) & num_buckets_mask_;
}
}
Expand Down Expand Up @@ -804,7 +808,8 @@ class Hash {
const int64_t kept_key_mask = (uint64_t(1) << num_kept_key_bits_) - 1;

uint32_t cur_bucket = key & num_buckets_mask_,
leftover_index = (1 | (key >> buckets_num_bitsm1_))
leftover_index =
(1 | ((key >> buckets_num_bitsm1_) ^ (key & num_buckets_mask_)))
<< num_implicit_key_bits_;
uint64_t kept_key = key >> num_implicit_key_bits_;

Expand Down Expand Up @@ -887,7 +892,8 @@ class Hash {
*/
__forceinline__ __host__ __device__ void Delete(uint64_t key) const {
uint32_t cur_bucket = key & num_buckets_mask_,
leftover_index = (1 | (key >> buckets_num_bitsm1_))
leftover_index =
(1 | ((key >> buckets_num_bitsm1_) ^ (key & num_buckets_mask_)))
<< num_implicit_key_bits_;
uint64_t kept_key = key >> num_implicit_key_bits_;
const uint64_t kept_key_mask = (uint64_t(1) << num_kept_key_bits_) - 1;
Expand Down
2 changes: 1 addition & 1 deletion k2/csrc/intersect.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,7 @@ class DeviceIntersector {

// arcs_row_ids_, which always maintained as having the same size as `arcs_`,
// maps from the output arc to the corresponding ostate index that the arc
// leaves from (index into states_). Actually this may be redu
// leaves from (index into states_). Actually this may be redundant.
Array1<int32_t> arcs_row_ids_;

// The hash maps from state-pair, as:
Expand Down
3 changes: 1 addition & 2 deletions k2/csrc/log.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,7 @@ inline K2_CUDA_HOSTDEV LogLevel GetCurrentLogLevel() {
#define K2_CUDA_SAFE_CALL(...) \
do { \
__VA_ARGS__; \
if (!::k2::internal::kDisableDebug && \
k2::internal::EnableCudaDeviceSync()) \
if (k2::internal::EnableCudaDeviceSync()) \
cudaDeviceSynchronize(); \
cudaError_t e = cudaGetLastError(); \
K2_CHECK_CUDA_ERROR(e); \
Expand Down

0 comments on commit 046eb7b

Please sign in to comment.