SlangShader { name: "basic" },
SlangShader { name: "draw_2d" },
SlangShader { name: "composite" },
+ SlangShader { name: "radix_sort" },
];
-const SHADERS: &[Shader] = &[
- Shader {
- stage: "comp",
- name: "radix_sort_0_upsweep",
- },
- Shader {
- stage: "comp",
- name: "radix_sort_1_downsweep",
- },
-];
+const SHADERS: &[Shader] = &[];
fn main() {
let out_dir = std::env::var("OUT_DIR").unwrap();
+++ /dev/null
-#ifndef COMPUTE_BINDINGS_INCLUDE
-#define COMPUTE_BINDINGS_INCLUDE
-
-const uint SAMPLER_BILINEAR = 0;
-const uint SAMPLER_BILINEAR_UNNORMALIZED = 1;
-const uint SAMPLER_COUNT = 2;
-
-layout (set = 0, binding = 0) uniform sampler samplers[SAMPLER_COUNT];
-layout (set = 0, binding = 1) uniform texture3D tony_mc_mapface_lut;
-layout (set = 0, binding = 2) uniform texture2D glyph_atlas;
-layout (set = 0, binding = 3) uniform writeonly image2D ui_layer_write;
-layout (set = 0, binding = 3) uniform readonly image2D ui_layer_read;
-layout (set = 0, binding = 4) uniform readonly image2D color_layer;
-layout (set = 0, binding = 5) uniform writeonly image2D composited_output;
-
-#endif
+++ /dev/null
-#ifndef INDIRECT_H
-#define INDIRECT_H
-
-struct VkDispatchIndirectCommand {
- uint x;
- uint y;
- uint z;
-};
-
-#endif
\ No newline at end of file
+++ /dev/null
-#ifndef RADIX_SORT_H
-#define RADIX_SORT_H
-
-const uint RADIX_BITS = 8;
-const uint RADIX_DIGITS = 1 << RADIX_BITS;
-const uint RADIX_MASK = RADIX_DIGITS - 1;
-
-const uint RADIX_WGP_SIZE = 256;
-const uint RADIX_ITEMS_PER_INVOCATION = 16;
-const uint RADIX_ITEMS_PER_WGP = RADIX_WGP_SIZE * RADIX_ITEMS_PER_INVOCATION;
-
-layout(buffer_reference, std430, buffer_reference_align = 4) coherent buffer FinishedRef {
- coherent uint value;
-};
-
-layout(buffer_reference, std430, buffer_reference_align = 4) readonly buffer CountRef {
- uint value;
-};
-
-#endif
--- /dev/null
+module radix_sort;
+
+uint4 WaveLtMask() {
+ if (WaveGetLaneIndex() < 32)
+ return uint4((1u << WaveGetLaneIndex()) - 1, 0, 0, 0);
+
+ if (WaveGetLaneIndex() < 64)
+ return uint4(~0, (1u << WaveGetLaneIndex()) - 1, 0, 0);
+
+ if (WaveGetLaneIndex() < 96)
+ return uint4(~0, ~0, (1u << WaveGetLaneIndex()) - 1, 0);
+
+ return uint4(~0, ~0, ~0, (1u << WaveGetLaneIndex()) - 1);
+}
+
+layout(constant_id = 0) const uint WAVE_SIZE = 64;
+
+namespace radix_sort {
+public func CalculateGroupCountForItemCount(uint item_count)->uint {
+ return (item_count + (RADIX_ITEMS_PER_GROUP - 1)) / RADIX_ITEMS_PER_GROUP;
+}
+}
+
+static const uint RADIX_BITS = 8;
+static const uint RADIX_DIGITS = 1 << RADIX_BITS;
+static const uint RADIX_MASK = RADIX_DIGITS - 1;
+
+static const uint RADIX_GROUP_SIZE = RADIX_DIGITS;
+static const uint RADIX_ITEMS_PER_THREAD = 16;
+static const uint RADIX_ITEMS_PER_GROUP = RADIX_GROUP_SIZE * RADIX_ITEMS_PER_THREAD;
+
+static const uint WAVE_COUNT = RADIX_GROUP_SIZE / WAVE_SIZE;
+
+struct UpsweepConstants {
+ uint shift;
+ uint _pad;
+ uint *finished_buffer;
+ uint *count_buffer;
+ uint *src_buffer;
+ uint *spine_buffer;
+};
+
+groupshared uint histogram[RADIX_DIGITS];
+
+groupshared bool is_last_group_dynamic;
+groupshared uint carry;
+groupshared uint sums[WAVE_COUNT];
+
+[shader("compute")]
+[require(spvGroupNonUniformBallot, spvGroupNonUniformArithmetic)]
+[numthreads(RADIX_GROUP_SIZE, 1, 1)]
+void upsweep(uniform UpsweepConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 group_id: SV_GroupID, uint3 thread_id_in_group: SV_GroupThreadID) {
+ let shift = constants.shift;
+ let count = constants.count_buffer[0];
+
+ let dispatch_group_count = radix_sort::CalculateGroupCountForItemCount(count);
+ let is_last_group_in_dispatch = group_id.x == dispatch_group_count - 1;
+
+ // Clear local histogram.
+ // Assumes RADIX_GROUP_SIZE == RADIX_DIGITS
+ histogram[thread_id_in_group.x] = 0;
+
+ GroupMemoryBarrierWithGroupSync();
+
+ if (is_last_group_in_dispatch) {
+ for (uint i = 0; i < RADIX_ITEMS_PER_THREAD; i++) {
+ const uint src_index = group_id.x * WorkgroupSize().x * RADIX_ITEMS_PER_THREAD + i * RADIX_DIGITS + thread_id_in_group.x;
+ if (src_index < count) {
+ const uint value = constants.src_buffer[src_index];
+ const uint digit = (value >> shift) & RADIX_MASK;
+ InterlockedAdd(histogram[digit], 1);
+ }
+ }
+ } else {
+ for (uint i = 0; i < RADIX_ITEMS_PER_THREAD; i++) {
+ const uint src_index = group_id.x * WorkgroupSize().x * RADIX_ITEMS_PER_THREAD + i * RADIX_DIGITS + thread_id_in_group.x;
+ const uint value = constants.src_buffer[src_index];
+ const uint digit = (value >> shift) & RADIX_MASK;
+ InterlockedAdd(histogram[digit], 1);
+ }
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ // Scatter to the spine, this is a striped layout so we can efficiently
+ // calculate the prefix sum. Re-calculate how many workgroups we dispatched
+ // to determine the stride we need to write at.
+ constants.spine_buffer[(thread_id_in_group.x * dispatch_group_count) + group_id.x] = histogram[thread_id_in_group.x];
+
+ DeviceMemoryBarrierWithGroupSync();
+
+ // Store whether we're the last-executing group in LDS. This contrasts with
+ // the 'static' `is_last_group_in_dispatch`, which represents the group with
+ // the largest group_id in the dispatch.
+ if (thread_id_in_group.x == 0) {
+ var old_value = 0;
+ InterlockedAdd(*constants.finished_buffer, 1, old_value);
+ is_last_group_dynamic = old_value == dispatch_group_count - 1;
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ // Only the last-executing group needs to continue, it will mop up the spine
+ // prefix sum.
+ if (!is_last_group_dynamic) {
+ return;
+ }
+
+ // Reset for the next pass.
+ InterlockedExchange(*constants.finished_buffer, 0);
+
+ let wave_id = thread_id_in_group.x / WaveGetLaneCount();
+
+ carry = 0;
+ for (uint i = 0; i < dispatch_group_count; i++) {
+ // Load values and calculate partial sums.
+ let value = constants.spine_buffer[i * RADIX_DIGITS + thread_id_in_group.x];
+ let sum = WaveActiveSum(value);
+ let scan = WavePrefixSum(value);
+
+ if (WaveIsFirstLane()) {
+ sums[wave_id] = sum;
+ }
+
+ // Even though we read and write from the spine, this can be a group
+ // barrier only as we read and write from disjoint locations each
+ // iteration.
+ //
+ // Note the write at the end of the loop body has no barrier against the
+ // spine load at the start of the loop body, across iterations.
+ GroupMemoryBarrierWithGroupSync();
+
+ // Load the carry value out of LDS here so we can borrow the barrier
+ // below.
+ let carry_in = carry;
+
+ // Scan partials.
+ if (thread_id_in_group.x < WAVE_COUNT) {
+ sums[thread_id_in_group.x] = WavePrefixSum(sums[thread_id_in_group.x]);
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ // Write out the final prefix sum, combining the carry-in, wave sums,
+ // and local scan.
+ constants.spine_buffer[i * RADIX_DIGITS + thread_id_in_group.x] = carry_in + sums[wave_id] + scan;
+
+ // `sums` in LDS now contains partials, so we need to also add the wave
+ // sum to get an inclusive prefix sum for the next iteration.
+ if (wave_id == WAVE_COUNT - 1 && WaveIsFirstLane()) {
+ InterlockedAdd(carry, sums[WAVE_COUNT - 1] + sum);
+ }
+ }
+}
+
+struct DownsweepConstants {
+ uint shift;
+ uint _pad;
+ uint *count_buffer;
+ uint *spine_buffer;
+ uint *src_buffer;
+ uint *dst_buffer;
+}
+
+groupshared uint spine[RADIX_DIGITS];
+groupshared uint match_masks[WAVE_COUNT][RADIX_DIGITS];
+
+[shader("compute")]
+[require(spvGroupNonUniformBallot)]
+[numthreads(RADIX_GROUP_SIZE, 1, 1)]
+void downsweep(uniform DownsweepConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 group_id: SV_GroupID, uint3 thread_id_in_group: SV_GroupThreadID) {
+ let shift = constants.shift;
+ let count = constants.count_buffer[0];
+
+ let dispatch_group_count = radix_sort::CalculateGroupCountForItemCount(count);
+ let is_last_group_in_dispatch = group_id.x == dispatch_group_count - 1;
+
+ let wave_id = thread_id_in_group.x / WaveGetLaneCount();
+
+ // Gather from spine buffer into LDS.
+ spine[thread_id_in_group.x] = constants.spine_buffer[thread_id_in_group.x * dispatch_group_count + group_id.x];
+
+ if (is_last_group_in_dispatch) {
+ for (uint i = 0; i < RADIX_ITEMS_PER_THREAD; i++) {
+ // Clear shared memory and load values from src buffer.
+ for (uint j = 0; j < WAVE_COUNT; j++) {
+ match_masks[j][thread_id_in_group.x] = 0;
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ let index = group_id.x * RADIX_ITEMS_PER_GROUP + i * RADIX_DIGITS + thread_id_in_group.x;
+ let value = index < count ? constants.src_buffer[index] : 0xffffffff;
+ let digit = (value >> shift) & RADIX_MASK;
+ InterlockedOr(match_masks[wave_id][digit], 1 << WaveGetLaneIndex());
+
+ GroupMemoryBarrierWithGroupSync();
+
+ uint peer_scan = 0;
+ for (uint j = 0; j < WAVE_COUNT; j++) {
+ if (j < wave_id) {
+ peer_scan += countbits(match_masks[j][digit]);
+ }
+ }
+ peer_scan += countbits(match_masks[wave_id][digit] & WaveLtMask().x);
+
+ if (index < count) {
+ constants.dst_buffer[spine[digit] + peer_scan] = value;
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ // Increment the spine with the counts for the workgroup we just
+ // wrote out.
+ for (uint i = 0; i < WAVE_COUNT; i++) {
+ InterlockedAdd(spine[thread_id_in_group.x], countbits(match_masks[i][thread_id_in_group.x]));
+ }
+ }
+ } else {
+ for (uint i = 0; i < RADIX_ITEMS_PER_THREAD; i++) {
+ // Clear shared memory and load values from src buffer.
+ for (uint j = 0; j < WAVE_COUNT; j++) {
+ match_masks[j][thread_id_in_group.x] = 0;
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ let index = group_id.x * RADIX_ITEMS_PER_GROUP + i * RADIX_DIGITS + thread_id_in_group.x;
+ let value = constants.src_buffer[index];
+ let digit = (value >> shift) & RADIX_MASK;
+ InterlockedOr(match_masks[wave_id][digit], 1 << WaveGetLaneIndex());
+
+ GroupMemoryBarrierWithGroupSync();
+
+ uint peer_scan = 0;
+ for (uint j = 0; j < WAVE_COUNT; j++) {
+ if (j < wave_id) {
+ peer_scan += countbits(match_masks[j][digit]);
+ }
+ }
+ peer_scan += countbits(match_masks[wave_id][digit] & WaveLtMask().x);
+
+ constants.dst_buffer[spine[digit] + peer_scan] = value;
+
+ if (i != RADIX_ITEMS_PER_THREAD - 1) {
+ GroupMemoryBarrierWithGroupSync();
+
+ // Increment the spine with the counts for the workgroup we just
+ // wrote out.
+ for (uint i = 0; i < WAVE_COUNT; i++) {
+ InterlockedAdd(spine[thread_id_in_group.x], countbits(match_masks[i][thread_id_in_group.x]));
+ }
+ }
+ }
+ }
+}
+++ /dev/null
-#version 460
-
-#extension GL_GOOGLE_include_directive : require
-
-#extension GL_EXT_buffer_reference : require
-#extension GL_EXT_buffer_reference2 : require
-#extension GL_EXT_scalar_block_layout : require
-#extension GL_EXT_control_flow_attributes : require
-
-#extension GL_KHR_shader_subgroup_arithmetic : require
-#extension GL_KHR_shader_subgroup_ballot : require
-#extension GL_KHR_shader_subgroup_shuffle_relative: enable
-#extension GL_KHR_shader_subgroup_vote : require
-
-#include "radix_sort.h"
-
-#include "indirect.h"
-
-layout (constant_id = 0) const uint SUBGROUP_SIZE = 64;
-
-const uint NUM_SUBGROUPS = RADIX_WGP_SIZE / SUBGROUP_SIZE;
-
-layout(buffer_reference, std430, buffer_reference_align = 4) readonly buffer ValuesRef {
- uint values[];
-};
-
-layout(buffer_reference, std430, buffer_reference_align = 4) buffer SpineRef {
- uint values[];
-};
-
-struct RadixSortUpsweepConstants {
- uint shift;
- uint _pad;
- FinishedRef finished_buffer;
- CountRef count_buffer;
- ValuesRef src_buffer;
- SpineRef spine_buffer;
-};
-
-layout(std430, push_constant) uniform RadixSortUpsweepConstantsBlock {
- RadixSortUpsweepConstants constants;
-};
-
-shared uint histogram[RADIX_DIGITS];
-
-shared bool finished;
-shared uint carry;
-shared uint sums[NUM_SUBGROUPS];
-
-layout (local_size_x = RADIX_DIGITS, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
- const uint shift = constants.shift;
- const uint count = constants.count_buffer.value;
- const uint workgroup_count = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP;
-
- const bool needs_bounds_check = gl_WorkGroupID.x == workgroup_count - 1;
-
- // Clear local histogram
- histogram[gl_LocalInvocationID.x] = 0;
-
- barrier();
-
- if (needs_bounds_check) {
- for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) {
- const uint global_id = gl_WorkGroupID.x * gl_WorkGroupSize.x * RADIX_ITEMS_PER_INVOCATION + i * RADIX_DIGITS + gl_LocalInvocationID.x;
- if (global_id < count) {
- const uint value = constants.src_buffer.values[global_id];
- const uint digit = (value >> shift) & RADIX_MASK;
- atomicAdd(histogram[digit], 1);
- }
- }
- } else {
- for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) {
- const uint global_id = gl_WorkGroupID.x * gl_WorkGroupSize.x * RADIX_ITEMS_PER_INVOCATION + i * RADIX_DIGITS + gl_LocalInvocationID.x;
- const uint value = constants.src_buffer.values[global_id];
- const uint digit = (value >> shift) & RADIX_MASK;
- atomicAdd(histogram[digit], 1);
- }
- }
-
- barrier();
-
- // Scatter to the spine, this is a striped layout so we can efficiently
- // calculate the prefix sum. Re-calculate how many workgroups we dispatched
- // to determine the stride we need to write at.
- constants.spine_buffer.values[(gl_LocalInvocationID.x * workgroup_count) + gl_WorkGroupID.x] = histogram[gl_LocalInvocationID.x];
-
- barrier();
-
- if (gl_SubgroupID == 0 && subgroupElect()) {
- finished = atomicAdd(constants.finished_buffer.value, 1) < workgroup_count - 1;
- }
-
- barrier();
-
- if (finished) {
- return;
- }
-
- // reset for the next pass
- constants.finished_buffer.value = 0;
-
- const uint local_id = gl_SubgroupID * gl_SubgroupSize + gl_SubgroupInvocationID;
-
- carry = 0;
- for (uint i = 0; i < workgroup_count; i++) {
- // Load values and calculate partial sums
- const uint value = constants.spine_buffer.values[i * RADIX_DIGITS + local_id];
- const uint sum = subgroupAdd(value);
- const uint scan = subgroupExclusiveAdd(value);
-
- if (subgroupElect()) {
- sums[gl_SubgroupID] = sum;
- }
-
- barrier();
-
- const uint carry_in = carry;
-
- // Scan partials
- if (local_id < NUM_SUBGROUPS) {
- sums[local_id] = subgroupExclusiveAdd(sums[local_id]);
- }
-
- barrier();
-
- // Write out the final prefix sum, combining the carry-in, subgroup sums, and local scan
- constants.spine_buffer.values[i * RADIX_DIGITS + local_id] = carry_in + sums[gl_SubgroupID] + scan;
-
- if (gl_SubgroupID == gl_NumSubgroups - 1 && subgroupElect()) {
- atomicAdd(carry, sums[gl_NumSubgroups - 1] + sum);
- }
- }
-}
\ No newline at end of file
+++ /dev/null
-#version 460
-
-#extension GL_GOOGLE_include_directive : require
-
-#extension GL_EXT_buffer_reference : require
-#extension GL_EXT_buffer_reference2 : require
-#extension GL_EXT_scalar_block_layout : require
-#extension GL_EXT_control_flow_attributes : require
-
-#extension GL_KHR_shader_subgroup_arithmetic : require
-#extension GL_KHR_shader_subgroup_ballot : require
-#extension GL_KHR_shader_subgroup_shuffle_relative: enable
-#extension GL_KHR_shader_subgroup_vote : require
-
-#include "radix_sort.h"
-
-#include "indirect.h"
-
-layout (constant_id = 0) const uint SUBGROUP_SIZE = 64;
-
-const uint NUM_SUBGROUPS = RADIX_WGP_SIZE / SUBGROUP_SIZE;
-
-layout(buffer_reference, std430, buffer_reference_align = 4) readonly buffer SpineRef {
- uint values[];
-};
-
-layout(buffer_reference, std430, buffer_reference_align = 4) buffer ValuesRef {
- uint values[];
-};
-
-struct RadixSortDownsweepConstants {
- uint shift;
- uint _pad;
- CountRef count_buffer;
- SpineRef spine_buffer;
- ValuesRef src_buffer;
- ValuesRef dst_buffer;
-};
-
-layout(std430, push_constant) uniform RadixSortDownsweepConstantsBlock {
- RadixSortDownsweepConstants constants;
-};
-
-shared uint spine[RADIX_DIGITS];
-shared uint match_masks[NUM_SUBGROUPS][RADIX_DIGITS];
-
-layout (local_size_x = RADIX_WGP_SIZE, local_size_y = 1, local_size_z = 1) in;
-
-void main() {
- const uint local_id = gl_SubgroupID * gl_SubgroupSize + gl_SubgroupInvocationID;
-
- const uint shift = constants.shift;
- const uint count = constants.count_buffer.value;
- const uint wgp_count = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP;
-
- // Gather from spine.
- spine[local_id] = constants.spine_buffer.values[(local_id * wgp_count) + gl_WorkGroupID.x];
-
- const bool needs_bounds_check = gl_WorkGroupID.x == wgp_count - 1;
-
- if (needs_bounds_check) {
- for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) {
- const uint base = gl_WorkGroupID.x * RADIX_ITEMS_PER_WGP + i * RADIX_DIGITS;
-
- if (base >= count)
- break;
-
- // Clear shared memory and load values from src buffer.
- for (uint j = 0; j < NUM_SUBGROUPS; j++) {
- match_masks[j][local_id] = 0;
- }
-
- barrier();
-
- const uint global_id = base + local_id;
- const uint value = global_id < count ? constants.src_buffer.values[global_id] : 0xffffffff;
- const uint digit = (value >> shift) & RADIX_MASK;
- atomicOr(match_masks[gl_SubgroupID][digit], 1 << gl_SubgroupInvocationID);
-
- barrier();
-
- uint peer_scan = 0;
- for (uint i = 0; i < gl_NumSubgroups; i++) {
- if (i < gl_SubgroupID) {
- peer_scan += bitCount(match_masks[i][digit]);
- }
- }
- peer_scan += bitCount(match_masks[gl_SubgroupID][digit] & gl_SubgroupLtMask.x);
-
- if (global_id < count) {
- constants.dst_buffer.values[spine[digit] + peer_scan] = value;
- }
-
- barrier();
-
- // Increment the spine with the counts for the workgroup we just wrote out.
- for (uint i = 0; i < NUM_SUBGROUPS; i++) {
- atomicAdd(spine[local_id], bitCount(match_masks[i][local_id]));
- }
- }
- } else {
- for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) {
- // Clear shared memory and load values from src buffer.
- for (uint j = 0; j < NUM_SUBGROUPS; j++) {
- match_masks[j][local_id] = 0;
- }
-
- barrier();
-
- const uint global_id = gl_WorkGroupID.x * RADIX_ITEMS_PER_WGP + i * RADIX_DIGITS + local_id;
- const uint value = constants.src_buffer.values[global_id];
- const uint digit = (value >> shift) & RADIX_MASK;
- atomicOr(match_masks[gl_SubgroupID][digit], 1 << gl_SubgroupInvocationID);
-
- barrier();
-
- uint peer_scan = 0;
- for (uint i = 0; i < gl_NumSubgroups; i++) {
- if (i < gl_SubgroupID) {
- peer_scan += bitCount(match_masks[i][digit]);
- }
- }
- peer_scan += bitCount(match_masks[gl_SubgroupID][digit] & gl_SubgroupLtMask.x);
-
- constants.dst_buffer.values[spine[digit] + peer_scan] = value;
-
- if (i != RADIX_ITEMS_PER_INVOCATION - 1) {
- barrier();
-
- // Increment the spine with the counts for the workgroup we just wrote out.
- for (uint i = 0; i < NUM_SUBGROUPS; i++) {
- atomicAdd(spine[local_id], bitCount(match_masks[i][local_id]));
- }
- }
- }
- }
-}
std::mem::size_of::<Draw2dRasterizeConstants>(),
);
- let radix_sort_0_upsweep_pipeline = create_compute_pipeline(
- crate::RADIX_SORT_0_UPSWEEP_COMP_SPV,
- "radix_sort_upsweep",
+ let radix_sort_0_upsweep_pipeline = create_compute_pipeline_with_entry(
+ crate::RADIX_SORT_SPV,
+ c"upsweep",
+ "radix sort upsweep",
32,
true,
std::mem::size_of::<RadixSortUpsweepConstants>(),
);
- let radix_sort_1_downsweep_pipeline = create_compute_pipeline(
- crate::RADIX_SORT_1_DOWNSWEEP_COMP_SPV,
+ let radix_sort_1_downsweep_pipeline = create_compute_pipeline_with_entry(
+ crate::RADIX_SORT_SPV,
+ c"downsweep",
"radix_sort_downsweep",
32,
true,