Skip to content

Commit

Permalink
Simplify code..
Browse files Browse the repository at this point in the history
  • Loading branch information
danpovey committed Jun 10, 2021
1 parent ec49583 commit a0eb91d
Showing 1 changed file with 28 additions and 29 deletions.
57 changes: 28 additions & 29 deletions k2/csrc/hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,11 @@ unsigned long long int __forceinline__ __host__ __device__ AtomicCAS(
- The number of buckets is a power of 2 provided by the user to the constructor;
currently no resizing is supported.
- When accessing hash[key], we use bucket_index == key % num_buckets,
leftover_index = 1 | ((key * 2) / num_buckets). This is the
leftover part of the index times 2, plus 1.
bucket_inc = 1 | (((key * 2) / num_buckets) ^ key).
- If the bucket at `bucket_index` is occupied, we look in locations
`(bucket_index + n * leftover_index)%num_buckets` for n = 1, 2, ...;
`(bucket_index + n * bucket_inc)%num_buckets` for n = 1, 2, ...;
this choice ensures that if multiple keys hash to the same bucket,
they don't all access the same sequence of locations; and leftover_index
they don't all access the same sequence of locations; and bucket_inc
being odd ensures we eventually try all locations (of course for
reasonable hash occupancy levels, we shouldn't ever have to try
more than two or three).
Expand Down Expand Up @@ -157,14 +156,14 @@ class Hash {
uint64_t key_value = src_data[i];
if (~key_value == 0) return; // equals -1.. nothing there.
uint64_t key = key_value & key_mask,
leftover_index = 1 | ((key >> new_buckets_num_bitsm1) ^ key);
bucket_inc = 1 | ((key >> new_buckets_num_bitsm1) ^ key);
size_t cur_bucket = key & new_num_buckets_mask;
while (1) {
uint64_t assumed = ~((uint64_t)0),
old_elem = AtomicCAS((unsigned long long*)(data + cur_bucket),
assumed, key_value);
if (old_elem == assumed) return;
cur_bucket = (cur_bucket + leftover_index) & new_num_buckets_mask;
cur_bucket = (cur_bucket + bucket_inc) & new_num_buckets_mask;
// Keep iterating until we find a free spot in the new hash...
}
});
Expand Down Expand Up @@ -285,7 +284,7 @@ class Hash {
uint64_t *old_value = nullptr,
uint64_t **key_value_location = nullptr) const {
uint32_t cur_bucket = static_cast<uint32_t>(key) & num_buckets_mask_,
leftover_index = 1 | ((key >> buckets_num_bitsm1_) ^ key);
bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key);
constexpr int64_t KEY_MASK = (uint64_t(1)<<NUM_KEY_BITS) - 1,
VALUE_MASK = (uint64_t(1)<< (64 - NUM_KEY_BITS)) - 1;

Expand Down Expand Up @@ -316,10 +315,10 @@ class Hash {
}
// Rotate bucket index until we find a free location. This will
// eventually visit all bucket indexes before it returns to the same
// location, because leftover_index is odd (so only satisfies
// (n * leftover_index) % num_buckets == 0 for n == num_buckets).
// location, because bucket_inc is odd (so only satisfies
// (n * bucket_inc) % num_buckets == 0 for n == num_buckets).
// Note: n here is the number of times we went around the loop.
cur_bucket = (cur_bucket + leftover_index) & num_buckets_mask_;
cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_;
}
}

Expand Down Expand Up @@ -351,7 +350,7 @@ class Hash {
constexpr int64_t KEY_MASK = (uint64_t(1) << NUM_KEY_BITS) - 1;

uint32_t cur_bucket = key & num_buckets_mask_,
leftover_index = 1 | ((key >> buckets_num_bitsm1_) ^ key);
bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key);
while (1) {
uint64_t old_elem = data_[cur_bucket];
if (~old_elem == 0) {
Expand All @@ -362,7 +361,7 @@ class Hash {
*key_value_location = data_ + cur_bucket;
return true;
} else {
cur_bucket = (cur_bucket + leftover_index) & num_buckets_mask_;
cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_;
}
}
}
Expand Down Expand Up @@ -436,14 +435,14 @@ class Hash {
__forceinline__ __host__ __device__ void Delete(uint64_t key) const {
constexpr int64_t KEY_MASK = (uint64_t(1) << NUM_KEY_BITS) - 1;
uint32_t cur_bucket = key & num_buckets_mask_,
leftover_index = 1 | ((key >> buckets_num_bitsm1_) ^ key);
bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key);
while (1) {
uint64_t old_elem = data_[cur_bucket];
if ((old_elem & KEY_MASK) == key) {
data_[cur_bucket] = ~((uint64_t)0);
return;
} else {
cur_bucket = (cur_bucket + leftover_index) & num_buckets_mask_;
cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_;
}
}
}
Expand Down Expand Up @@ -505,7 +504,7 @@ class Hash {
uint64_t *old_value = nullptr,
uint64_t **key_value_location = nullptr) const {
uint32_t cur_bucket = static_cast<uint32_t>(key) & num_buckets_mask_,
leftover_index = 1 | ((key >> buckets_num_bitsm1_) ^ key);
bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key);
const uint32_t num_key_bits = num_key_bits_;
const uint64_t key_mask = (uint64_t(1) << num_key_bits) - 1,
not_value_mask = (uint64_t(-1) << (64 - num_key_bits));
Expand Down Expand Up @@ -537,10 +536,10 @@ class Hash {
}
// Rotate bucket index until we find a free location. This will
// eventually visit all bucket indexes before it returns to the same
// location, because leftover_index is odd (so only satisfies
// (n * leftover_index) % num_buckets == 0 for n == num_buckets).
// location, because bucket_inc is odd (so only satisfies
// (n * bucket_inc) % num_buckets == 0 for n == num_buckets).
// Note: n here is the number of times we went around the loop.
cur_bucket = (cur_bucket + leftover_index) & num_buckets_mask_;
cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_;
}
}

