From: Josh Simmons Date: Tue, 5 Nov 2024 22:15:02 +0000 (+0100) Subject: shark-shaders: Small improvements to radix sort X-Git-Url: https://git.nega.tv//gitweb.cgi?a=commitdiff_plain;h=9f95d1578abcde52c61872c02e50f0733b2e2923;p=josh%2Fnarcissus shark-shaders: Small improvements to radix sort --- diff --git a/title/shark-shaders/shaders/radix_sort.h b/title/shark-shaders/shaders/radix_sort.h index cee1983..2a57798 100644 --- a/title/shark-shaders/shaders/radix_sort.h +++ b/title/shark-shaders/shaders/radix_sort.h @@ -9,8 +9,6 @@ const uint RADIX_WGP_SIZE = 256; const uint RADIX_ITEMS_PER_INVOCATION = 16; const uint RADIX_ITEMS_PER_WGP = RADIX_WGP_SIZE * RADIX_ITEMS_PER_INVOCATION; -const uint RADIX_SPINE_WGP_SIZE = 256; - layout(buffer_reference, std430, buffer_reference_align = 4) readonly buffer CountRef { uint value; }; diff --git a/title/shark-shaders/shaders/radix_sort_0_upsweep.comp b/title/shark-shaders/shaders/radix_sort_0_upsweep.comp index 480fbf7..4bff3a4 100644 --- a/title/shark-shaders/shaders/radix_sort_0_upsweep.comp +++ b/title/shark-shaders/shaders/radix_sort_0_upsweep.comp @@ -39,35 +39,33 @@ shared uint histogram[RADIX_DIGITS]; layout (local_size_x = RADIX_DIGITS, local_size_y = 1, local_size_z = 1) in; void main() { - const uint count = constants.count_buffer.value; const uint shift = constants.shift; + const uint count = constants.count_buffer.value; + + const uint wgp_count = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP; + const bool needs_bounds_check = gl_WorkGroupID.x == wgp_count - 1; // Clear local histogram histogram[gl_LocalInvocationID.x] = 0; barrier(); - const uint start = gl_WorkGroupID.x * RADIX_ITEMS_PER_WGP; - const uint end = start + RADIX_ITEMS_PER_WGP; - - const bool skip_bounds_check = end <= count; - - if (skip_bounds_check) { - for (uint i = start; i < end; i += RADIX_DIGITS) { - const uint index = i + gl_LocalInvocationID.x; - const uint value = constants.src_buffer.values[index]; - const uint digit = (value >> shift) & RADIX_MASK; - atomicAdd(histogram[digit], 1); - } - } else { - for (uint i = start; i < end; i += RADIX_DIGITS) { - const uint index = i + gl_LocalInvocationID.x; - if (index < count) { - const uint value = constants.src_buffer.values[index]; + if (needs_bounds_check) { + for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) { + const uint global_id = gl_WorkGroupID.x * gl_WorkGroupSize.x * RADIX_ITEMS_PER_INVOCATION + i * RADIX_DIGITS + gl_LocalInvocationID.x; + if (global_id < count) { + const uint value = constants.src_buffer.values[global_id]; const uint digit = (value >> shift) & RADIX_MASK; atomicAdd(histogram[digit], 1); } } + } else { + for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) { + const uint global_id = gl_WorkGroupID.x * gl_WorkGroupSize.x * RADIX_ITEMS_PER_INVOCATION + i * RADIX_DIGITS + gl_LocalInvocationID.x; + const uint value = constants.src_buffer.values[global_id]; + const uint digit = (value >> shift) & RADIX_MASK; + atomicAdd(histogram[digit], 1); + } } barrier(); @@ -75,6 +73,5 @@ void main() { // Scatter to the spine, this is a striped layout so we can efficiently // calculate the prefix sum. Re-calculate how many workgroups we dispatched // to determine the stride we need to write at. - const uint wgp_count = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP; constants.spine_buffer.values[(gl_LocalInvocationID.x * wgp_count) + gl_WorkGroupID.x] = histogram[gl_LocalInvocationID.x]; } \ No newline at end of file diff --git a/title/shark-shaders/shaders/radix_sort_1_spine.comp b/title/shark-shaders/shaders/radix_sort_1_spine.comp index a47dc4f..c9459bd 100644 --- a/title/shark-shaders/shaders/radix_sort_1_spine.comp +++ b/title/shark-shaders/shaders/radix_sort_1_spine.comp @@ -36,12 +36,12 @@ layout(std430, push_constant) uniform RadixSortSpineConstantsBlock { layout (constant_id = 0) const uint SUBGROUP_SIZE = 64; -const uint NUM_SUBGROUPS = RADIX_SPINE_WGP_SIZE / SUBGROUP_SIZE; +const uint NUM_SUBGROUPS = RADIX_WGP_SIZE / SUBGROUP_SIZE; +shared uint carry; shared uint sums[NUM_SUBGROUPS]; -shared uint carry_in; -layout (local_size_x = RADIX_SPINE_WGP_SIZE, local_size_y = 1, local_size_z = 1) in; +layout (local_size_x = RADIX_WGP_SIZE, local_size_y = 1, local_size_z = 1) in; void main() { const uint local_id = gl_SubgroupID * gl_SubgroupSize + gl_SubgroupInvocationID; @@ -51,12 +51,10 @@ void main() { // Re-calculate how many workgroups pushed data into the spine const uint upsweep_wgp_count = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP; - carry_in = 0; + carry = 0; for (uint i = 0; i < upsweep_wgp_count; i++) { - const uint spine_index = i * RADIX_DIGITS + local_id; - // Load values and calculate partial sums - const uint value = constants.spine_buffer.values[spine_index]; + const uint value = constants.spine_buffer.values[i * RADIX_DIGITS + local_id]; const uint sum = subgroupAdd(value); const uint scan = subgroupExclusiveAdd(value); @@ -66,22 +64,20 @@ void main() { barrier(); + const uint carry_in = carry; + // Scan partials if (local_id < NUM_SUBGROUPS) { sums[local_id] = subgroupExclusiveAdd(sums[local_id]); } - const uint carry = carry_in; - barrier(); // Write out the final prefix sum, combining the carry-in, subgroup sums, and local scan - constants.spine_buffer.values[spine_index] = carry + sums[gl_SubgroupID] + scan; + constants.spine_buffer.values[i * RADIX_DIGITS + local_id] = carry_in + sums[gl_SubgroupID] + scan; if (gl_SubgroupID == gl_NumSubgroups - 1 && subgroupElect()) { - carry_in += sums[gl_SubgroupID] + sum; + atomicAdd(carry, sums[gl_NumSubgroups - 1] + sum); } - - memoryBarrierShared(); } } diff --git a/title/shark-shaders/shaders/radix_sort_2_downsweep.comp b/title/shark-shaders/shaders/radix_sort_2_downsweep.comp index 19fcbcf..bbf7fcb 100644 --- a/title/shark-shaders/shaders/radix_sort_2_downsweep.comp +++ b/title/shark-shaders/shaders/radix_sort_2_downsweep.comp @@ -14,8 +14,6 @@ #extension GL_KHR_shader_subgroup_shuffle_relative: enable #extension GL_KHR_shader_subgroup_vote : require -//#extension GL_EXT_debug_printf : enable - #include "compute_bindings.h" #include "radix_sort.h" @@ -48,64 +46,97 @@ layout(std430, push_constant) uniform RadixSortDownsweepConstantsBlock { RadixSortDownsweepConstants constants; }; -shared uint values[RADIX_WGP_SIZE]; shared uint spine[RADIX_DIGITS]; shared uint match_masks[NUM_SUBGROUPS][RADIX_DIGITS]; layout (local_size_x = RADIX_WGP_SIZE, local_size_y = 1, local_size_z = 1) in; void main() { + const uint local_id = gl_SubgroupID * gl_SubgroupSize + gl_SubgroupInvocationID; + const uint shift = constants.shift; const uint count = constants.count_buffer.value; const uint wgp_count = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP; - const uint start = gl_WorkGroupID.x * RADIX_ITEMS_PER_WGP; - const uint end = min(start + RADIX_ITEMS_PER_WGP, count); - - const uint local_id = gl_SubgroupID * gl_SubgroupSize + gl_SubgroupInvocationID; - // Gather from spine. spine[local_id] = constants.spine_buffer.values[(local_id * wgp_count) + gl_WorkGroupID.x]; - for (uint value_base = start; value_base < end; value_base += RADIX_WGP_SIZE) { - // Clear shared memory and load values from src buffer. - for (uint i = 0; i < NUM_SUBGROUPS; i++) { - match_masks[i][local_id] = 0; - } + const bool needs_bounds_check = gl_WorkGroupID.x == wgp_count - 1; - barrier(); + if (needs_bounds_check) { + for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) { + const uint base = gl_WorkGroupID.x * RADIX_ITEMS_PER_WGP + i * RADIX_DIGITS; - const uint global_offset = value_base + local_id; - uint value = 0xffffffff; - uint digit = 0xff; + if (base >= count) + break; - if (global_offset < end) { - value = constants.src_buffer.values[global_offset]; - digit = (value >> shift) & RADIX_MASK; + // Clear shared memory and load values from src buffer. + for (uint j = 0; j < NUM_SUBGROUPS; j++) { + match_masks[j][local_id] = 0; + } + + barrier(); + + const uint global_id = base + local_id; + const uint value = global_id < count ? constants.src_buffer.values[global_id] : 0xffffffff; + const uint digit = (value >> shift) & RADIX_MASK; atomicOr(match_masks[gl_SubgroupID][digit], 1 << gl_SubgroupInvocationID); + + barrier(); + + uint peer_scan = 0; + for (uint i = 0; i < gl_NumSubgroups; i++) { + if (i < gl_SubgroupID) { + peer_scan += bitCount(match_masks[i][digit]); + } + } + peer_scan += bitCount(match_masks[gl_SubgroupID][digit] & gl_SubgroupLtMask.x); + + if (global_id < count) { + constants.dst_buffer.values[spine[digit] + peer_scan] = value; + } + + barrier(); + + // Increment the spine with the counts for the workgroup we just wrote out. + for (uint i = 0; i < NUM_SUBGROUPS; i++) { + atomicAdd(spine[local_id], bitCount(match_masks[i][local_id])); + } } + } else { + for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) { + // Clear shared memory and load values from src buffer. + for (uint j = 0; j < NUM_SUBGROUPS; j++) { + match_masks[j][local_id] = 0; + } - barrier(); + barrier(); + + const uint global_id = gl_WorkGroupID.x * RADIX_ITEMS_PER_WGP + i * RADIX_DIGITS + local_id; + const uint value = constants.src_buffer.values[global_id]; + const uint digit = (value >> shift) & RADIX_MASK; + atomicOr(match_masks[gl_SubgroupID][digit], 1 << gl_SubgroupInvocationID); - if (global_offset < end) { - const uint peer_mask = match_masks[gl_SubgroupID][digit]; + barrier(); - uint peer_scan = bitCount(peer_mask & gl_SubgroupLtMask.x); + uint peer_scan = 0; for (uint i = 0; i < gl_NumSubgroups; i++) { if (i < gl_SubgroupID) { peer_scan += bitCount(match_masks[i][digit]); } } + peer_scan += bitCount(match_masks[gl_SubgroupID][digit] & gl_SubgroupLtMask.x); - const uint dst_index = spine[digit] + peer_scan; - constants.dst_buffer.values[dst_index] = value; - } + constants.dst_buffer.values[spine[digit] + peer_scan] = value; - barrier(); + if (i != RADIX_ITEMS_PER_INVOCATION - 1) { + barrier(); - // Increment the spine with the counts for the workgroup we just wrote out. - for (uint i = 0; i < NUM_SUBGROUPS; i++) { - atomicAdd(spine[local_id], bitCount(match_masks[i][local_id])); + // Increment the spine with the counts for the workgroup we just wrote out. + for (uint i = 0; i < NUM_SUBGROUPS; i++) { + atomicAdd(spine[local_id], bitCount(match_masks[i][local_id])); + } + } } } }