layout (local_size_x = RADIX_DIGITS, local_size_y = 1, local_size_z = 1) in;
void main() {
- const uint count = constants.count_buffer.value;
const uint shift = constants.shift;
+ const uint count = constants.count_buffer.value;
+
+ const uint wgp_count = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP;
+ const bool needs_bounds_check = gl_WorkGroupID.x == wgp_count - 1;
// Clear local histogram
histogram[gl_LocalInvocationID.x] = 0;
barrier();
- const uint start = gl_WorkGroupID.x * RADIX_ITEMS_PER_WGP;
- const uint end = start + RADIX_ITEMS_PER_WGP;
-
- const bool skip_bounds_check = end <= count;
-
- if (skip_bounds_check) {
- for (uint i = start; i < end; i += RADIX_DIGITS) {
- const uint index = i + gl_LocalInvocationID.x;
- const uint value = constants.src_buffer.values[index];
- const uint digit = (value >> shift) & RADIX_MASK;
- atomicAdd(histogram[digit], 1);
- }
- } else {
- for (uint i = start; i < end; i += RADIX_DIGITS) {
- const uint index = i + gl_LocalInvocationID.x;
- if (index < count) {
- const uint value = constants.src_buffer.values[index];
+ if (needs_bounds_check) {
+ for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) {
+ const uint global_id = gl_WorkGroupID.x * gl_WorkGroupSize.x * RADIX_ITEMS_PER_INVOCATION + i * RADIX_DIGITS + gl_LocalInvocationID.x;
+ if (global_id < count) {
+ const uint value = constants.src_buffer.values[global_id];
const uint digit = (value >> shift) & RADIX_MASK;
atomicAdd(histogram[digit], 1);
}
}
+ } else {
+ for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) {
+ const uint global_id = gl_WorkGroupID.x * gl_WorkGroupSize.x * RADIX_ITEMS_PER_INVOCATION + i * RADIX_DIGITS + gl_LocalInvocationID.x;
+ const uint value = constants.src_buffer.values[global_id];
+ const uint digit = (value >> shift) & RADIX_MASK;
+ atomicAdd(histogram[digit], 1);
+ }
}
barrier();
// Scatter to the spine, this is a striped layout so we can efficiently
// calculate the prefix sum. Re-calculate how many workgroups we dispatched
// to determine the stride we need to write at.
- const uint wgp_count = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP;
constants.spine_buffer.values[(gl_LocalInvocationID.x * wgp_count) + gl_WorkGroupID.x] = histogram[gl_LocalInvocationID.x];
}
\ No newline at end of file
layout (constant_id = 0) const uint SUBGROUP_SIZE = 64;
-const uint NUM_SUBGROUPS = RADIX_SPINE_WGP_SIZE / SUBGROUP_SIZE;
+const uint NUM_SUBGROUPS = RADIX_WGP_SIZE / SUBGROUP_SIZE;
+shared uint carry;
shared uint sums[NUM_SUBGROUPS];
-shared uint carry_in;
-layout (local_size_x = RADIX_SPINE_WGP_SIZE, local_size_y = 1, local_size_z = 1) in;
+layout (local_size_x = RADIX_WGP_SIZE, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint local_id = gl_SubgroupID * gl_SubgroupSize + gl_SubgroupInvocationID;
// Re-calculate how many workgroups pushed data into the spine
const uint upsweep_wgp_count = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP;
- carry_in = 0;
+ carry = 0;
for (uint i = 0; i < upsweep_wgp_count; i++) {
- const uint spine_index = i * RADIX_DIGITS + local_id;
-
// Load values and calculate partial sums
- const uint value = constants.spine_buffer.values[spine_index];
+ const uint value = constants.spine_buffer.values[i * RADIX_DIGITS + local_id];
const uint sum = subgroupAdd(value);
const uint scan = subgroupExclusiveAdd(value);
barrier();
+ const uint carry_in = carry;
+
// Scan partials
if (local_id < NUM_SUBGROUPS) {
sums[local_id] = subgroupExclusiveAdd(sums[local_id]);
}
- const uint carry = carry_in;
-
barrier();
// Write out the final prefix sum, combining the carry-in, subgroup sums, and local scan
- constants.spine_buffer.values[spine_index] = carry + sums[gl_SubgroupID] + scan;
+ constants.spine_buffer.values[i * RADIX_DIGITS + local_id] = carry_in + sums[gl_SubgroupID] + scan;
if (gl_SubgroupID == gl_NumSubgroups - 1 && subgroupElect()) {
- carry_in += sums[gl_SubgroupID] + sum;
+ atomicAdd(carry, sums[gl_NumSubgroups - 1] + sum);
}
-
- memoryBarrierShared();
}
}
#extension GL_KHR_shader_subgroup_shuffle_relative: enable
#extension GL_KHR_shader_subgroup_vote : require
-//#extension GL_EXT_debug_printf : enable
-
#include "compute_bindings.h"
#include "radix_sort.h"
RadixSortDownsweepConstants constants;
};
-shared uint values[RADIX_WGP_SIZE];
shared uint spine[RADIX_DIGITS];
shared uint match_masks[NUM_SUBGROUPS][RADIX_DIGITS];
layout (local_size_x = RADIX_WGP_SIZE, local_size_y = 1, local_size_z = 1) in;
void main() {
+ const uint local_id = gl_SubgroupID * gl_SubgroupSize + gl_SubgroupInvocationID;
+
const uint shift = constants.shift;
const uint count = constants.count_buffer.value;
const uint wgp_count = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP;
- const uint start = gl_WorkGroupID.x * RADIX_ITEMS_PER_WGP;
- const uint end = min(start + RADIX_ITEMS_PER_WGP, count);
-
- const uint local_id = gl_SubgroupID * gl_SubgroupSize + gl_SubgroupInvocationID;
-
// Gather from spine.
spine[local_id] = constants.spine_buffer.values[(local_id * wgp_count) + gl_WorkGroupID.x];
- for (uint value_base = start; value_base < end; value_base += RADIX_WGP_SIZE) {
- // Clear shared memory and load values from src buffer.
- for (uint i = 0; i < NUM_SUBGROUPS; i++) {
- match_masks[i][local_id] = 0;
- }
+ const bool needs_bounds_check = gl_WorkGroupID.x == wgp_count - 1;
- barrier();
+ if (needs_bounds_check) {
+ for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) {
+ const uint base = gl_WorkGroupID.x * RADIX_ITEMS_PER_WGP + i * RADIX_DIGITS;
- const uint global_offset = value_base + local_id;
- uint value = 0xffffffff;
- uint digit = 0xff;
+ if (base >= count)
+ break;
- if (global_offset < end) {
- value = constants.src_buffer.values[global_offset];
- digit = (value >> shift) & RADIX_MASK;
+ // Clear shared memory and load values from src buffer.
+ for (uint j = 0; j < NUM_SUBGROUPS; j++) {
+ match_masks[j][local_id] = 0;
+ }
+
+ barrier();
+
+ const uint global_id = base + local_id;
+ const uint value = global_id < count ? constants.src_buffer.values[global_id] : 0xffffffff;
+ const uint digit = (value >> shift) & RADIX_MASK;
atomicOr(match_masks[gl_SubgroupID][digit], 1 << gl_SubgroupInvocationID);
+
+ barrier();
+
+ uint peer_scan = 0;
+ for (uint i = 0; i < gl_NumSubgroups; i++) {
+ if (i < gl_SubgroupID) {
+ peer_scan += bitCount(match_masks[i][digit]);
+ }
+ }
+ peer_scan += bitCount(match_masks[gl_SubgroupID][digit] & gl_SubgroupLtMask.x);
+
+ if (global_id < count) {
+ constants.dst_buffer.values[spine[digit] + peer_scan] = value;
+ }
+
+ barrier();
+
+ // Increment the spine with the counts for the workgroup we just wrote out.
+ for (uint i = 0; i < NUM_SUBGROUPS; i++) {
+ atomicAdd(spine[local_id], bitCount(match_masks[i][local_id]));
+ }
}
+ } else {
+ for (uint i = 0; i < RADIX_ITEMS_PER_INVOCATION; i++) {
+ // Clear shared memory and load values from src buffer.
+ for (uint j = 0; j < NUM_SUBGROUPS; j++) {
+ match_masks[j][local_id] = 0;
+ }
- barrier();
+ barrier();
+
+ const uint global_id = gl_WorkGroupID.x * RADIX_ITEMS_PER_WGP + i * RADIX_DIGITS + local_id;
+ const uint value = constants.src_buffer.values[global_id];
+ const uint digit = (value >> shift) & RADIX_MASK;
+ atomicOr(match_masks[gl_SubgroupID][digit], 1 << gl_SubgroupInvocationID);
- if (global_offset < end) {
- const uint peer_mask = match_masks[gl_SubgroupID][digit];
+ barrier();
- uint peer_scan = bitCount(peer_mask & gl_SubgroupLtMask.x);
+ uint peer_scan = 0;
for (uint i = 0; i < gl_NumSubgroups; i++) {
if (i < gl_SubgroupID) {
peer_scan += bitCount(match_masks[i][digit]);
}
}
+ peer_scan += bitCount(match_masks[gl_SubgroupID][digit] & gl_SubgroupLtMask.x);
- const uint dst_index = spine[digit] + peer_scan;
- constants.dst_buffer.values[dst_index] = value;
- }
+ constants.dst_buffer.values[spine[digit] + peer_scan] = value;
- barrier();
+ if (i != RADIX_ITEMS_PER_INVOCATION - 1) {
+ barrier();
- // Increment the spine with the counts for the workgroup we just wrote out.
- for (uint i = 0; i < NUM_SUBGROUPS; i++) {
- atomicAdd(spine[local_id], bitCount(match_masks[i][local_id]));
+ // Increment the spine with the counts for the workgroup we just wrote out.
+ for (uint i = 0; i < NUM_SUBGROUPS; i++) {
+ atomicAdd(spine[local_id], bitCount(match_masks[i][local_id]));
+ }
+ }
}
}
}