Expand Down Expand Up @@ -571,7 +570,7 @@ class Hash {
const int64_t key_mask = (uint64_t(1) << num_key_bits) - 1;

uint32_t cur_bucket = key & num_buckets_mask_,
leftover_index = 1 | ((key >> buckets_num_bitsm1_) ^ key);
bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key);
while (1) {
uint64_t old_elem = data_[cur_bucket];
if (~old_elem == 0) {
Expand All @@ -582,7 +581,7 @@ class Hash {
*key_value_location = data_ + cur_bucket;
return true;
} else {
cur_bucket = (cur_bucket + leftover_index) & num_buckets_mask_;
cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_;
}
}
}
Expand Down Expand Up @@ -647,15 +646,15 @@ class Hash {
*/
__forceinline__ __host__ __device__ void Delete(uint64_t key) const {
uint32_t cur_bucket = key & num_buckets_mask_,
leftover_index = 1 | ((key >> buckets_num_bitsm1_) ^ key);
bucket_inc = 1 | ((key >> buckets_num_bitsm1_) ^ key);
const uint64_t key_mask = (uint64_t(1) << num_key_bits_) - 1;
while (1) {
uint64_t old_elem = data_[cur_bucket];
if ((old_elem & key_mask) == key) {
data_[cur_bucket] = ~((uint64_t)0);
return;
} else {
cur_bucket = (cur_bucket + leftover_index) & num_buckets_mask_;
cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_;
}
}
}
Expand Down Expand Up @@ -733,11 +732,11 @@ class Hash {
uint64_t *old_value = nullptr,
uint64_t **key_value_location = nullptr) const {
uint32_t cur_bucket = static_cast<uint32_t>(key) & num_buckets_mask_;
// Shifting `leftover_index` right by num_implicit_key_bits_ ensures that
// Shifting `bucket_inc` right by num_implicit_key_bits_ ensures that
// the lowest-order `num_implicit_key_bits_` bits of the bucket index will
// not change when we fail over to the next location. Without this, our
// scheme would not work.
uint32_t leftover_index = (1 | ((key >> buckets_num_bitsm1_) ^ key))
uint32_t bucket_inc = (1 | ((key >> buckets_num_bitsm1_) ^ key))
<< num_implicit_key_bits_;
uint64_t kept_key = key >> num_implicit_key_bits_;

Expand Down Expand Up @@ -770,7 +769,7 @@ class Hash {
}
}
// Rotate bucket index until we find a free location.
cur_bucket = (cur_bucket + leftover_index) & num_buckets_mask_;
cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_;
}
}

Expand Down Expand Up @@ -800,7 +799,7 @@ class Hash {
const int64_t kept_key_mask = (uint64_t(1) << num_kept_key_bits_) - 1;

uint32_t cur_bucket = key & num_buckets_mask_,
leftover_index = (1 | ((key >> buckets_num_bitsm1_) ^ key))
bucket_inc = (1 | ((key >> buckets_num_bitsm1_) ^ key))
<< num_implicit_key_bits_;
uint64_t kept_key = key >> num_implicit_key_bits_;

Expand All @@ -814,7 +813,7 @@ class Hash {
*key_value_location = data_ + cur_bucket;
return true;
} else {
cur_bucket = (cur_bucket + leftover_index) & num_buckets_mask_;
cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_;
}
}
}
Expand Down Expand Up @@ -883,7 +882,7 @@ class Hash {
*/
__forceinline__ __host__ __device__ void Delete(uint64_t key) const {
uint32_t cur_bucket = key & num_buckets_mask_,
leftover_index = (1 | ((key >> buckets_num_bitsm1_) ^ key))
bucket_inc = (1 | ((key >> buckets_num_bitsm1_) ^ key))
<< num_implicit_key_bits_;
uint64_t kept_key = key >> num_implicit_key_bits_;
const uint64_t kept_key_mask = (uint64_t(1) << num_kept_key_bits_) - 1;
Expand All @@ -893,7 +892,7 @@ class Hash {
data_[cur_bucket] = ~((uint64_t)0);
return;
} else {
cur_bucket = (cur_bucket + leftover_index) & num_buckets_mask_;
cur_bucket = (cur_bucket + bucket_inc) & num_buckets_mask_;
}
}
}
Expand Down

0 comments on commit a0eb91d

Please sign in to comment.