]> git.nega.tv - josh/narcissus/commitdiff
shark-shaders: Small improvements to radix sort
authorJosh Simmons <josh@nega.tv>
Tue, 5 Nov 2024 22:15:02 +0000 (23:15 +0100)
committerJosh Simmons <josh@nega.tv>
Tue, 5 Nov 2024 22:15:02 +0000 (23:15 +0100)
title/shark-shaders/shaders/radix_sort.h
title/shark-shaders/shaders/radix_sort_0_upsweep.comp
title/shark-shaders/shaders/radix_sort_1_spine.comp
title/shark-shaders/shaders/radix_sort_2_downsweep.comp

index cee19831bf9ff2ac32f50535b2519dd0f6a3594b..2a57798beac308e0a95050b8b43dfe7c671a3c93 100644 (file)
@@ -9,8 +9,6 @@ const uint RADIX_WGP_SIZE = 256;
 const uint RADIX_ITEMS_PER_INVOCATION = 16;
 const uint RADIX_ITEMS_PER_WGP = RADIX_WGP_SIZE * RADIX_ITEMS_PER_INVOCATION;
 
-const uint RADIX_SPINE_WGP_SIZE = 256;
-
 layout(buffer_reference, std430, buffer_reference_align = 4) readonly buffer CountRef {
     uint value;
 };
