]> git.nega.tv - josh/narcissus/commitdiff
shark-shaders: Migrate radix_sort pipeline to slang main
authorJoshua Simmons <josh@nega.tv>
Sun, 12 Oct 2025 10:33:36 +0000 (12:33 +0200)
committerJoshua Simmons <josh@nega.tv>
Tue, 14 Oct 2025 23:12:46 +0000 (01:12 +0200)
title/shark-shaders/build.rs
title/shark-shaders/shaders/bindings_compute.h [deleted file]
title/shark-shaders/shaders/indirect.h [deleted file]
title/shark-shaders/shaders/radix_sort.h [deleted file]
title/shark-shaders/shaders/radix_sort.slang [new file with mode: 0644]
title/shark-shaders/shaders/radix_sort_0_upsweep.comp [deleted file]
title/shark-shaders/shaders/radix_sort_1_downsweep.comp [deleted file]
title/shark-shaders/src/pipelines.rs

index 74112484af0d22b14dd37bad908e0e906dc080dd..0e5ececbd75eaa1b2d1270ed69674dc3f5d2a592 100644 (file)
@@ -80,18 +80,10 @@ const SLANG_SHADERS: &[SlangShader] = &[
     SlangShader { name: "basic" },
     SlangShader { name: "draw_2d" },
     SlangShader { name: "composite" },
+    SlangShader { name: "radix_sort" },
 ];
 
