]> git.nega.tv - josh/narcissus/commitdiff
shark-shaders: Improve the performance of draw 2d shaders
authorJoshua Simmons <josh@nega.tv>
Thu, 16 Oct 2025 20:53:54 +0000 (22:53 +0200)
committerJoshua Simmons <josh@nega.tv>
Sat, 18 Oct 2025 21:40:34 +0000 (23:40 +0200)
Major improvement is to track the alpha value during resolve, and use it
to determine a conservative cut-off point for command drawing.

During resolve, remove words which became empty after culling.

Additionally use an indirect dispatch for drawing work, rather than
launching workgroups which will immediately terminate.

title/shark-shaders/shaders/composite.slang
title/shark-shaders/shaders/draw_2d.slang
title/shark-shaders/shaders/radix_sort.slang
title/shark-shaders/src/pipelines.rs
title/shark/src/draw.rs
title/shark/src/main.rs

index f02ace93f25271122ec98265a312c533e71f1576..e3c5833e31208879ba4e1da112a5ad89d972ad37 100644 (file)
@@ -2,8 +2,6 @@
 import bindings_samplers;
 import bindings_compute;
 
-import draw_2d;
-
 float srgb_oetf(float a) {
     return (.0031308f >= a) ? 12.92f * a : 1.055f * pow(a, .4166666666666667f) - .055f;
 }