index 480fbf7d0040ee379a6da439fcd3c02f7916fa89..4bff3a4346f787a0c51a800a3586eabf6db2e239 100644 (file)
@@ -39,35 +39,33 @@ shared uint histogram[RADIX_DIGITS];
 layout (local_size_x = RADIX_DIGITS, local_size_y = 1, local_size_z = 1) in;
 
 void main() {
-    const uint count = constants.count_buffer.value;
     const uint shift = constants.shift;
+    const uint count = constants.count_buffer.value;
+
+    const uint wgp_count = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP;
+    const bool needs_bounds_check = gl_WorkGroupID.x == wgp_count - 1;
 
     // Clear local histogram
     histogram[gl_LocalInvocationID.x] = 0;
 
     barrier();
 
-    const uint start = gl_WorkGroupID.x * RADIX_ITEMS_PER_WGP;
-    const uint end = start + RADIX_ITEMS_PER_WGP;
-
-    const bool skip_bounds_check = end <= count;
-
-    if (skip_bounds_check) {
-        for (uint i = start; i < end; i += RADIX_DIGITS) {
-            const uint index = i + gl_LocalInvocationID.x;
-            const uint value = constants.src_buffer.values[index];
-            const uint digit = (value >> shift) & RADIX_MASK;
-            atomicAdd(histogram[digit], 1);
-        }
-    } else {
-        for (uint i = start; i < end; i += RADIX_DIGITS) {
-            const uint index = i + gl_LocalInvocationID.x;
-            if (index < count) {
-                const uint value = constants.src_buffer.values[index];
+    if (needs_bounds_check) {
+        for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) {
+            const uint global_id = gl_WorkGroupID.x * gl_WorkGroupSize.x * RADIX_ITEMS_PER_INVOCATION + i * RADIX_DIGITS + gl_LocalInvocationID.x;
+            if (global_id < count) {
+                const uint value = constants.src_buffer.values[global_id];
                 const uint digit = (value >> shift) & RADIX_MASK;
                 atomicAdd(histogram[digit], 1);
             }
         }
+    } else {
+        for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) {
+            const uint global_id = gl_WorkGroupID.x * gl_WorkGroupSize.x * RADIX_ITEMS_PER_INVOCATION + i * RADIX_DIGITS + gl_LocalInvocationID.x;
+            const uint value = constants.src_buffer.values[global_id];
+            const uint digit = (value >> shift) & RADIX_MASK;
+            atomicAdd(histogram[digit], 1);
+        }
     }
 
     barrier();
@@ -75,6 +73,5 @@ void main() {
     // Scatter to the spine, this is a striped layout so we can efficiently
     // calculate the prefix sum. Re-calculate how many workgroups we dispatched
     // to determine the stride we need to write at.
-    const uint wgp_count = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP;
     constants.spine_buffer.values[(gl_LocalInvocationID.x * wgp_count) + gl_WorkGroupID.x] = histogram[gl_LocalInvocationID.x];
 }
\ No newline at end of file
index a47dc4fcf15f5b24b571073516fc051f2f5b9232..c9459bd1e2397199bdef73699915c4ac886ef3e4 100644 (file)
@@ -36,12 +36,12 @@ layout(std430, push_constant) uniform RadixSortSpineConstantsBlock {
 
 layout (constant_id = 0) const uint SUBGROUP_SIZE = 64;
 
-const uint NUM_SUBGROUPS = RADIX_SPINE_WGP_SIZE / SUBGROUP_SIZE;
+const uint NUM_SUBGROUPS = RADIX_WGP_SIZE / SUBGROUP_SIZE;
 
+shared uint carry;
 shared uint sums[NUM_SUBGROUPS];
-shared uint carry_in;
 
-layout (local_size_x = RADIX_SPINE_WGP_SIZE, local_size_y = 1, local_size_z = 1) in;
+layout (local_size_x = RADIX_WGP_SIZE, local_size_y = 1, local_size_z = 1) in;
 
 void main() {
     const uint local_id = gl_SubgroupID * gl_SubgroupSize + gl_SubgroupInvocationID;
@@ -51,12 +51,10 @@ void main() {
     // Re-calculate how many workgroups pushed data into the spine
     const uint upsweep_wgp_count = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP;
 
-    carry_in = 0;
+    carry = 0;
     for (uint i = 0; i < upsweep_wgp_count; i++) {
-        const uint spine_index = i * RADIX_DIGITS + local_id;
-
         // Load values and calculate partial sums
-        const uint value = constants.spine_buffer.values[spine_index];
+        const uint value = constants.spine_buffer.values[i * RADIX_DIGITS + local_id];
         const uint sum = subgroupAdd(value);
         const uint scan = subgroupExclusiveAdd(value);
 
@@ -66,22 +64,20 @@ void main() {
 
         barrier();
 
+        const uint carry_in = carry;
+
         // Scan partials
         if (local_id < NUM_SUBGROUPS) {
             sums[local_id] = subgroupExclusiveAdd(sums[local_id]);
         }
 
-        const uint carry = carry_in;
-
         barrier();
 
         // Write out the final prefix sum, combining the carry-in, subgroup sums, and local scan
-        constants.spine_buffer.values[spine_index] = carry + sums[gl_SubgroupID] + scan;
+        constants.spine_buffer.values[i * RADIX_DIGITS + local_id] = carry_in + sums[gl_SubgroupID] + scan;
 
         if (gl_SubgroupID == gl_NumSubgroups - 1 && subgroupElect()) {
-            carry_in += sums[gl_SubgroupID] + sum;
+            atomicAdd(carry, sums[gl_NumSubgroups - 1] + sum);
         }
-
-        memoryBarrierShared();
     }
 }
index 19fcbcf00738b9a1d4b37dc6f26f8b385d82e37b..bbf7fcb0487b1a659e82921058a3a294814c7f86 100644 (file)
@@ -14,8 +14,6 @@
 #extension GL_KHR_shader_subgroup_shuffle_relative: enable
 #extension GL_KHR_shader_subgroup_vote : require
 
-//#extension GL_EXT_debug_printf : enable
-
 #include "compute_bindings.h"
 
 #include "radix_sort.h"
@@ -48,64 +46,97 @@ layout(std430, push_constant) uniform RadixSortDownsweepConstantsBlock {
     RadixSortDownsweepConstants constants;
 };
 
-shared uint values[RADIX_WGP_SIZE];
 shared uint spine[RADIX_DIGITS];
 shared uint match_masks[NUM_SUBGROUPS][RADIX_DIGITS];
 
 layout (local_size_x = RADIX_WGP_SIZE, local_size_y = 1, local_size_z = 1) in;
 
 void main() {
+    const uint local_id = gl_SubgroupID * gl_SubgroupSize + gl_SubgroupInvocationID;
+
     const uint shift = constants.shift;
     const uint count = constants.count_buffer.value;
     const uint wgp_count = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP;
 
-    const uint start = gl_WorkGroupID.x * RADIX_ITEMS_PER_WGP;
-    const uint end = min(start + RADIX_ITEMS_PER_WGP, count);
-
-    const uint local_id = gl_SubgroupID * gl_SubgroupSize + gl_SubgroupInvocationID;
-
     // Gather from spine.
     spine[local_id] = constants.spine_buffer.values[(local_id * wgp_count) + gl_WorkGroupID.x];
 
-    for (uint value_base = start; value_base < end; value_base += RADIX_WGP_SIZE) {
-        // Clear shared memory and load values from src buffer.
-        for (uint i = 0; i < NUM_SUBGROUPS; i++) {
-            match_masks[i][local_id] = 0;
-        }
+    const bool needs_bounds_check = gl_WorkGroupID.x == wgp_count - 1;
 
-        barrier();
+    if (needs_bounds_check) {
+        for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) {
+            const uint base = gl_WorkGroupID.x * RADIX_ITEMS_PER_WGP + i * RADIX_DIGITS;
 
-        const uint global_offset = value_base + local_id;
-        uint value = 0xffffffff;
-        uint digit = 0xff;
+            if (base >= count)
+                break;
 
-        if (global_offset < end) {
-            value = constants.src_buffer.values[global_offset];
-            digit = (value >> shift) & RADIX_MASK;
+            // Clear shared memory and load values from src buffer.
+            for (uint j = 0; j < NUM_SUBGROUPS; j++) {
+                match_masks[j][local_id] = 0;
+            }
+
+            barrier();
+
+            const uint global_id = base + local_id;
+            const uint value = global_id < count ? constants.src_buffer.values[global_id] : 0xffffffff;
+            const uint digit = (value >> shift) & RADIX_MASK;
             atomicOr(match_masks[gl_SubgroupID][digit], 1 << gl_SubgroupInvocationID);
+
+            barrier();
+
+            uint peer_scan = 0;
+            for (uint i = 0; i < gl_NumSubgroups; i++) {
+                if (i < gl_SubgroupID) {
+                    peer_scan += bitCount(match_masks[i][digit]);
+                }
+            }
+            peer_scan += bitCount(match_masks[gl_SubgroupID][digit] & gl_SubgroupLtMask.x);
+
+            if (global_id < count) {
+                constants.dst_buffer.values[spine[digit] + peer_scan] = value;
+            }
+
+            barrier();
+
+            // Increment the spine with the counts for the workgroup we just wrote out.
+            for (uint i = 0; i < NUM_SUBGROUPS; i++) {
+                atomicAdd(spine[local_id], bitCount(match_masks[i][local_id]));
+            }
         }
+    } else {
+        for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) {
+            // Clear shared memory and load values from src buffer.
+            for (uint j = 0; j < NUM_SUBGROUPS; j++) {
+                match_masks[j][local_id] = 0;
+            }
 
-        barrier();
+            barrier();
+
+            const uint global_id = gl_WorkGroupID.x * RADIX_ITEMS_PER_WGP + i * RADIX_DIGITS + local_id;
+            const uint value = constants.src_buffer.values[global_id];
+            const uint digit = (value >> shift) & RADIX_MASK;
+            atomicOr(match_masks[gl_SubgroupID][digit], 1 << gl_SubgroupInvocationID);
 
-        if (global_offset < end) {
-            const uint peer_mask = match_masks[gl_SubgroupID][digit];
+            barrier();
 
-            uint peer_scan = bitCount(peer_mask & gl_SubgroupLtMask.x);
+            uint peer_scan = 0;
             for (uint i = 0; i < gl_NumSubgroups; i++) {
                 if (i < gl_SubgroupID) {
                     peer_scan += bitCount(match_masks[i][digit]);
                 }
             }
+            peer_scan += bitCount(match_masks[gl_SubgroupID][digit] & gl_SubgroupLtMask.x);
 
-            const uint dst_index = spine[digit] + peer_scan;
-            constants.dst_buffer.values[dst_index] = value;
-        }
+            constants.dst_buffer.values[spine[digit] + peer_scan] = value;
 
-        barrier();
+            if (i != RADIX_ITEMS_PER_INVOCATION - 1) {
+                barrier();
 
-        // Increment the spine with the counts for the workgroup we just wrote out.
-        for (uint i = 0; i < NUM_SUBGROUPS; i++) {
-            atomicAdd(spine[local_id], bitCount(match_masks[i][local_id]));
+                // Increment the spine with the counts for the workgroup we just wrote out.
+                for (uint i = 0; i < NUM_SUBGROUPS; i++) {
+                    atomicAdd(spine[local_id], bitCount(match_masks[i][local_id]));
+                }
+            }
         }
     }
 }