-const SHADERS: &[Shader] = &[
-    Shader {
-        stage: "comp",
-        name: "radix_sort_0_upsweep",
-    },
-    Shader {
-        stage: "comp",
-        name: "radix_sort_1_downsweep",
-    },
-];
+const SHADERS: &[Shader] = &[];
 
 fn main() {
     let out_dir = std::env::var("OUT_DIR").unwrap();
diff --git a/title/shark-shaders/shaders/bindings_compute.h b/title/shark-shaders/shaders/bindings_compute.h
deleted file mode 100644 (file)
index 1eb1fd3..0000000
+++ /dev/null
@@ -1,16 +0,0 @@
-#ifndef COMPUTE_BINDINGS_INCLUDE
-#define COMPUTE_BINDINGS_INCLUDE
-
-const uint SAMPLER_BILINEAR = 0;
-const uint SAMPLER_BILINEAR_UNNORMALIZED = 1;
-const uint SAMPLER_COUNT = 2;
-
-layout (set = 0, binding = 0) uniform sampler samplers[SAMPLER_COUNT];
-layout (set = 0, binding = 1) uniform texture3D tony_mc_mapface_lut;
-layout (set = 0, binding = 2) uniform texture2D glyph_atlas;
-layout (set = 0, binding = 3) uniform writeonly image2D ui_layer_write;
-layout (set = 0, binding = 3) uniform readonly image2D ui_layer_read;
-layout (set = 0, binding = 4) uniform readonly image2D color_layer;
-layout (set = 0, binding = 5) uniform writeonly image2D composited_output;
-
-#endif
diff --git a/title/shark-shaders/shaders/indirect.h b/title/shark-shaders/shaders/indirect.h
deleted file mode 100644 (file)
index d409cf2..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-#ifndef INDIRECT_H
-#define INDIRECT_H
-
-struct VkDispatchIndirectCommand {
-    uint x;
-    uint y;
-    uint z;
-};
-
-#endif
\ No newline at end of file
diff --git a/title/shark-shaders/shaders/radix_sort.h b/title/shark-shaders/shaders/radix_sort.h
deleted file mode 100644 (file)
index edf5a6c..0000000
+++ /dev/null
@@ -1,20 +0,0 @@
-#ifndef RADIX_SORT_H
-#define RADIX_SORT_H
-
-const uint RADIX_BITS = 8;
-const uint RADIX_DIGITS = 1 << RADIX_BITS;
-const uint RADIX_MASK = RADIX_DIGITS - 1;
-
-const uint RADIX_WGP_SIZE = 256;
-const uint RADIX_ITEMS_PER_INVOCATION = 16;
-const uint RADIX_ITEMS_PER_WGP = RADIX_WGP_SIZE * RADIX_ITEMS_PER_INVOCATION;
-
-layout(buffer_reference, std430, buffer_reference_align = 4) coherent buffer FinishedRef {
-    coherent uint value;
-};
-
-layout(buffer_reference, std430, buffer_reference_align = 4) readonly buffer CountRef {
-    uint value;
-};
-
-#endif
diff --git a/title/shark-shaders/shaders/radix_sort.slang b/title/shark-shaders/shaders/radix_sort.slang
new file mode 100644 (file)
index 0000000..120569e
--- /dev/null
@@ -0,0 +1,256 @@
+module radix_sort;
+
+uint4 WaveLtMask() {
+    if (WaveGetLaneIndex() < 32)
+        return uint4((1u << WaveGetLaneIndex()) - 1, 0, 0, 0);
+
+    if (WaveGetLaneIndex() < 64)
+        return uint4(~0, (1u << WaveGetLaneIndex()) - 1, 0, 0);
+
+    if (WaveGetLaneIndex() < 96)
+        return uint4(~0, ~0, (1u << WaveGetLaneIndex()) - 1, 0);
+
+    return uint4(~0, ~0, ~0, (1u << WaveGetLaneIndex()) - 1);
+}
+
+layout(constant_id = 0) const uint WAVE_SIZE = 64;
+
+namespace radix_sort {
+public func CalculateGroupCountForItemCount(uint item_count)->uint {
+    return (item_count + (RADIX_ITEMS_PER_GROUP - 1)) / RADIX_ITEMS_PER_GROUP;
+}
+}
+
+static const uint RADIX_BITS = 8;
+static const uint RADIX_DIGITS = 1 << RADIX_BITS;
+static const uint RADIX_MASK = RADIX_DIGITS - 1;
+
+static const uint RADIX_GROUP_SIZE = RADIX_DIGITS;
+static const uint RADIX_ITEMS_PER_THREAD = 16;
+static const uint RADIX_ITEMS_PER_GROUP = RADIX_GROUP_SIZE * RADIX_ITEMS_PER_THREAD;
+
+static const uint WAVE_COUNT = RADIX_GROUP_SIZE / WAVE_SIZE;
+
+struct UpsweepConstants {
+    uint shift;
+    uint _pad;
+    uint *finished_buffer;
+    uint *count_buffer;
+    uint *src_buffer;
+    uint *spine_buffer;
+};
+
+groupshared uint histogram[RADIX_DIGITS];
+
+groupshared bool is_last_group_dynamic;
+groupshared uint carry;
+groupshared uint sums[WAVE_COUNT];
+
+[shader("compute")]
+[require(spvGroupNonUniformBallot, spvGroupNonUniformArithmetic)]
+[numthreads(RADIX_GROUP_SIZE, 1, 1)]
+void upsweep(uniform UpsweepConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 group_id: SV_GroupID, uint3 thread_id_in_group: SV_GroupThreadID) {
+    let shift = constants.shift;
+    let count = constants.count_buffer[0];
+
+    let dispatch_group_count = radix_sort::CalculateGroupCountForItemCount(count);
+    let is_last_group_in_dispatch = group_id.x == dispatch_group_count - 1;
+
+    // Clear local histogram.
+    // Assumes RADIX_GROUP_SIZE == RADIX_DIGITS
+    histogram[thread_id_in_group.x] = 0;
+
+    GroupMemoryBarrierWithGroupSync();
+
+    if (is_last_group_in_dispatch) {
+        for (uint i = 0; i < RADIX_ITEMS_PER_THREAD; i++) {
+            const uint src_index = group_id.x * WorkgroupSize().x * RADIX_ITEMS_PER_THREAD + i * RADIX_DIGITS + thread_id_in_group.x;
+            if (src_index < count) {
+                const uint value = constants.src_buffer[src_index];
+                const uint digit = (value >> shift) & RADIX_MASK;
+                InterlockedAdd(histogram[digit], 1);
+            }
+        }
+    } else {
+        for (uint i = 0; i < RADIX_ITEMS_PER_THREAD; i++) {
+            const uint src_index = group_id.x * WorkgroupSize().x * RADIX_ITEMS_PER_THREAD + i * RADIX_DIGITS + thread_id_in_group.x;
+            const uint value = constants.src_buffer[src_index];
+            const uint digit = (value >> shift) & RADIX_MASK;
+            InterlockedAdd(histogram[digit], 1);
+        }
+    }
+
+    GroupMemoryBarrierWithGroupSync();
+
+    // Scatter to the spine, this is a striped layout so we can efficiently
+    // calculate the prefix sum. Re-calculate how many workgroups we dispatched
+    // to determine the stride we need to write at.
+    constants.spine_buffer[(thread_id_in_group.x * dispatch_group_count) + group_id.x] = histogram[thread_id_in_group.x];
+
+    DeviceMemoryBarrierWithGroupSync();
+
+    // Store whether we're the last-executing group in LDS. This contrasts with
+    // the 'static' `is_last_group_in_dispatch`, which represents the group with
+    // the largest group_id in the dispatch.
+    if (thread_id_in_group.x == 0) {
+        var old_value = 0;
+        InterlockedAdd(*constants.finished_buffer, 1, old_value);
+        is_last_group_dynamic = old_value == dispatch_group_count - 1;
+    }
+
+    GroupMemoryBarrierWithGroupSync();
+
+    // Only the last-executing group needs to continue, it will mop up the spine
+    // prefix sum.
+    if (!is_last_group_dynamic) {
+        return;
+    }
+
+    // Reset for the next pass.
+    InterlockedExchange(*constants.finished_buffer, 0);
+
+    let wave_id = thread_id_in_group.x / WaveGetLaneCount();
+
+    carry = 0;
+    for (uint i = 0; i < dispatch_group_count; i++) {
+        // Load values and calculate partial sums.
+        let value = constants.spine_buffer[i * RADIX_DIGITS + thread_id_in_group.x];
+        let sum = WaveActiveSum(value);
+        let scan = WavePrefixSum(value);
+
+        if (WaveIsFirstLane()) {
+            sums[wave_id] = sum;
+        }
+
+        // Even though we read and write from the spine, this can be a group
+        // barrier only as we read and write from disjoint locations each
+        // iteration.
+        //
+        // Note the write at the end of the loop body has no barrier against the
+        // spine load at the start of the loop body, across iterations.
+        GroupMemoryBarrierWithGroupSync();
+
+        // Load the carry value out of LDS here so we can borrow the barrier
+        // below.
+        let carry_in = carry;
+
+        // Scan partials.
+        if (thread_id_in_group.x < WAVE_COUNT) {
+            sums[thread_id_in_group.x] = WavePrefixSum(sums[thread_id_in_group.x]);
+        }
+
+        GroupMemoryBarrierWithGroupSync();
+
+        // Write out the final prefix sum, combining the carry-in, wave sums,
+        // and local scan.
+        constants.spine_buffer[i * RADIX_DIGITS + thread_id_in_group.x] = carry_in + sums[wave_id] + scan;
+
+        // `sums` in LDS now contains partials, so we need to also add the wave
+        // sum to get an inclusive prefix sum for the next iteration.
+        if (wave_id == WAVE_COUNT - 1 && WaveIsFirstLane()) {
+            InterlockedAdd(carry, sums[WAVE_COUNT - 1] + sum);
+        }
+    }
+}
+
+struct DownsweepConstants {
+    uint shift;
+    uint _pad;
+    uint *count_buffer;
+    uint *spine_buffer;
+    uint *src_buffer;
+    uint *dst_buffer;
+}
+
+groupshared uint spine[RADIX_DIGITS];
+groupshared uint match_masks[WAVE_COUNT][RADIX_DIGITS];
+
+[shader("compute")]
+[require(spvGroupNonUniformBallot)]
+[numthreads(RADIX_GROUP_SIZE, 1, 1)]
+void downsweep(uniform DownsweepConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 group_id: SV_GroupID, uint3 thread_id_in_group: SV_GroupThreadID) {
+    let shift = constants.shift;
+    let count = constants.count_buffer[0];
+
+    let dispatch_group_count = radix_sort::CalculateGroupCountForItemCount(count);
+    let is_last_group_in_dispatch = group_id.x == dispatch_group_count - 1;
+
+    let wave_id = thread_id_in_group.x / WaveGetLaneCount();
+
+    // Gather from spine buffer into LDS.
+    spine[thread_id_in_group.x] = constants.spine_buffer[thread_id_in_group.x * dispatch_group_count + group_id.x];
+
+    if (is_last_group_in_dispatch) {
+        for (uint i = 0; i < RADIX_ITEMS_PER_THREAD; i++) {
+            // Clear shared memory and load values from src buffer.
+            for (uint j = 0; j < WAVE_COUNT; j++) {
+                match_masks[j][thread_id_in_group.x] = 0;
+            }
+
+            GroupMemoryBarrierWithGroupSync();
+
+            let index = group_id.x * RADIX_ITEMS_PER_GROUP + i * RADIX_DIGITS + thread_id_in_group.x;
+            let value = index < count ? constants.src_buffer[index] : 0xffffffff;
+            let digit = (value >> shift) & RADIX_MASK;
+            InterlockedOr(match_masks[wave_id][digit], 1 << WaveGetLaneIndex());
+
+            GroupMemoryBarrierWithGroupSync();
+
+            uint peer_scan = 0;
+            for (uint j = 0; j < WAVE_COUNT; j++) {
+                if (j < wave_id) {
+                    peer_scan += countbits(match_masks[j][digit]);
+                }
+            }
+            peer_scan += countbits(match_masks[wave_id][digit] & WaveLtMask().x);
+
+            if (index < count) {
+                constants.dst_buffer[spine[digit] + peer_scan] = value;
+            }
+
+            GroupMemoryBarrierWithGroupSync();
+
+            // Increment the spine with the counts for the workgroup we just
+            // wrote out.
+            for (uint i = 0; i < WAVE_COUNT; i++) {
+                InterlockedAdd(spine[thread_id_in_group.x], countbits(match_masks[i][thread_id_in_group.x]));
+            }
+        }
+    } else {
+        for (uint i = 0; i < RADIX_ITEMS_PER_THREAD; i++) {
+            // Clear shared memory and load values from src buffer.
+            for (uint j = 0; j < WAVE_COUNT; j++) {
+                match_masks[j][thread_id_in_group.x] = 0;
+            }
+
+            GroupMemoryBarrierWithGroupSync();
+
+            let index = group_id.x * RADIX_ITEMS_PER_GROUP + i * RADIX_DIGITS + thread_id_in_group.x;
+            let value = constants.src_buffer[index];
+            let digit = (value >> shift) & RADIX_MASK;
+            InterlockedOr(match_masks[wave_id][digit], 1 << WaveGetLaneIndex());
+
+            GroupMemoryBarrierWithGroupSync();
+
+            uint peer_scan = 0;
+            for (uint j = 0; j < WAVE_COUNT; j++) {
+                if (j < wave_id) {
+                    peer_scan += countbits(match_masks[j][digit]);
+                }
+            }
+            peer_scan += countbits(match_masks[wave_id][digit] & WaveLtMask().x);
+
+            constants.dst_buffer[spine[digit] + peer_scan] = value;
+
+            if (i != RADIX_ITEMS_PER_THREAD - 1) {
+                GroupMemoryBarrierWithGroupSync();
+
+                // Increment the spine with the counts for the workgroup we just
+                // wrote out.
+                for (uint i = 0; i < WAVE_COUNT; i++) {
+                    InterlockedAdd(spine[thread_id_in_group.x], countbits(match_masks[i][thread_id_in_group.x]));
+                }
+            }
+        }
+    }
+}
diff --git a/title/shark-shaders/shaders/radix_sort_0_upsweep.comp b/title/shark-shaders/shaders/radix_sort_0_upsweep.comp
deleted file mode 100644 (file)
index 1e9e571..0000000
+++ /dev/null
@@ -1,135 +0,0 @@
-#version 460
-
-#extension GL_GOOGLE_include_directive : require
-
-#extension GL_EXT_buffer_reference : require
-#extension GL_EXT_buffer_reference2 : require
-#extension GL_EXT_scalar_block_layout : require
-#extension GL_EXT_control_flow_attributes : require
-
-#extension GL_KHR_shader_subgroup_arithmetic : require
-#extension GL_KHR_shader_subgroup_ballot : require
-#extension GL_KHR_shader_subgroup_shuffle_relative: enable
-#extension GL_KHR_shader_subgroup_vote : require
-
-#include "radix_sort.h"
-
-#include "indirect.h"
-
-layout (constant_id = 0) const uint SUBGROUP_SIZE = 64;
-
-const uint NUM_SUBGROUPS = RADIX_WGP_SIZE / SUBGROUP_SIZE;
-
-layout(buffer_reference, std430, buffer_reference_align = 4) readonly buffer ValuesRef {
-    uint values[];
-};
-
-layout(buffer_reference, std430, buffer_reference_align = 4) buffer SpineRef {
-    uint values[];
-};
-
-struct RadixSortUpsweepConstants {
-    uint shift;
-    uint _pad;
-    FinishedRef finished_buffer;
-    CountRef count_buffer;
-    ValuesRef src_buffer;
-    SpineRef spine_buffer;
-};
-
-layout(std430, push_constant) uniform RadixSortUpsweepConstantsBlock {
-    RadixSortUpsweepConstants constants;
-};
-
-shared uint histogram[RADIX_DIGITS];
-
-shared bool finished;
-shared uint carry;
-shared uint sums[NUM_SUBGROUPS];
-
-layout (local_size_x = RADIX_DIGITS, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
-    const uint shift = constants.shift;
-    const uint count = constants.count_buffer.value;
-    const uint workgroup_count = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP;
-
-    const bool needs_bounds_check = gl_WorkGroupID.x == workgroup_count - 1;
-
-    // Clear local histogram
-    histogram[gl_LocalInvocationID.x] = 0;
-
-    barrier();
-
-    if (needs_bounds_check) {
-        for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) {
-            const uint global_id = gl_WorkGroupID.x * gl_WorkGroupSize.x * RADIX_ITEMS_PER_INVOCATION + i * RADIX_DIGITS + gl_LocalInvocationID.x;
-            if (global_id < count) {
-                const uint value = constants.src_buffer.values[global_id];
-                const uint digit = (value >> shift) & RADIX_MASK;
-                atomicAdd(histogram[digit], 1);
-            }
-        }
-    } else {
-        for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) {
-            const uint global_id = gl_WorkGroupID.x * gl_WorkGroupSize.x * RADIX_ITEMS_PER_INVOCATION + i * RADIX_DIGITS + gl_LocalInvocationID.x;
-            const uint value = constants.src_buffer.values[global_id];
-            const uint digit = (value >> shift) & RADIX_MASK;
-            atomicAdd(histogram[digit], 1);
-        }
-    }
-
-    barrier();
-
-    // Scatter to the spine, this is a striped layout so we can efficiently
-    // calculate the prefix sum. Re-calculate how many workgroups we dispatched
-    // to determine the stride we need to write at.
-    constants.spine_buffer.values[(gl_LocalInvocationID.x * workgroup_count) + gl_WorkGroupID.x] = histogram[gl_LocalInvocationID.x];
-
-    barrier();
-
-    if (gl_SubgroupID == 0 && subgroupElect()) {
-        finished = atomicAdd(constants.finished_buffer.value, 1) < workgroup_count - 1;
-    }
-
-    barrier();
-
-    if (finished) {
-        return;
-    }
-
-    // reset for the next pass
-    constants.finished_buffer.value = 0;
-
-    const uint local_id = gl_SubgroupID * gl_SubgroupSize + gl_SubgroupInvocationID;
-
-    carry = 0;
-    for (uint i = 0; i < workgroup_count; i++) {
-        // Load values and calculate partial sums
-        const uint value = constants.spine_buffer.values[i * RADIX_DIGITS + local_id];
-        const uint sum = subgroupAdd(value);
-        const uint scan = subgroupExclusiveAdd(value);
-
-        if (subgroupElect()) {
-            sums[gl_SubgroupID] = sum;
-        }
-
-        barrier();
-
-        const uint carry_in = carry;
-
-        // Scan partials
-        if (local_id < NUM_SUBGROUPS) {
-            sums[local_id] = subgroupExclusiveAdd(sums[local_id]);
-        }
-
-        barrier();
-
-        // Write out the final prefix sum, combining the carry-in, subgroup sums, and local scan
-        constants.spine_buffer.values[i * RADIX_DIGITS + local_id] = carry_in + sums[gl_SubgroupID] + scan;
-
-        if (gl_SubgroupID == gl_NumSubgroups - 1 && subgroupElect()) {
-            atomicAdd(carry, sums[gl_NumSubgroups - 1] + sum);
-        }
-    }
-}
\ No newline at end of file
diff --git a/title/shark-shaders/shaders/radix_sort_1_downsweep.comp b/title/shark-shaders/shaders/radix_sort_1_downsweep.comp
deleted file mode 100644 (file)
index 23ff1c3..0000000
+++ /dev/null
@@ -1,137 +0,0 @@
-#version 460
-
-#extension GL_GOOGLE_include_directive : require
-
-#extension GL_EXT_buffer_reference : require
-#extension GL_EXT_buffer_reference2 : require
-#extension GL_EXT_scalar_block_layout : require
-#extension GL_EXT_control_flow_attributes : require
-
-#extension GL_KHR_shader_subgroup_arithmetic : require
-#extension GL_KHR_shader_subgroup_ballot : require
-#extension GL_KHR_shader_subgroup_shuffle_relative: enable
-#extension GL_KHR_shader_subgroup_vote : require
-
-#include "radix_sort.h"
-
-#include "indirect.h"
-
-layout (constant_id = 0) const uint SUBGROUP_SIZE = 64;
-
-const uint NUM_SUBGROUPS = RADIX_WGP_SIZE / SUBGROUP_SIZE;
-
-layout(buffer_reference, std430, buffer_reference_align = 4) readonly buffer SpineRef {
-    uint values[];
-};
-
-layout(buffer_reference, std430, buffer_reference_align = 4) buffer ValuesRef {
-    uint values[];
-};
-
-struct RadixSortDownsweepConstants {
-    uint shift;
-    uint _pad;
-    CountRef count_buffer;
-    SpineRef spine_buffer;
-    ValuesRef src_buffer;
-    ValuesRef dst_buffer;
-};
-
-layout(std430, push_constant) uniform RadixSortDownsweepConstantsBlock {
-    RadixSortDownsweepConstants constants;
-};
-
-shared uint spine[RADIX_DIGITS];
-shared uint match_masks[NUM_SUBGROUPS][RADIX_DIGITS];
-
-layout (local_size_x = RADIX_WGP_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
-    const uint local_id = gl_SubgroupID * gl_SubgroupSize + gl_SubgroupInvocationID;
-
-    const uint shift = constants.shift;
-    const uint count = constants.count_buffer.value;
-    const uint wgp_count = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP;
-
-    // Gather from spine.
-    spine[local_id] = constants.spine_buffer.values[(local_id * wgp_count) + gl_WorkGroupID.x];
-
-    const bool needs_bounds_check = gl_WorkGroupID.x == wgp_count - 1;
-
-    if (needs_bounds_check) {
-        for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) {
-            const uint base = gl_WorkGroupID.x * RADIX_ITEMS_PER_WGP + i * RADIX_DIGITS;
-
-            if (base >= count)
-                break;
-
-            // Clear shared memory and load values from src buffer.
-            for (uint j = 0; j < NUM_SUBGROUPS; j++) {
-                match_masks[j][local_id] = 0;
-            }
-
-            barrier();
-
-            const uint global_id = base + local_id;
-            const uint value = global_id < count ? constants.src_buffer.values[global_id] : 0xffffffff;
-            const uint digit = (value >> shift) & RADIX_MASK;
-            atomicOr(match_masks[gl_SubgroupID][digit], 1 << gl_SubgroupInvocationID);
-
-            barrier();
-
-            uint peer_scan = 0;
-            for (uint i = 0; i < gl_NumSubgroups; i++) {
-                if (i < gl_SubgroupID) {
-                    peer_scan += bitCount(match_masks[i][digit]);
-                }
-            }
-            peer_scan += bitCount(match_masks[gl_SubgroupID][digit] & gl_SubgroupLtMask.x);
-
-            if (global_id < count) {
-                constants.dst_buffer.values[spine[digit] + peer_scan] = value;
-            }
-
-            barrier();
-
-            // Increment the spine with the counts for the workgroup we just wrote out.
-            for (uint i = 0; i < NUM_SUBGROUPS; i++) {
-                atomicAdd(spine[local_id], bitCount(match_masks[i][local_id]));
-            }
-        }
-    } else {
-        for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) {
-            // Clear shared memory and load values from src buffer.
-            for (uint j = 0; j < NUM_SUBGROUPS; j++) {
-                match_masks[j][local_id] = 0;
-            }
-
-            barrier();
-
-            const uint global_id = gl_WorkGroupID.x * RADIX_ITEMS_PER_WGP + i * RADIX_DIGITS + local_id;
-            const uint value = constants.src_buffer.values[global_id];
-            const uint digit = (value >> shift) & RADIX_MASK;
-            atomicOr(match_masks[gl_SubgroupID][digit], 1 << gl_SubgroupInvocationID);
-
-            barrier();
-
-            uint peer_scan = 0;
-            for (uint i = 0; i < gl_NumSubgroups; i++) {
-                if (i < gl_SubgroupID) {
-                    peer_scan += bitCount(match_masks[i][digit]);
-                }
-            }
-            peer_scan += bitCount(match_masks[gl_SubgroupID][digit] & gl_SubgroupLtMask.x);
-
-            constants.dst_buffer.values[spine[digit] + peer_scan] = value;
-
-            if (i != RADIX_ITEMS_PER_INVOCATION - 1) {
-                barrier();
-
-                // Increment the spine with the counts for the workgroup we just wrote out.
-                for (uint i = 0; i < NUM_SUBGROUPS; i++) {
-                    atomicAdd(spine[local_id], bitCount(match_masks[i][local_id]));
-                }
-            }
-        }
-    }
-}
index 19abe02259d913f86db0b0fe386fbd89f2264b87..6c3fef17320686de5bfbdd79ffdb28d5541f0f92 100644 (file)
@@ -464,16 +464,18 @@ impl Pipelines {
             std::mem::size_of::<Draw2dRasterizeConstants>(),
         );
 
-        let radix_sort_0_upsweep_pipeline = create_compute_pipeline(
-            crate::RADIX_SORT_0_UPSWEEP_COMP_SPV,
-            "radix_sort_upsweep",
+        let radix_sort_0_upsweep_pipeline = create_compute_pipeline_with_entry(
+            crate::RADIX_SORT_SPV,
+            c"upsweep",
+            "radix sort upsweep",
             32,
             true,
             std::mem::size_of::<RadixSortUpsweepConstants>(),
         );
 
-        let radix_sort_1_downsweep_pipeline = create_compute_pipeline(
-            crate::RADIX_SORT_1_DOWNSWEEP_COMP_SPV,
+        let radix_sort_1_downsweep_pipeline = create_compute_pipeline_with_entry(
+            crate::RADIX_SORT_SPV,
+            c"downsweep",
             "radix_sort_downsweep",
             32,
             true,