From 01cd6d09378b9c31d902324c7084a8185f9fc312 Mon Sep 17 00:00:00 2001 From: Kishimisu Date: Fri, 7 Jun 2024 03:54:39 +0200 Subject: [PATCH] Prevent OOB memory access in shaders --- example/tests.js | 2 +- src/PrefixSumKernel.js | 6 +++-- .../prefix_sum_no_bank_conflict.js | 23 ++++++++++++++--- .../optimizations/radix_sort_local_shuffle.js | 17 +++++++++---- src/shaders/prefix_sum.js | 25 ++++++++++++++++--- src/shaders/radix_sort.js | 17 ++++++++----- 6 files changed, 68 insertions(+), 22 deletions(-) diff --git a/example/tests.js b/example/tests.js index 4ab1f7f..cedc6a3 100644 --- a/example/tests.js +++ b/example/tests.js @@ -93,7 +93,7 @@ async function test_radix_sort(device, keys_and_values = false) { isOK = isOK && valuesResult.every((v, i) => keysResult[i] == keys[v]) } - console.log('Test Radix Sort:', element_count, sub_element_count, workgroup_size, isOK ? 'OK' : 'ERROR') + console.log('Test Radix Sort:', element_count, sub_element_count, workgroup_size, check_order, local_shuffle, avoid_bank_conflicts, isOK ? 'OK' : 'ERROR') if (!isOK) { console.log('keys', keys) diff --git a/src/PrefixSumKernel.js b/src/PrefixSumKernel.js index 5423d68..7f223bf 100644 --- a/src/PrefixSumKernel.js +++ b/src/PrefixSumKernel.js @@ -98,7 +98,8 @@ class PrefixSumKernel { 'WORKGROUP_SIZE_X': this.workgroup_size.x, 'WORKGROUP_SIZE_Y': this.workgroup_size.y, 'THREADS_PER_WORKGROUP': this.threads_per_workgroup, - 'ITEMS_PER_WORKGROUP': this.items_per_workgroup + 'ITEMS_PER_WORKGROUP': this.items_per_workgroup, + 'ELEMENT_COUNT': count, } } }) @@ -119,7 +120,8 @@ class PrefixSumKernel { constants: { 'WORKGROUP_SIZE_X': this.workgroup_size.x, 'WORKGROUP_SIZE_Y': this.workgroup_size.y, - 'THREADS_PER_WORKGROUP': this.threads_per_workgroup + 'THREADS_PER_WORKGROUP': this.threads_per_workgroup, + 'ELEMENT_COUNT': count, } } }) diff --git a/src/shaders/optimizations/prefix_sum_no_bank_conflict.js b/src/shaders/optimizations/prefix_sum_no_bank_conflict.js index fae00bf..0e3b225 100644 --- a/src/shaders/optimizations/prefix_sum_no_bank_conflict.js +++ b/src/shaders/optimizations/prefix_sum_no_bank_conflict.js @@ -12,6 +12,7 @@ override WORKGROUP_SIZE_X: u32; override WORKGROUP_SIZE_Y: u32; override THREADS_PER_WORKGROUP: u32; override ITEMS_PER_WORKGROUP: u32; +override ELEMENT_COUNT: u32; const NUM_BANKS: u32 = 32; const LOG_NUM_BANKS: u32 = 5; @@ -43,8 +44,8 @@ fn reduce_downsweep( let s_bi = bi + get_offset(bi); let g_ai = ai + WID * 2; let g_bi = bi + WID * 2; - temp[s_ai] = items[g_ai]; - temp[s_bi] = items[g_bi]; + temp[s_ai] = select(items[g_ai], 0, g_ai >= ELEMENT_COUNT); + temp[s_bi] = select(items[g_bi], 0, g_bi >= ELEMENT_COUNT); var offset: u32 = 1; @@ -91,8 +92,12 @@ fn reduce_downsweep( workgroupBarrier(); // Copy result from shared memory to global memory - items[g_ai] = temp[s_ai]; - items[g_bi] = temp[s_bi]; + if (g_ai < ELEMENT_COUNT) { + items[g_ai] = temp[s_ai]; + } + if (g_bi < ELEMENT_COUNT) { + items[g_bi] = temp[s_bi]; + } } @compute @workgroup_size(WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y, 1) @@ -106,9 +111,19 @@ fn add_block_sums( let GID = WID + TID; // Global thread ID let ELM_ID = GID * 2; + + if (ELM_ID >= ELEMENT_COUNT) { + return; + } + let blockSum = blockSums[WORKGROUP_ID]; items[ELM_ID] += blockSum; + + if (ELM_ID + 1 >= ELEMENT_COUNT) { + return; + } + items[ELM_ID + 1] += blockSum; }` diff --git a/src/shaders/optimizations/radix_sort_local_shuffle.js b/src/shaders/optimizations/radix_sort_local_shuffle.js index f97ad00..a3623ea 100644 --- a/src/shaders/optimizations/radix_sort_local_shuffle.js +++ b/src/shaders/optimizations/radix_sort_local_shuffle.js @@ -31,8 +31,12 @@ fn radix_sort( let GID = WID + TID; // Global thread ID // Extract 2 bits from the input - let elm = input[GID]; - let val = values[GID]; + var elm: u32 = 0; + var val: u32 = 0; + if (GID < ELEMENT_COUNT) { + elm = input[GID]; + val = values[GID]; + } let extract_bits: u32 = (elm >> CURRENT_BIT) & 0x3; var bit_prefix_sums = array(0, 0, 0, 0); @@ -58,14 +62,18 @@ fn radix_sort( s_prefix_sum[inOffset + 1] = bitmask; workgroupBarrier(); + var prefix_sum: u32 = 0; + // Prefix sum for (var offset: u32 = 1; offset < THREADS_PER_WORKGROUP; offset *= 2) { if (TID >= offset) { - s_prefix_sum[outOffset] = s_prefix_sum[inOffset] + s_prefix_sum[inOffset - offset]; + prefix_sum = s_prefix_sum[inOffset] + s_prefix_sum[inOffset - offset]; } else { - s_prefix_sum[outOffset] = s_prefix_sum[inOffset]; + prefix_sum = s_prefix_sum[inOffset]; } + s_prefix_sum[outOffset] = prefix_sum; + // Swap buffers outOffset = inOffset; swapOffset = TPW - swapOffset; @@ -75,7 +83,6 @@ fn radix_sort( } // Store prefix sum for current bit - let prefix_sum = s_prefix_sum[inOffset]; bit_prefix_sums[b] = prefix_sum; if (TID == LAST_THREAD) { diff --git a/src/shaders/prefix_sum.js b/src/shaders/prefix_sum.js index f6670f6..d6c2041 100644 --- a/src/shaders/prefix_sum.js +++ b/src/shaders/prefix_sum.js @@ -7,6 +7,7 @@ override WORKGROUP_SIZE_X: u32; override WORKGROUP_SIZE_Y: u32; override THREADS_PER_WORKGROUP: u32; override ITEMS_PER_WORKGROUP: u32; +override ELEMENT_COUNT: u32; var temp: array; @@ -24,8 +25,8 @@ fn reduce_downsweep( let ELM_GID = GID * 2; // Element pair global ID // Load input to shared memory - temp[ELM_TID] = items[ELM_GID]; - temp[ELM_TID + 1] = items[ELM_GID + 1]; + temp[ELM_TID] = select(items[ELM_GID], 0, ELM_GID >= ELEMENT_COUNT); + temp[ELM_TID + 1] = select(items[ELM_GID + 1], 0, ELM_GID + 1 >= ELEMENT_COUNT); var offset: u32 = 1; @@ -67,7 +68,14 @@ fn reduce_downsweep( workgroupBarrier(); // Copy result from shared memory to global memory - items[ELM_GID] = temp[ELM_TID]; + if (ELM_GID >= ELEMENT_COUNT) { + return; + } + items[ELM_GID] = temp[ELM_TID]; + + if (ELM_GID + 1 >= ELEMENT_COUNT) { + return; + } items[ELM_GID + 1] = temp[ELM_TID + 1]; } @@ -80,12 +88,21 @@ fn add_block_sums( let WORKGROUP_ID = w_id.x + w_id.y * w_dim.x; let WID = WORKGROUP_ID * THREADS_PER_WORKGROUP; let GID = WID + TID; // Global thread ID - let ELM_ID = GID * 2; + + if (ELM_ID >= ELEMENT_COUNT) { + return; + } + let blockSum = blockSums[WORKGROUP_ID]; items[ELM_ID] += blockSum; + + if (ELM_ID + 1 >= ELEMENT_COUNT) { + return; + } + items[ELM_ID + 1] += blockSum; }` diff --git a/src/shaders/radix_sort.js b/src/shaders/radix_sort.js index 6f34d91..e5a8e69 100644 --- a/src/shaders/radix_sort.js +++ b/src/shaders/radix_sort.js @@ -24,7 +24,7 @@ fn radix_sort( let GID = WID + TID; // Global thread ID // Extract 2 bits from the input - let elm = input[GID]; + let elm = select(input[GID], 0, GID >= ELEMENT_COUNT); let extract_bits: u32 = (elm >> CURRENT_BIT) & 0x3; var bit_prefix_sums = array(0, 0, 0, 0); @@ -50,14 +50,18 @@ fn radix_sort( s_prefix_sum[inOffset + 1] = bitmask; workgroupBarrier(); + var prefix_sum: u32 = 0; + // Prefix sum for (var offset: u32 = 1; offset < THREADS_PER_WORKGROUP; offset *= 2) { if (TID >= offset) { - s_prefix_sum[outOffset] = s_prefix_sum[inOffset] + s_prefix_sum[inOffset - offset]; + prefix_sum = s_prefix_sum[inOffset] + s_prefix_sum[inOffset - offset]; } else { - s_prefix_sum[outOffset] = s_prefix_sum[inOffset]; + prefix_sum = s_prefix_sum[inOffset]; } + s_prefix_sum[outOffset] = prefix_sum; + // Swap buffers outOffset = inOffset; swapOffset = TPW - swapOffset; @@ -67,7 +71,6 @@ fn radix_sort( } // Store prefix sum for current bit - let prefix_sum = s_prefix_sum[inOffset]; bit_prefix_sums[b] = prefix_sum; if (TID == LAST_THREAD) { @@ -82,8 +85,10 @@ fn radix_sort( inOffset = TID + swapOffset; } - // Store local prefix sum to global memory - local_prefix_sums[GID] = bit_prefix_sums[extract_bits]; + if (GID < ELEMENT_COUNT) { + // Store local prefix sum to global memory + local_prefix_sums[GID] = bit_prefix_sums[extract_bits]; + } }` export default radixSortSource; \ No newline at end of file