Skip to content

Commit

Permalink
Merge pull request #3 from kishimisu/failing-test-fix
Browse files Browse the repository at this point in the history
Fix shader OOB memory access
  • Loading branch information
kishimisu authored Jun 7, 2024
2 parents df8ef15 + 01cd6d0 commit 9797fcb
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 22 deletions.
2 changes: 1 addition & 1 deletion example/tests.js
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async function test_radix_sort(device, keys_and_values = false) {
isOK = isOK && valuesResult.every((v, i) => keysResult[i] == keys[v])
}

console.log('Test Radix Sort:', element_count, sub_element_count, workgroup_size, isOK ? 'OK' : 'ERROR')
console.log('Test Radix Sort:', element_count, sub_element_count, workgroup_size, check_order, local_shuffle, avoid_bank_conflicts, isOK ? 'OK' : 'ERROR')

if (!isOK) {
console.log('keys', keys)
Expand Down
6 changes: 4 additions & 2 deletions src/PrefixSumKernel.js
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ class PrefixSumKernel {
'WORKGROUP_SIZE_X': this.workgroup_size.x,
'WORKGROUP_SIZE_Y': this.workgroup_size.y,
'THREADS_PER_WORKGROUP': this.threads_per_workgroup,
'ITEMS_PER_WORKGROUP': this.items_per_workgroup
'ITEMS_PER_WORKGROUP': this.items_per_workgroup,
'ELEMENT_COUNT': count,
}
}
})
Expand All @@ -119,7 +120,8 @@ class PrefixSumKernel {
constants: {
'WORKGROUP_SIZE_X': this.workgroup_size.x,
'WORKGROUP_SIZE_Y': this.workgroup_size.y,
'THREADS_PER_WORKGROUP': this.threads_per_workgroup
'THREADS_PER_WORKGROUP': this.threads_per_workgroup,
'ELEMENT_COUNT': count,
}
}
})
Expand Down
23 changes: 19 additions & 4 deletions src/shaders/optimizations/prefix_sum_no_bank_conflict.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ override WORKGROUP_SIZE_X: u32;
override WORKGROUP_SIZE_Y: u32;
override THREADS_PER_WORKGROUP: u32;
override ITEMS_PER_WORKGROUP: u32;
override ELEMENT_COUNT: u32;
const NUM_BANKS: u32 = 32;
const LOG_NUM_BANKS: u32 = 5;
Expand Down Expand Up @@ -43,8 +44,8 @@ fn reduce_downsweep(
let s_bi = bi + get_offset(bi);
let g_ai = ai + WID * 2;
let g_bi = bi + WID * 2;
temp[s_ai] = items[g_ai];
temp[s_bi] = items[g_bi];
temp[s_ai] = select(items[g_ai], 0, g_ai >= ELEMENT_COUNT);
temp[s_bi] = select(items[g_bi], 0, g_bi >= ELEMENT_COUNT);
var offset: u32 = 1;
Expand Down Expand Up @@ -91,8 +92,12 @@ fn reduce_downsweep(
workgroupBarrier();
// Copy result from shared memory to global memory
items[g_ai] = temp[s_ai];
items[g_bi] = temp[s_bi];
if (g_ai < ELEMENT_COUNT) {
items[g_ai] = temp[s_ai];
}
if (g_bi < ELEMENT_COUNT) {
items[g_bi] = temp[s_bi];
}
}
@compute @workgroup_size(WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y, 1)
Expand All @@ -106,9 +111,19 @@ fn add_block_sums(
let GID = WID + TID; // Global thread ID
let ELM_ID = GID * 2;
if (ELM_ID >= ELEMENT_COUNT) {
return;
}
let blockSum = blockSums[WORKGROUP_ID];
items[ELM_ID] += blockSum;
if (ELM_ID + 1 >= ELEMENT_COUNT) {
return;
}
items[ELM_ID + 1] += blockSum;
}`

Expand Down
17 changes: 12 additions & 5 deletions src/shaders/optimizations/radix_sort_local_shuffle.js
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ fn radix_sort(
let GID = WID + TID; // Global thread ID
// Extract 2 bits from the input
let elm = input[GID];
let val = values[GID];
var elm: u32 = 0;
var val: u32 = 0;
if (GID < ELEMENT_COUNT) {
elm = input[GID];
val = values[GID];
}
let extract_bits: u32 = (elm >> CURRENT_BIT) & 0x3;
var bit_prefix_sums = array<u32, 4>(0, 0, 0, 0);
Expand All @@ -58,14 +62,18 @@ fn radix_sort(
s_prefix_sum[inOffset + 1] = bitmask;
workgroupBarrier();
var prefix_sum: u32 = 0;
// Prefix sum
for (var offset: u32 = 1; offset < THREADS_PER_WORKGROUP; offset *= 2) {
if (TID >= offset) {
s_prefix_sum[outOffset] = s_prefix_sum[inOffset] + s_prefix_sum[inOffset - offset];
prefix_sum = s_prefix_sum[inOffset] + s_prefix_sum[inOffset - offset];
} else {
s_prefix_sum[outOffset] = s_prefix_sum[inOffset];
prefix_sum = s_prefix_sum[inOffset];
}
s_prefix_sum[outOffset] = prefix_sum;
// Swap buffers
outOffset = inOffset;
swapOffset = TPW - swapOffset;
Expand All @@ -75,7 +83,6 @@ fn radix_sort(
}
// Store prefix sum for current bit
let prefix_sum = s_prefix_sum[inOffset];
bit_prefix_sums[b] = prefix_sum;
if (TID == LAST_THREAD) {
Expand Down
25 changes: 21 additions & 4 deletions src/shaders/prefix_sum.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ override WORKGROUP_SIZE_X: u32;
override WORKGROUP_SIZE_Y: u32;
override THREADS_PER_WORKGROUP: u32;
override ITEMS_PER_WORKGROUP: u32;
override ELEMENT_COUNT: u32;
var<workgroup> temp: array<u32, ITEMS_PER_WORKGROUP*2>;
Expand All @@ -24,8 +25,8 @@ fn reduce_downsweep(
let ELM_GID = GID * 2; // Element pair global ID
// Load input to shared memory
temp[ELM_TID] = items[ELM_GID];
temp[ELM_TID + 1] = items[ELM_GID + 1];
temp[ELM_TID] = select(items[ELM_GID], 0, ELM_GID >= ELEMENT_COUNT);
temp[ELM_TID + 1] = select(items[ELM_GID + 1], 0, ELM_GID + 1 >= ELEMENT_COUNT);
var offset: u32 = 1;
Expand Down Expand Up @@ -67,7 +68,14 @@ fn reduce_downsweep(
workgroupBarrier();
// Copy result from shared memory to global memory
items[ELM_GID] = temp[ELM_TID];
if (ELM_GID >= ELEMENT_COUNT) {
return;
}
items[ELM_GID] = temp[ELM_TID];
if (ELM_GID + 1 >= ELEMENT_COUNT) {
return;
}
items[ELM_GID + 1] = temp[ELM_TID + 1];
}
Expand All @@ -80,12 +88,21 @@ fn add_block_sums(
let WORKGROUP_ID = w_id.x + w_id.y * w_dim.x;
let WID = WORKGROUP_ID * THREADS_PER_WORKGROUP;
let GID = WID + TID; // Global thread ID
let ELM_ID = GID * 2;
if (ELM_ID >= ELEMENT_COUNT) {
return;
}
let blockSum = blockSums[WORKGROUP_ID];
items[ELM_ID] += blockSum;
if (ELM_ID + 1 >= ELEMENT_COUNT) {
return;
}
items[ELM_ID + 1] += blockSum;
}`

Expand Down
17 changes: 11 additions & 6 deletions src/shaders/radix_sort.js
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn radix_sort(
let GID = WID + TID; // Global thread ID
// Extract 2 bits from the input
let elm = input[GID];
let elm = select(input[GID], 0, GID >= ELEMENT_COUNT);
let extract_bits: u32 = (elm >> CURRENT_BIT) & 0x3;
var bit_prefix_sums = array<u32, 4>(0, 0, 0, 0);
Expand All @@ -50,14 +50,18 @@ fn radix_sort(
s_prefix_sum[inOffset + 1] = bitmask;
workgroupBarrier();
var prefix_sum: u32 = 0;
// Prefix sum
for (var offset: u32 = 1; offset < THREADS_PER_WORKGROUP; offset *= 2) {
if (TID >= offset) {
s_prefix_sum[outOffset] = s_prefix_sum[inOffset] + s_prefix_sum[inOffset - offset];
prefix_sum = s_prefix_sum[inOffset] + s_prefix_sum[inOffset - offset];
} else {
s_prefix_sum[outOffset] = s_prefix_sum[inOffset];
prefix_sum = s_prefix_sum[inOffset];
}
s_prefix_sum[outOffset] = prefix_sum;
// Swap buffers
outOffset = inOffset;
swapOffset = TPW - swapOffset;
Expand All @@ -67,7 +71,6 @@ fn radix_sort(
}
// Store prefix sum for current bit
let prefix_sum = s_prefix_sum[inOffset];
bit_prefix_sums[b] = prefix_sum;
if (TID == LAST_THREAD) {
Expand All @@ -82,8 +85,10 @@ fn radix_sort(
inOffset = TID + swapOffset;
}
// Store local prefix sum to global memory
local_prefix_sums[GID] = bit_prefix_sums[extract_bits];
if (GID < ELEMENT_COUNT) {
// Store local prefix sum to global memory
local_prefix_sums[GID] = bit_prefix_sums[extract_bits];
}
}`

export default radixSortSource;

0 comments on commit 9797fcb

Please sign in to comment.