@@ -21,17 +19,18 @@ float3 tony_mc_mapface(float3 stimulus) {
 
 struct CompositeConstants {
     uint2 tile_resolution;
-    Draw2d::Tile *tile_buffer;
+    uint *tile_mask_buffer;
 }
 
 [shader("compute")]
 [numthreads(8, 8, 1)]
 void main(uniform CompositeConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 group_id: SV_GroupID) {
-    let tile_coord = group_id.xy * WorkgroupSize().xy / Draw2d::TILE_SIZE;
-    let tile_index = tile_coord.y * constants.tile_resolution.x + tile_coord.x;
+    let tile_coord = thread_id.xy / 32;
 
-    let lo = constants.tile_buffer[tile_index].index_min;
-    let hi = constants.tile_buffer[tile_index].index_max;
+    let stride = (constants.tile_resolution.x + 31) / 32;
+    let index = tile_coord.y * stride + tile_coord.x / 32;
+    let word = constants.tile_mask_buffer[index];
+    let mask = 1u << (tile_coord.x & 31);
 
     // Display transform
     let stimulus = color_layer.Load(thread_id.xy).rgb;
@@ -39,7 +38,7 @@ void main(uniform CompositeConstants constants, uint3 thread_id: SV_DispatchThre
     var composited = srgb_oetf(transformed);
 
     // UI composite
-    if (lo != hi) {
+    if ((word & mask) != 0) {
         let ui = ui_layer.Load(thread_id.xy).rgba;
         composited = ui.rgb + (composited * (1.0 - ui.a));
     }
index a9fe81ceb6014f55a36bc61d466c860b568d0fbb..539b9647968965802a0854837afa1950279c5ab8 100644 (file)
@@ -6,11 +6,6 @@ import sdf;
 
 namespace Draw2d {
 public static const uint TILE_SIZE = 32;
-
-public struct Tile {
-    public uint index_min;
-    public uint index_max;
-}
 }
 
 static const uint MAX_TILES = 256;
@@ -63,15 +58,31 @@ struct CmdGlyph {
 };
 
 struct ClearConstants {
+    uint2 tile_resolution;
     uint *finished_buffer;
     uint *coarse_buffer;
+    uint *tile_mask_buffer;
+    VkDispatchIndirectCommand *tile_dispatch_buffer;
 }
 
 [shader("compute")]
-[numthreads(1, 1, 1)]
-void clear(uniform ClearConstants constants) {
+[numthreads(64, 1, 1)]
+void clear(uniform ClearConstants constants, uint thread_index_in_group: SV_GroupIndex) {
+    let stride = (constants.tile_resolution.x + 31) / 32;
+    let size = constants.tile_resolution.y * stride;
+
+    for (uint i = 0; i < size; i += 64) {
+        let index = i + thread_index_in_group;
+        if (index < size) {
+            constants.tile_mask_buffer[index] = 0;
+        }
+    }
+
     constants.finished_buffer[0] = 0;
     constants.coarse_buffer[0] = 0;
+    constants.tile_dispatch_buffer.x = 1;
+    constants.tile_dispatch_buffer.y = 1;
+    constants.tile_dispatch_buffer.z = 0;
 }
 
 struct ScatterConstants {
@@ -87,13 +98,13 @@ struct ScatterConstants {
 };
 
 [vk::specialization_constant]
-const int WGP_SIZE = 64;
+const int WAVE_SIZE = 64;
 
 groupshared uint scatter_intersected_tiles[BITMAP_SIZE];
 
 [shader("compute")]
 [require(spvGroupNonUniformBallot, spvGroupNonUniformArithmetic, spvGroupNonUniformVote)]
-[numthreads(WGP_SIZE, 1, 1)]
+[numthreads(WAVE_SIZE, 1, 1)]
 void scatter(uniform ScatterConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 group_id: SV_GroupID) {
     let in_bounds = thread_id.x < constants.draw_buffer_len;
 
@@ -265,7 +276,7 @@ struct VkDispatchIndirectCommand {
 struct SortConstants {
     uint coarse_buffer_len;
     uint _pad;
-    VkDispatchIndirectCommand *indirect_dispatch_buffer;
+    VkDispatchIndirectCommand *sort_dispatch_buffer;
     uint *coarse_buffer;
 };
 
@@ -284,9 +295,9 @@ void sort(uniform SortConstants constants) {
     let count = min(constants.coarse_buffer_len, constants.coarse_buffer[0]);
     constants.coarse_buffer[0] = count;
 
-    constants.indirect_dispatch_buffer.x = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP;
-    constants.indirect_dispatch_buffer.y = 1;
-    constants.indirect_dispatch_buffer.z = 1;
+    constants.sort_dispatch_buffer.x = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP;
+    constants.sort_dispatch_buffer.y = 1;
+    constants.sort_dispatch_buffer.z = 1;
 }
 
 struct ResolveConstants {
@@ -298,31 +309,30 @@ struct ResolveConstants {
     Glyph *glyph_buffer;
     uint *coarse_buffer;
     uint *fine_buffer;
-    Draw2d::Tile *tile_buffer;
+    uint4 *tile_buffer;
+    uint *tile_mask_buffer;
+    VkDispatchIndirectCommand *tile_dispatch_buffer;
 };
 
 [shader("compute")]
-[require(spvGroupNonUniformBallot, spvGroupNonUniformVote)]
-[numthreads(WGP_SIZE, 1, 1)]
+[require(spvGroupNonUniformBallot, spvGroupNonUniformShuffle, spvGroupNonUniformVote)]
+[numthreads(WAVE_SIZE, 1, 1)]
 void resolve(uniform ResolveConstants constants, uint3 thread_id: SV_DispatchThreadID) {
     let x = thread_id.y;
     let y = thread_id.z;
-    let tile_offset = constants.tile_stride * y + x;
     let search = ((y & 0xff) << 24) | ((x & 0xff) << 16);
     let count = constants.coarse_buffer[0];
 
     if (count == 0) {
-        constants.tile_buffer[tile_offset].index_min = 0;
-        constants.tile_buffer[tile_offset].index_max = 0;
         return;
     }
 
     // Binary search for the upper bound of the tile.
-    uint base = 0;
+    var base = 0;
     {
-        uint n = count;
+        var max_iters = 32;
+        var n = count;
         uint mid;
-        uint max_iters = 32;
         while (max_iters-- > 0 && (mid = n / 2) > 0) {
             let value = constants.coarse_buffer[1 + base + mid] & 0xffff0000;
             base = value > search ? base : base + mid;
@@ -333,22 +343,24 @@ void resolve(uniform ResolveConstants constants, uint3 thread_id: SV_DispatchThr
     let tile_min = uint2(x, y) * Draw2d::TILE_SIZE;
     let tile_max = tile_min + Draw2d::TILE_SIZE;
 
-    bool hit_opaque = false;
-    uint lo = base + 1;
+    var alpha_carry = 0.0;
+    var hit_opaque = false;
+    var lo = base + 1;
     let hi = base + 1;
-    for (; !hit_opaque && lo > 0; lo--) {
-        let i = lo;
+    for (var i = base + 1; !hit_opaque && i > 0; i--) {
         let packed = constants.coarse_buffer[i];
 
+        // If we leave the tile we're done.
         if ((packed & 0xffff0000) != (search & 0xffff0000)) {
             break;
         }
 
         let draw_offset = packed & 0xffff;
-        let draw_index = draw_offset * WGP_SIZE + WaveGetLaneIndex();
+        let draw_index = draw_offset * WAVE_SIZE + WaveGetLaneIndex();
 
-        bool intersects = false;
-        bool opaque_tile = false;
+        var intersects_tile = false;
+        var covers_tile = false;
+        var covers_tile_alpha = 0.0;
 
         if (draw_index < constants.draw_buffer_len) {
             var cmd_min = float2(99999.9);
@@ -362,7 +374,7 @@ void resolve(uniform ResolveConstants constants, uint3 thread_id: SV_DispatchThr
 
             // If the tile doesn't intersect the scissor region it doesn't need to do work here.
             if (any(scissor.offset_max < tile_min) || any(scissor.offset_min > tile_max)) {
-                intersects = false;
+                intersects_tile = false;
             } else {
                 for (;;) {
                     let scalar_type = WaveReadLaneFirst(cmd_type);
@@ -374,18 +386,21 @@ void resolve(uniform ResolveConstants constants, uint3 thread_id: SV_DispatchThr
                             cmd_min = cmd_rect.position;
                             cmd_max = cmd_rect.position + cmd_rect.bound;
 
-                            const bool background_opaque = (cmd_rect.background_color & 0xff000000) == 0xff000000;
-                            if (background_opaque) {
-                                let border_width = float((packed_type >> 16) & 0xff);
-                                let border_opaque = (cmd_rect.border_color & 0xff000000) == 0xff000000;
-                                let border_radii = unpackUnorm4x8ToFloat(cmd_rect.border_radii);
-                                let max_border_radius = max(border_radii.x, max(border_radii.y, max(border_radii.z, border_radii.w))) * 255.0;
-                                let shrink = ((2.0 - sqrt(2.0)) * max_border_radius) + (border_opaque ? 0.0 : border_width);
-
-                                let cmd_shrunk_min = max(scissor.offset_min, cmd_min + shrink);
-                                let cmd_shrunk_max = min(scissor.offset_max, cmd_max - shrink);
-                                opaque_tile = all(cmd_shrunk_max > cmd_shrunk_min) && all(tile_min > cmd_shrunk_min) && all(tile_max < cmd_shrunk_max);
-                            }
+                            let background_alpha = cmd_rect.background_color >> 24;
+                            let border_alpha = cmd_rect.border_color >> 24;
+                            let border_matches_background = background_alpha == border_alpha;
+
+                            let border_width = float((packed_type >> 16) & 0xff);
+                            let border_radii = unpackUnorm4x8ToFloat(cmd_rect.border_radii);
+                            let max_border_radius = max(border_radii.x, max(border_radii.y, max(border_radii.z, border_radii.w))) * 255.0;
+                            let shrink = ((2.0 - sqrt(2.0)) * max_border_radius) + (border_matches_background ? 0.0 : border_width);
+
+                            let cmd_shrunk_min = max(scissor.offset_min, cmd_min + shrink);
+                            let cmd_shrunk_max = min(scissor.offset_max, cmd_max - shrink);
+
+                            covers_tile_alpha = float(background_alpha) * (1.0 / 255.0);
+                            covers_tile = all(cmd_shrunk_max > cmd_shrunk_min) && all(tile_min > cmd_shrunk_min) && all(tile_max < cmd_shrunk_max);
+
                             break;
                         case CmdType::Glyph:
                             let cmd_glyph = reinterpret<CmdGlyph>(constants.draw_buffer[draw_index]);
@@ -400,39 +415,86 @@ void resolve(uniform ResolveConstants constants, uint3 thread_id: SV_DispatchThr
 
                 cmd_min = max(cmd_min, scissor.offset_min);
                 cmd_max = min(cmd_max, scissor.offset_max);
-                intersects = !(any(tile_max < cmd_min) || any(tile_min > cmd_max));
+                intersects_tile = !(any(tile_max < cmd_min) || any(tile_min > cmd_max));
+            }
+        }
+
+        var intersects_mask = WaveActiveBallot(intersects_tile).x;
+
+        if (WaveActiveAnyTrue(covers_tile)) {
+            let transparent = covers_tile && covers_tile_alpha == 0.0;
+            let non_transparent = covers_tile && covers_tile_alpha != 0.0;
+            let opaque = covers_tile && covers_tile_alpha == 1.0;
+
+            let transparent_tile_ballot = WaveActiveBallot(transparent).x;
+            intersects_mask &= ~transparent_tile_ballot;
+
+            if (WaveActiveAnyTrue(opaque)) {
+                let opaque_tile_ballot = WaveActiveBallot(opaque).x;
+                let opaque_mask = (1 << firstbithigh(opaque_tile_ballot)) - 1;
+                hit_opaque = true;
+                intersects_mask &= ~opaque_mask;
+            } else if (WaveActiveAnyTrue(non_transparent)) {
+                // First we read in the alpha carry from the previous coarse
+                // word, if any.
+                var alpha = WaveReadLaneFirst(alpha_carry);
+
+                // Then each wave calculates its alpha by blending its
+                // predecessors into itself.
+                for (var i = WAVE_SIZE; i-- > 0;) {
+                    if (i >= WaveGetLaneIndex()) {
+                        let cmd_alpha = WaveReadLaneAt(covers_tile_alpha, i);
+                        alpha = cmd_alpha + alpha * (1.0 - cmd_alpha);
+                    }
+                }
+
+                // Check if any lanes went beyond the threshold.
+                let considered_opaque = alpha > 0.999;
+                if (WaveActiveAnyTrue(considered_opaque)) {
+                    let opaque_tile_ballot = WaveActiveBallot(considered_opaque).x;
+                    let opaque_mask = (1 << firstbithigh(opaque_tile_ballot)) - 1;
+                    hit_opaque = true;
+                    intersects_mask &= ~opaque_mask;
+                }
+
+                // If they didn't go beyond the threshold, we need to update
+                // the alpha carry value.
+                alpha_carry = WaveReadLaneFirst(alpha);
             }
         }
 
-        var intersects_mask = WaveActiveBallot(intersects).x;
-
-        if (WaveActiveAnyTrue(opaque_tile)) {
-            let opaque_tile_ballot = WaveActiveBallot(opaque_tile);
-            // TODO: Needs to check all live words of the ballot...
-            let first_opaque_tile = firstbithigh(opaque_tile_ballot).x;
-            let opaque_mask = ~((1 << first_opaque_tile) - 1);
-            intersects_mask &= opaque_mask;
-            constants.fine_buffer[i] = intersects_mask;
-            hit_opaque = true;
-        } else {
-            constants.fine_buffer[i] = intersects_mask;
+        if (intersects_mask != 0) {
+            constants.coarse_buffer[lo] = packed;
+            constants.fine_buffer[lo] = intersects_mask;
+            lo--;
         }
     }
 
-    constants.tile_buffer[tile_offset].index_min = lo + 1;
-    constants.tile_buffer[tile_offset].index_max = hi + 1;
+    if (WaveIsFirstLane()) {
+        if (lo != hi) {
+            uint dispatch;
+            InterlockedAdd(constants.tile_dispatch_buffer.z, 16, dispatch);
+
+            let offset = dispatch >> 4;
+            constants.tile_buffer[offset] = uint4(x, y, lo + 1, hi + 1);
+
+            let stride = (constants.tile_stride + 31) / 32;
+            let mask = 1u << (x & 31);
+            InterlockedOr(constants.tile_mask_buffer[y * stride + x / 32], 1 << (x & 31));
+        }
+    }
 }
 
 struct RasterizeConstants {
     uint tile_stride;
     uint _pad;
 
-    Cmd *draw_buffer;
-    Scissor *scissor_buffer;
-    Glyph *glyph_buffer;
-    uint *coarse_buffer;
-    uint *fine_buffer;
-    Draw2d::Tile *tile_buffer;
+    Ptr<Cmd, Access::Read> draw_buffer;
+    Ptr<Scissor, Access::Read> scissor_buffer;
+    Ptr<Glyph, Access::Read> glyph_buffer;
+    Ptr<uint, Access::Read> coarse_buffer;
+    Ptr<uint, Access::Read> fine_buffer;
+    Ptr<uint4, Access::Read> tile_buffer;
 };
 
 /// x = (((index >> 2) & 0x0007) & 0xFFFE) | index & 0x0001
@@ -452,21 +514,20 @@ float3 plasma_quintic(float x) {
 
 [shader("compute")]
 [numthreads(8, 8, 1)]
-void rasterize(uniform RasterizeConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 group_id: SV_GroupID) {
-    let tile_coord = group_id.xy * WorkgroupSize().xy / Draw2d::TILE_SIZE;
-    let tile_index = tile_coord.y * constants.tile_stride + tile_coord.x;
+void rasterize(uniform RasterizeConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 thread_id_in_group: SV_GroupThreadID) {
+    let tile_index = thread_id.z / 16;
+    let x = thread_id.z & 3;
+    let y = (thread_id.z >> 2) & 3;
 
-    let lo = constants.tile_buffer[tile_index].index_min;
-    let hi = constants.tile_buffer[tile_index].index_max;
-
-    if (lo == hi) {
-        return;
-    }
+    let tile = constants.tile_buffer[tile_index];
+    let position = tile.xy * Draw2d::TILE_SIZE + uint2(x, y) * WorkgroupSize().xy + thread_id_in_group.xy;
+    let lo = tile.z;
+    let hi = tile.w;
 
 #if DEBUG_SHOW_TILES == 1
 
-    let color = plasma_quintic(float(hi - lo) / 50.0);
-    ui_layer_write.Store(thread_id.xy, float4(color, 1.0));
+    let color = plasma_quintic(float(hi - lo) / 16.0);
+    ui_layer.Store(position, float4(color, 1.0));
 
 #elif DEBUG_SHOW_TILES == 2
 
@@ -474,26 +535,28 @@ void rasterize(uniform RasterizeConstants constants, uint3 thread_id: SV_Dispatc
     for (uint i = lo; i < hi; i++) {
         count += countbits(constants.fine_buffer[i]);
     }
-    let color = plasma_quintic(float(count) / 600.0);
-    ui_layer_write.Store(thread_id.xy, float4(color, 1.0));
+    let color = count == 1 ? float3(1.0, 0.0, 0.0) : plasma_quintic(float(count) / 300.0);
+    ui_layer.Store(position, float4(color, 1.0));
 
 #else
 
-    let sample_center = thread_id.xy + float2(0.5);
+    let sample_center = float2(position) + float2(0.5);
     var accum = float4(0.0);
 
-    for (uint i = lo; i < hi; i++) {
+    var i = lo;
+    // lo != hi, or the group wouldn't have been dispatched.
+    do {
+        let base = (constants.coarse_buffer[i] & 0xffff) * 32;
         var bitmap = constants.fine_buffer[i];
 
-        while (bitmap != 0) {
+        // Any bitmap in the fine buffer is non-zero.
+        do {
             let index = firstbitlow(bitmap);
             bitmap ^= bitmap & -bitmap;
 
-            let base_index = (constants.coarse_buffer[i] & 0xffff) * 32;
-            let cmd = constants.draw_buffer[base_index + index];
+            let cmd = constants.draw_buffer[base + index];
             let cmd_type = cmd.packed_type >> 24;
             let cmd_scissor = cmd.packed_type & 0xffff;
-
             let scissor = constants.scissor_buffer[cmd_scissor];
 
             var primitive_color = float4(0.0);
@@ -545,12 +608,16 @@ void rasterize(uniform RasterizeConstants constants, uint3 thread_id: SV_Dispatc
                 let glyph = constants.glyph_buffer[cmd_glyph.index];
                 let cmd_min = cmd_glyph.position + glyph.offset_min;
                 let cmd_max = cmd_glyph.position + glyph.offset_max;
-                if (all(sample_center >= max(scissor.offset_min, cmd_min)) && all(sample_center <= min(scissor.offset_max, cmd_max))) {
-                    let glyph_size = glyph.offset_max - glyph.offset_min;
-                    let uv = lerp(glyph.atlas_min, glyph.atlas_max, (sample_center - cmd_min) / glyph_size);
-                    let color = unpackUnorm4x8ToFloat(cmd_glyph.color).bgra;
-                    let coverage = glyph_atlas.SampleLevel(samplers[Sampler::BilinearUnnormalized], uv, 0.0).r * color.a;
-                    primitive_color = color * coverage;
+                if (all(sample_center >= cmd_min) && all(sample_center <= cmd_max)) {
+                    let cmd_min_clipped = max(scissor.offset_min, cmd_min);
+                    let cmd_max_clipped = min(scissor.offset_max, cmd_max);
+                    if (all(sample_center >= cmd_min_clipped) && all(sample_center <= cmd_max_clipped)) {
+                        let glyph_size = glyph.offset_max - glyph.offset_min;
+                        let uv = lerp(glyph.atlas_min, glyph.atlas_max, (sample_center - cmd_min) / glyph_size);
+                        let color = unpackUnorm4x8ToFloat(cmd_glyph.color).bgra;
+                        let coverage = glyph_atlas.SampleLevel(samplers[Sampler::BilinearUnnormalized], uv, 0.0).r * color.a;
+                        primitive_color = color * coverage;
+                    }
                 }
                 break;
             }
@@ -558,10 +625,12 @@ void rasterize(uniform RasterizeConstants constants, uint3 thread_id: SV_Dispatc
 
             // does it blend?
             accum.rgba = primitive_color.rgba + accum.rgba * (1.0 - primitive_color.a);
-        }
-    }
+        } while (bitmap != 0);
+
+        i++;
+    } while (i != hi);
 
-    ui_layer.Store(thread_id.xy, accum);
+    ui_layer.Store(position, accum);
 
 #endif
 }
index 120569e3639668780731c8b947e97fef37dc4f96..d7d3f0742acb37c40bc60b7b0199d22d7ef23dea 100644 (file)
@@ -34,10 +34,10 @@ static const uint WAVE_COUNT = RADIX_GROUP_SIZE / WAVE_SIZE;
 struct UpsweepConstants {
     uint shift;
     uint _pad;
-    uint *finished_buffer;
-    uint *count_buffer;
-    uint *src_buffer;
-    uint *spine_buffer;
+    Ptr<uint, Access::ReadWrite> finished_buffer;
+    Ptr<uint, Access::Read> count_buffer;
+    Ptr<uint, Access::Read> src_buffer;
+    Ptr<uint, Access::ReadWrite> spine_buffer;
 };
 
 groupshared uint histogram[RADIX_DIGITS];
@@ -49,7 +49,7 @@ groupshared uint sums[WAVE_COUNT];
 [shader("compute")]
 [require(spvGroupNonUniformBallot, spvGroupNonUniformArithmetic)]
 [numthreads(RADIX_GROUP_SIZE, 1, 1)]
-void upsweep(uniform UpsweepConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 group_id: SV_GroupID, uint3 thread_id_in_group: SV_GroupThreadID) {
+void upsweep(uniform UpsweepConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 group_id: SV_GroupID, uint thread_index_in_group: SV_GroupIndex) {
     let shift = constants.shift;
     let count = constants.count_buffer[0];
 
@@ -58,46 +58,57 @@ void upsweep(uniform UpsweepConstants constants, uint3 thread_id: SV_DispatchThr
 
     // Clear local histogram.
     // Assumes RADIX_GROUP_SIZE == RADIX_DIGITS
-    histogram[thread_id_in_group.x] = 0;
+    histogram[thread_index_in_group] = 0;
 
+    // Ensure we've finished clearing the LDS histogram.
     GroupMemoryBarrierWithGroupSync();
 
     if (is_last_group_in_dispatch) {
         for (uint i = 0; i < RADIX_ITEMS_PER_THREAD; i++) {
-            const uint src_index = group_id.x * WorkgroupSize().x * RADIX_ITEMS_PER_THREAD + i * RADIX_DIGITS + thread_id_in_group.x;
-            if (src_index < count) {
-                const uint value = constants.src_buffer[src_index];
-                const uint digit = (value >> shift) & RADIX_MASK;
-                InterlockedAdd(histogram[digit], 1);
-            }
+            let src_index = group_id.x * WorkgroupSize().x * RADIX_ITEMS_PER_THREAD + i * RADIX_DIGITS + thread_index_in_group;
+            // This will count out-of-bounds values into the last histogram
+            // bucket, but since it's only happening for the last group in a
+            // dispatch, and because it's the last bucket, it won't affect the
+            // results.
+            let value = src_index < count ? constants.src_buffer[src_index] : 0xffffffff;
+            let digit = (value >> shift) & RADIX_MASK;
+            InterlockedAdd(histogram[digit], 1);
         }
     } else {
         for (uint i = 0; i < RADIX_ITEMS_PER_THREAD; i++) {
-            const uint src_index = group_id.x * WorkgroupSize().x * RADIX_ITEMS_PER_THREAD + i * RADIX_DIGITS + thread_id_in_group.x;
-            const uint value = constants.src_buffer[src_index];
-            const uint digit = (value >> shift) & RADIX_MASK;
+            let src_index = group_id.x * WorkgroupSize().x * RADIX_ITEMS_PER_THREAD + i * RADIX_DIGITS + thread_index_in_group;
+            let value = constants.src_buffer[src_index];
+            let digit = (value >> shift) & RADIX_MASK;
             InterlockedAdd(histogram[digit], 1);
         }
     }
 
+    // Ensure we've finished updating the histogram in LDS.
     GroupMemoryBarrierWithGroupSync();
 
     // Scatter to the spine, this is a striped layout so we can efficiently
     // calculate the prefix sum. Re-calculate how many workgroups we dispatched
     // to determine the stride we need to write at.
-    constants.spine_buffer[(thread_id_in_group.x * dispatch_group_count) + group_id.x] = histogram[thread_id_in_group.x];
+    //
+    // Note the spine buffer size is rounded up so there's no need for bounds
+    // checking.
+    constants.spine_buffer[(thread_index_in_group * dispatch_group_count) + group_id.x] = histogram[thread_index_in_group];
 
+    // Ensure the spine has been written before we increment the finished
+    // atomic counter.
     DeviceMemoryBarrierWithGroupSync();
 
     // Store whether we're the last-executing group in LDS. This contrasts with
     // the 'static' `is_last_group_in_dispatch`, which represents the group with
     // the largest group_id in the dispatch.
-    if (thread_id_in_group.x == 0) {
+    if (thread_index_in_group == 0) {
         var old_value = 0;
         InterlockedAdd(*constants.finished_buffer, 1, old_value);
         is_last_group_dynamic = old_value == dispatch_group_count - 1;
     }
 
+    // Ensure all waves read the value of `is_last_group_dynamic` that we just
+    // filled.
     GroupMemoryBarrierWithGroupSync();
 
     // Only the last-executing group needs to continue, it will mop up the spine
@@ -106,20 +117,21 @@ void upsweep(uniform UpsweepConstants constants, uint3 thread_id: SV_DispatchThr
         return;
     }
 
-    // Reset for the next pass.
-    InterlockedExchange(*constants.finished_buffer, 0);
+    // Reset for the next pass, this can be a simple store as there's a barrier
+    // between passes, and we are the only group executing at this point.
+    *constants.finished_buffer = 0;
 
-    let wave_id = thread_id_in_group.x / WaveGetLaneCount();
+    let wave_index = thread_index_in_group / WaveGetLaneCount();
 
     carry = 0;
     for (uint i = 0; i < dispatch_group_count; i++) {
         // Load values and calculate partial sums.
-        let value = constants.spine_buffer[i * RADIX_DIGITS + thread_id_in_group.x];
+        let value = constants.spine_buffer[i * RADIX_DIGITS + thread_index_in_group];
         let sum = WaveActiveSum(value);
         let scan = WavePrefixSum(value);
 
         if (WaveIsFirstLane()) {
-            sums[wave_id] = sum;
+            sums[wave_index] = sum;
         }
 
         // Even though we read and write from the spine, this can be a group
@@ -135,19 +147,20 @@ void upsweep(uniform UpsweepConstants constants, uint3 thread_id: SV_DispatchThr
         let carry_in = carry;
 
         // Scan partials.
-        if (thread_id_in_group.x < WAVE_COUNT) {
-            sums[thread_id_in_group.x] = WavePrefixSum(sums[thread_id_in_group.x]);
+        if (thread_index_in_group < WAVE_COUNT) {
+            sums[thread_index_in_group] = WavePrefixSum(sums[thread_index_in_group]);
         }
 
+        // Make sure we've finished turning the partial sums into a prefix sum.
         GroupMemoryBarrierWithGroupSync();
 
         // Write out the final prefix sum, combining the carry-in, wave sums,
         // and local scan.
-        constants.spine_buffer[i * RADIX_DIGITS + thread_id_in_group.x] = carry_in + sums[wave_id] + scan;
+        constants.spine_buffer[i * RADIX_DIGITS + thread_index_in_group] = carry_in + sums[wave_index] + scan;
 
         // `sums` in LDS now contains partials, so we need to also add the wave
         // sum to get an inclusive prefix sum for the next iteration.
-        if (wave_id == WAVE_COUNT - 1 && WaveIsFirstLane()) {
+        if (wave_index == WAVE_COUNT - 1 && WaveIsFirstLane()) {
             InterlockedAdd(carry, sums[WAVE_COUNT - 1] + sum);
         }
     }
@@ -156,10 +169,10 @@ void upsweep(uniform UpsweepConstants constants, uint3 thread_id: SV_DispatchThr
 struct DownsweepConstants {
     uint shift;
     uint _pad;
-    uint *count_buffer;
-    uint *spine_buffer;
-    uint *src_buffer;
-    uint *dst_buffer;
+    Ptr<uint, Access::Read> count_buffer;
+    Ptr<uint, Access::Read> spine_buffer;
+    Ptr<uint, Access::Read> src_buffer;
+    Ptr<uint, Access::ReadWrite> dst_buffer;
 }
 
 groupshared uint spine[RADIX_DIGITS];
@@ -168,77 +181,75 @@ groupshared uint match_masks[WAVE_COUNT][RADIX_DIGITS];
 [shader("compute")]
 [require(spvGroupNonUniformBallot)]
 [numthreads(RADIX_GROUP_SIZE, 1, 1)]
-void downsweep(uniform DownsweepConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 group_id: SV_GroupID, uint3 thread_id_in_group: SV_GroupThreadID) {
+void downsweep(uniform DownsweepConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 group_id: SV_GroupID, uint thread_index_in_group: SV_GroupIndex) {
     let shift = constants.shift;
     let count = constants.count_buffer[0];
 
     let dispatch_group_count = radix_sort::CalculateGroupCountForItemCount(count);
     let is_last_group_in_dispatch = group_id.x == dispatch_group_count - 1;
 
-    let wave_id = thread_id_in_group.x / WaveGetLaneCount();
+    let wave_index = thread_index_in_group / WaveGetLaneCount();
 
     // Gather from spine buffer into LDS.
-    spine[thread_id_in_group.x] = constants.spine_buffer[thread_id_in_group.x * dispatch_group_count + group_id.x];
+    spine[thread_index_in_group] = constants.spine_buffer[thread_index_in_group * dispatch_group_count + group_id.x];
 
     if (is_last_group_in_dispatch) {
         for (uint i = 0; i < RADIX_ITEMS_PER_THREAD; i++) {
             // Clear shared memory and load values from src buffer.
             for (uint j = 0; j < WAVE_COUNT; j++) {
-                match_masks[j][thread_id_in_group.x] = 0;
+                match_masks[j][thread_index_in_group] = 0;
             }
 
             GroupMemoryBarrierWithGroupSync();
 
-            let index = group_id.x * RADIX_ITEMS_PER_GROUP + i * RADIX_DIGITS + thread_id_in_group.x;
+            let index = group_id.x * RADIX_ITEMS_PER_GROUP + i * RADIX_DIGITS + thread_index_in_group;
             let value = index < count ? constants.src_buffer[index] : 0xffffffff;
             let digit = (value >> shift) & RADIX_MASK;
-            InterlockedOr(match_masks[wave_id][digit], 1 << WaveGetLaneIndex());
+            InterlockedOr(match_masks[wave_index][digit], 1 << WaveGetLaneIndex());
 
             GroupMemoryBarrierWithGroupSync();
 
-            uint peer_scan = 0;
+            var peer_scan = 0;
             for (uint j = 0; j < WAVE_COUNT; j++) {
-                if (j < wave_id) {
-                    peer_scan += countbits(match_masks[j][digit]);
-                }
+                peer_scan += j < wave_index ? countbits(match_masks[j][digit]) : 0;
             }
-            peer_scan += countbits(match_masks[wave_id][digit] & WaveLtMask().x);
+            peer_scan += countbits(match_masks[wave_index][digit] & WaveLtMask().x);
 
             if (index < count) {
                 constants.dst_buffer[spine[digit] + peer_scan] = value;
             }
 
-            GroupMemoryBarrierWithGroupSync();
+            if (i != RADIX_ITEMS_PER_THREAD - 1) {
+                GroupMemoryBarrierWithGroupSync();
 
-            // Increment the spine with the counts for the workgroup we just
-            // wrote out.
-            for (uint i = 0; i < WAVE_COUNT; i++) {
-                InterlockedAdd(spine[thread_id_in_group.x], countbits(match_masks[i][thread_id_in_group.x]));
+                // Increment the spine with the counts for the workgroup we just
+                // wrote out.
+                for (uint i = 0; i < WAVE_COUNT; i++) {
+                    InterlockedAdd(spine[thread_index_in_group], countbits(match_masks[i][thread_index_in_group]));
+                }
             }
         }
     } else {
         for (uint i = 0; i < RADIX_ITEMS_PER_THREAD; i++) {
             // Clear shared memory and load values from src buffer.
             for (uint j = 0; j < WAVE_COUNT; j++) {
-                match_masks[j][thread_id_in_group.x] = 0;
+                match_masks[j][thread_index_in_group] = 0;
             }
 
             GroupMemoryBarrierWithGroupSync();
 
-            let index = group_id.x * RADIX_ITEMS_PER_GROUP + i * RADIX_DIGITS + thread_id_in_group.x;
+            let index = group_id.x * RADIX_ITEMS_PER_GROUP + i * RADIX_DIGITS + thread_index_in_group;
             let value = constants.src_buffer[index];
             let digit = (value >> shift) & RADIX_MASK;
-            InterlockedOr(match_masks[wave_id][digit], 1 << WaveGetLaneIndex());
+            InterlockedOr(match_masks[wave_index][digit], 1 << WaveGetLaneIndex());
 
             GroupMemoryBarrierWithGroupSync();
 
-            uint peer_scan = 0;
+            var peer_scan = 0;
             for (uint j = 0; j < WAVE_COUNT; j++) {
-                if (j < wave_id) {
-                    peer_scan += countbits(match_masks[j][digit]);
-                }
+                peer_scan += j < wave_index ? countbits(match_masks[j][digit]) : 0;
             }
-            peer_scan += countbits(match_masks[wave_id][digit] & WaveLtMask().x);
+            peer_scan += countbits(match_masks[wave_index][digit] & WaveLtMask().x);
 
             constants.dst_buffer[spine[digit] + peer_scan] = value;
 
@@ -248,7 +259,7 @@ void downsweep(uniform DownsweepConstants constants, uint3 thread_id: SV_Dispatc
                 // Increment the spine with the counts for the workgroup we just
                 // wrote out.
                 for (uint i = 0; i < WAVE_COUNT; i++) {
-                    InterlockedAdd(spine[thread_id_in_group.x], countbits(match_masks[i][thread_id_in_group.x]));
+                    InterlockedAdd(spine[thread_index_in_group], countbits(match_masks[i][thread_index_in_group]));
                 }
             }
         }
index 6c3fef17320686de5bfbdd79ffdb28d5541f0f92..0ae7369699c176a554415f5cab2984716059405d 100644 (file)
@@ -176,8 +176,12 @@ pub struct BasicConstants<'a> {
 
 #[repr(C)]
 pub struct Draw2dClearConstants<'a> {
+    pub tile_resolution_x: u32,
+    pub tile_resolution_y: u32,
     pub finished_buffer_address: BufferAddress<'a>,
     pub coarse_buffer_address: BufferAddress<'a>,
+    pub tile_mask_buffer_address: BufferAddress<'a>,
+    pub tile_dispatch_buffer_address: BufferAddress<'a>,
 }
 
 #[repr(C)]
@@ -198,7 +202,7 @@ pub struct Draw2dScatterConstants<'a> {
 pub struct Draw2dSortConstants<'a> {
     pub coarse_buffer_len: u32,
     pub _pad: u32,
-    pub indirect_dispatch_buffer_address: BufferAddress<'a>,
+    pub sort_dispatch_buffer_address: BufferAddress<'a>,
     pub coarse_buffer_address: BufferAddress<'a>,
 }
 
@@ -213,6 +217,8 @@ pub struct Draw2dResolveConstants<'a> {
     pub coarse_buffer_address: BufferAddress<'a>,
     pub fine_buffer_address: BufferAddress<'a>,
     pub tile_buffer_address: BufferAddress<'a>,
+    pub tile_mask_buffer_address: BufferAddress<'a>,
+    pub tile_dispatch_buffer_address: BufferAddress<'a>,
 }
 
 #[repr(C)]
@@ -232,7 +238,7 @@ pub struct Draw2dRasterizeConstants<'a> {
 pub struct CompositeConstants<'a> {
     pub tile_resolution_x: u32,
     pub tile_resolution_y: u32,
-    pub tile_buffer_address: BufferAddress<'a>,
+    pub tile_mask_buffer_address: BufferAddress<'a>,
 }
 
 #[repr(C)]
index 37dac905d9377796b75cf80a5443912ac2aeb67f..05db1bacf284645346f322376aee58d28b9d348d 100644 (file)
@@ -750,16 +750,14 @@ impl<'gpu> DrawState<'gpu> {
                 ],
             );
 
-            let tile_buffer = gpu.request_transient_buffer(
+            let tile_mask_buffer = gpu.request_transient_buffer(
                 frame,
                 thread_token,
                 BufferUsageFlags::STORAGE,
-                self.tile_resolution_x as usize
-                    * self.tile_resolution_y as usize
-                    * std::mem::size_of::<u32>()
-                    * 2,
+                (self.tile_resolution_x.div_ceil(32) * self.tile_resolution_y) as usize
+                    * std::mem::size_of::<u32>(),
             );
-            let tile_buffer_address = gpu.get_buffer_address(tile_buffer.to_arg());
+            let tile_mask_buffer_address = gpu.get_buffer_address(tile_mask_buffer.to_arg());
 
             // Render UI
             {
@@ -769,6 +767,17 @@ impl<'gpu> DrawState<'gpu> {
                     microshades::PURPLE_RGBA_F32[3],
                 );
 
+                let tile_buffer = gpu.request_transient_buffer(
+                    frame,
+                    thread_token,
+                    BufferUsageFlags::STORAGE,
+                    self.tile_resolution_x as usize
+                        * self.tile_resolution_y as usize
+                        * std::mem::size_of::<u32>()
+                        * 4,
+                );
+                let tile_buffer_address = gpu.get_buffer_address(tile_buffer.to_arg());
+
                 let draw_buffer = gpu.request_transient_buffer_with_data(
                     frame,
                     thread_token,
@@ -800,7 +809,14 @@ impl<'gpu> DrawState<'gpu> {
                     COARSE_BUFFER_LEN * std::mem::size_of::<u32>(),
                 );
 
-                let indirect_dispatch_buffer = gpu.request_transient_buffer(
+                let sort_dispatch_buffer = gpu.request_transient_buffer(
+                    frame,
+                    thread_token,
+                    BufferUsageFlags::INDIRECT,
+                    3 * std::mem::size_of::<u32>(),
+                );
+
+                let tile_dispatch_buffer = gpu.request_transient_buffer(
                     frame,
                     thread_token,
                     BufferUsageFlags::INDIRECT,
@@ -832,8 +848,10 @@ impl<'gpu> DrawState<'gpu> {
                 let scissor_buffer_address = gpu.get_buffer_address(scissor_buffer.to_arg());
                 let glyph_buffer_address = gpu.get_buffer_address(glyph_buffer.to_arg());
                 let coarse_buffer_address = gpu.get_buffer_address(coarse_buffer.to_arg());
-                let indirect_dispatch_buffer_address =
-                    gpu.get_buffer_address(indirect_dispatch_buffer.to_arg());
+                let sort_dispatch_buffer_address =
+                    gpu.get_buffer_address(sort_dispatch_buffer.to_arg());
+                let tile_dispatch_buffer_address =
+                    gpu.get_buffer_address(tile_dispatch_buffer.to_arg());
                 let finished_buffer_address = gpu.get_buffer_address(finished_buffer.to_arg());
                 let tmp_buffer_address = gpu.get_buffer_address(tmp_buffer.to_arg());
                 let spine_buffer_address = gpu.get_buffer_address(spine_buffer.to_arg());
@@ -845,8 +863,12 @@ impl<'gpu> DrawState<'gpu> {
                     ShaderStageFlags::COMPUTE,
                     0,
                     &Draw2dClearConstants {
+                        tile_resolution_x: self.tile_resolution_x,
+                        tile_resolution_y: self.tile_resolution_y,
                         finished_buffer_address,
                         coarse_buffer_address,
+                        tile_mask_buffer_address,
+                        tile_dispatch_buffer_address,
                     },
                 );
                 gpu.cmd_dispatch(cmd_encoder, 1, 1, 1);
@@ -905,7 +927,7 @@ impl<'gpu> DrawState<'gpu> {
                         // -1 due to the count taking up a single slot in the buffer.
                         coarse_buffer_len: COARSE_BUFFER_LEN as u32 - 1,
                         _pad: 0,
-                        indirect_dispatch_buffer_address,
+                        sort_dispatch_buffer_address,
                         coarse_buffer_address,
                     },
                 );
@@ -951,7 +973,7 @@ impl<'gpu> DrawState<'gpu> {
                             spine_buffer_address,
                         },
                     );
-                    gpu.cmd_dispatch_indirect(cmd_encoder, indirect_dispatch_buffer.to_arg(), 0);
+                    gpu.cmd_dispatch_indirect(cmd_encoder, sort_dispatch_buffer.to_arg(), 0);
 
                     gpu.cmd_barrier(
                         cmd_encoder,
@@ -981,7 +1003,7 @@ impl<'gpu> DrawState<'gpu> {
                             spine_buffer_address,
                         },
                     );
-                    gpu.cmd_dispatch_indirect(cmd_encoder, indirect_dispatch_buffer.to_arg(), 0);
+                    gpu.cmd_dispatch_indirect(cmd_encoder, sort_dispatch_buffer.to_arg(), 0);
 
                     gpu.cmd_barrier(
                         cmd_encoder,
@@ -1012,6 +1034,8 @@ impl<'gpu> DrawState<'gpu> {
                         coarse_buffer_address,
                         fine_buffer_address: tmp_buffer_address,
                         tile_buffer_address,
+                        tile_mask_buffer_address,
+                        tile_dispatch_buffer_address,
                     },
                 );
                 gpu.cmd_dispatch(
@@ -1025,7 +1049,7 @@ impl<'gpu> DrawState<'gpu> {
                     cmd_encoder,
                     Some(&GlobalBarrier {
                         prev_access: &[Access::ComputeWrite],
-                        next_access: &[Access::ComputeOtherRead],
+                        next_access: &[Access::ComputeOtherRead, Access::IndirectBuffer],
                     }),
                     &[],
                 );
@@ -1047,12 +1071,7 @@ impl<'gpu> DrawState<'gpu> {
                         tile_buffer_address,
                     },
                 );
-                gpu.cmd_dispatch(
-                    cmd_encoder,
-                    self.width.div_ceil(8),
-                    self.height.div_ceil(8),
-                    1,
-                );
+                gpu.cmd_dispatch_indirect(cmd_encoder, tile_dispatch_buffer.to_arg(), 0);
 
                 gpu.cmd_end_debug_marker(cmd_encoder);
             }
@@ -1097,7 +1116,7 @@ impl<'gpu> DrawState<'gpu> {
                     &CompositeConstants {
                         tile_resolution_x: self.tile_resolution_x,
                         tile_resolution_y: self.tile_resolution_y,
-                        tile_buffer_address,
+                        tile_mask_buffer_address,
                     },
                 );
                 gpu.cmd_dispatch(
index fe4652f6e52c4eb6dbaf4198ba5efe0201761f42..281f99253dda1c3375008e578564a619c74f1f93 100644 (file)
@@ -384,9 +384,11 @@ pub fn main() {
                 let y = height / 2.0 + w * c;
 
                 ui_state.push_scissor(vec2(x - w, y - h), vec2(x + w, y + h), true);
-                ui_state.rect(
-                    0.0, 0.0, width, height, 0.0, [0.0; 4], 0xffffffff, 0xffffffff,
-                );
+                for _ in 0..200 {
+                    ui_state.rect(
+                        0.0, 0.0, width, height, 0.0, [0.0; 4], 0x10101010, 0x10101010,
+                    );
+                }
                 ui_state.pop_scissor();
 
                 ui_state.push_scissor(vec2(x - w, y - h), vec2(x + w, y + h), true);
@@ -423,10 +425,27 @@ pub fn main() {
                         y - 200.0,
                         400.0,
                         400.0,
-                        100.0,
+                        50.0,
                         [100.0, 50.0, 25.0, 0.0],
                         0x33333333,
-                        microshades::BLUE_RGBA8[4].rotate_right(8),
+                        (microshades::BLUE_RGBA8[4] >> 8) | 0xff00_0000,
+                    );
+                }
+
+                for i in 0..50 {
+                    let (s, c) = sin_cos_pi_f32(game_state.time * 0.1 + i as f32 * 0.04);
+
+                    let x = width / 2.0 + w * 0.5 * s;
+                    let y = height / 2.0 + w * 0.5 * c;
+                    ui_state.rect(
+                        x - 200.0,
+                        y - 200.0,
+                        400.0,
+                        400.0,
+                        10.0,
+                        [10.0, 10.0, 10.0, 10.0],
+                        0xffff0000,
+                        0x0,
                     );
                 }