From 79ea22679b81d77e6f6ac33ab003d6868efcf948 Mon Sep 17 00:00:00 2001 From: Joshua Simmons Date: Thu, 16 Oct 2025 22:53:54 +0200 Subject: [PATCH] shark-shaders: Improve the performance of draw 2d shaders Major improvement is to track the alpha value during resolve, and use it to determine a conservative cut-off point for command drawing. During resolve, remove words which became empty after culling. Additionally use an indirect dispatch for drawing work, rather than launching workgroups which will immediately terminate. --- title/shark-shaders/shaders/composite.slang | 15 +- title/shark-shaders/shaders/draw_2d.slang | 251 ++++++++++++------- title/shark-shaders/shaders/radix_sort.slang | 121 +++++---- title/shark-shaders/src/pipelines.rs | 10 +- title/shark/src/draw.rs | 59 +++-- title/shark/src/main.rs | 29 ++- 6 files changed, 304 insertions(+), 181 deletions(-) diff --git a/title/shark-shaders/shaders/composite.slang b/title/shark-shaders/shaders/composite.slang index f02ace9..e3c5833 100644 --- a/title/shark-shaders/shaders/composite.slang +++ b/title/shark-shaders/shaders/composite.slang @@ -2,8 +2,6 @@ import bindings_samplers; import bindings_compute; -import draw_2d; - float srgb_oetf(float a) { return (.0031308f >= a) ? 12.92f * a : 1.055f * pow(a, .4166666666666667f) - .055f; } @@ -21,17 +19,18 @@ float3 tony_mc_mapface(float3 stimulus) { struct CompositeConstants { uint2 tile_resolution; - Draw2d::Tile *tile_buffer; + uint *tile_mask_buffer; } [shader("compute")] [numthreads(8, 8, 1)] void main(uniform CompositeConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 group_id: SV_GroupID) { - let tile_coord = group_id.xy * WorkgroupSize().xy / Draw2d::TILE_SIZE; - let tile_index = tile_coord.y * constants.tile_resolution.x + tile_coord.x; + let tile_coord = thread_id.xy / 32; - let lo = constants.tile_buffer[tile_index].index_min; - let hi = constants.tile_buffer[tile_index].index_max; + let stride = (constants.tile_resolution.x + 31) / 32; + let index = tile_coord.y * stride + tile_coord.x / 32; + let word = constants.tile_mask_buffer[index]; + let mask = 1u << (tile_coord.x & 31); // Display transform let stimulus = color_layer.Load(thread_id.xy).rgb; @@ -39,7 +38,7 @@ void main(uniform CompositeConstants constants, uint3 thread_id: SV_DispatchThre var composited = srgb_oetf(transformed); // UI composite - if (lo != hi) { + if ((word & mask) != 0) { let ui = ui_layer.Load(thread_id.xy).rgba; composited = ui.rgb + (composited * (1.0 - ui.a)); } diff --git a/title/shark-shaders/shaders/draw_2d.slang b/title/shark-shaders/shaders/draw_2d.slang index a9fe81c..539b964 100644 --- a/title/shark-shaders/shaders/draw_2d.slang +++ b/title/shark-shaders/shaders/draw_2d.slang @@ -6,11 +6,6 @@ import sdf; namespace Draw2d { public static const uint TILE_SIZE = 32; - -public struct Tile { - public uint index_min; - public uint index_max; -} } static const uint MAX_TILES = 256; @@ -63,15 +58,31 @@ struct CmdGlyph { }; struct ClearConstants { + uint2 tile_resolution; uint *finished_buffer; uint *coarse_buffer; + uint *tile_mask_buffer; + VkDispatchIndirectCommand *tile_dispatch_buffer; } [shader("compute")] -[numthreads(1, 1, 1)] -void clear(uniform ClearConstants constants) { +[numthreads(64, 1, 1)] +void clear(uniform ClearConstants constants, uint thread_index_in_group: SV_GroupIndex) { + let stride = (constants.tile_resolution.x + 31) / 32; + let size = constants.tile_resolution.y * stride; + + for (uint i = 0; i < size; i += 64) { + let index = i + thread_index_in_group; + if (index < size) { + constants.tile_mask_buffer[index] = 0; + } + } + constants.finished_buffer[0] = 0; constants.coarse_buffer[0] = 0; + constants.tile_dispatch_buffer.x = 1; + constants.tile_dispatch_buffer.y = 1; + constants.tile_dispatch_buffer.z = 0; } struct ScatterConstants { @@ -87,13 +98,13 @@ struct ScatterConstants { }; [vk::specialization_constant] -const int WGP_SIZE = 64; +const int WAVE_SIZE = 64; groupshared uint scatter_intersected_tiles[BITMAP_SIZE]; [shader("compute")] [require(spvGroupNonUniformBallot, spvGroupNonUniformArithmetic, spvGroupNonUniformVote)] -[numthreads(WGP_SIZE, 1, 1)] +[numthreads(WAVE_SIZE, 1, 1)] void scatter(uniform ScatterConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 group_id: SV_GroupID) { let in_bounds = thread_id.x < constants.draw_buffer_len; @@ -265,7 +276,7 @@ struct VkDispatchIndirectCommand { struct SortConstants { uint coarse_buffer_len; uint _pad; - VkDispatchIndirectCommand *indirect_dispatch_buffer; + VkDispatchIndirectCommand *sort_dispatch_buffer; uint *coarse_buffer; }; @@ -284,9 +295,9 @@ void sort(uniform SortConstants constants) { let count = min(constants.coarse_buffer_len, constants.coarse_buffer[0]); constants.coarse_buffer[0] = count; - constants.indirect_dispatch_buffer.x = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP; - constants.indirect_dispatch_buffer.y = 1; - constants.indirect_dispatch_buffer.z = 1; + constants.sort_dispatch_buffer.x = (count + (RADIX_ITEMS_PER_WGP - 1)) / RADIX_ITEMS_PER_WGP; + constants.sort_dispatch_buffer.y = 1; + constants.sort_dispatch_buffer.z = 1; } struct ResolveConstants { @@ -298,31 +309,30 @@ struct ResolveConstants { Glyph *glyph_buffer; uint *coarse_buffer; uint *fine_buffer; - Draw2d::Tile *tile_buffer; + uint4 *tile_buffer; + uint *tile_mask_buffer; + VkDispatchIndirectCommand *tile_dispatch_buffer; }; [shader("compute")] -[require(spvGroupNonUniformBallot, spvGroupNonUniformVote)] -[numthreads(WGP_SIZE, 1, 1)] +[require(spvGroupNonUniformBallot, spvGroupNonUniformShuffle, spvGroupNonUniformVote)] +[numthreads(WAVE_SIZE, 1, 1)] void resolve(uniform ResolveConstants constants, uint3 thread_id: SV_DispatchThreadID) { let x = thread_id.y; let y = thread_id.z; - let tile_offset = constants.tile_stride * y + x; let search = ((y & 0xff) << 24) | ((x & 0xff) << 16); let count = constants.coarse_buffer[0]; if (count == 0) { - constants.tile_buffer[tile_offset].index_min = 0; - constants.tile_buffer[tile_offset].index_max = 0; return; } // Binary search for the upper bound of the tile. - uint base = 0; + var base = 0; { - uint n = count; + var max_iters = 32; + var n = count; uint mid; - uint max_iters = 32; while (max_iters-- > 0 && (mid = n / 2) > 0) { let value = constants.coarse_buffer[1 + base + mid] & 0xffff0000; base = value > search ? base : base + mid; @@ -333,22 +343,24 @@ void resolve(uniform ResolveConstants constants, uint3 thread_id: SV_DispatchThr let tile_min = uint2(x, y) * Draw2d::TILE_SIZE; let tile_max = tile_min + Draw2d::TILE_SIZE; - bool hit_opaque = false; - uint lo = base + 1; + var alpha_carry = 0.0; + var hit_opaque = false; + var lo = base + 1; let hi = base + 1; - for (; !hit_opaque && lo > 0; lo--) { - let i = lo; + for (var i = base + 1; !hit_opaque && i > 0; i--) { let packed = constants.coarse_buffer[i]; + // If we leave the tile we're done. if ((packed & 0xffff0000) != (search & 0xffff0000)) { break; } let draw_offset = packed & 0xffff; - let draw_index = draw_offset * WGP_SIZE + WaveGetLaneIndex(); + let draw_index = draw_offset * WAVE_SIZE + WaveGetLaneIndex(); - bool intersects = false; - bool opaque_tile = false; + var intersects_tile = false; + var covers_tile = false; + var covers_tile_alpha = 0.0; if (draw_index < constants.draw_buffer_len) { var cmd_min = float2(99999.9); @@ -362,7 +374,7 @@ void resolve(uniform ResolveConstants constants, uint3 thread_id: SV_DispatchThr // If the tile doesn't intersect the scissor region it doesn't need to do work here. if (any(scissor.offset_max < tile_min) || any(scissor.offset_min > tile_max)) { - intersects = false; + intersects_tile = false; } else { for (;;) { let scalar_type = WaveReadLaneFirst(cmd_type); @@ -374,18 +386,21 @@ void resolve(uniform ResolveConstants constants, uint3 thread_id: SV_DispatchThr cmd_min = cmd_rect.position; cmd_max = cmd_rect.position + cmd_rect.bound; - const bool background_opaque = (cmd_rect.background_color & 0xff000000) == 0xff000000; - if (background_opaque) { - let border_width = float((packed_type >> 16) & 0xff); - let border_opaque = (cmd_rect.border_color & 0xff000000) == 0xff000000; - let border_radii = unpackUnorm4x8ToFloat(cmd_rect.border_radii); - let max_border_radius = max(border_radii.x, max(border_radii.y, max(border_radii.z, border_radii.w))) * 255.0; - let shrink = ((2.0 - sqrt(2.0)) * max_border_radius) + (border_opaque ? 0.0 : border_width); - - let cmd_shrunk_min = max(scissor.offset_min, cmd_min + shrink); - let cmd_shrunk_max = min(scissor.offset_max, cmd_max - shrink); - opaque_tile = all(cmd_shrunk_max > cmd_shrunk_min) && all(tile_min > cmd_shrunk_min) && all(tile_max < cmd_shrunk_max); - } + let background_alpha = cmd_rect.background_color >> 24; + let border_alpha = cmd_rect.border_color >> 24; + let border_matches_background = background_alpha == border_alpha; + + let border_width = float((packed_type >> 16) & 0xff); + let border_radii = unpackUnorm4x8ToFloat(cmd_rect.border_radii); + let max_border_radius = max(border_radii.x, max(border_radii.y, max(border_radii.z, border_radii.w))) * 255.0; + let shrink = ((2.0 - sqrt(2.0)) * max_border_radius) + (border_matches_background ? 0.0 : border_width); + + let cmd_shrunk_min = max(scissor.offset_min, cmd_min + shrink); + let cmd_shrunk_max = min(scissor.offset_max, cmd_max - shrink); + + covers_tile_alpha = float(background_alpha) * (1.0 / 255.0); + covers_tile = all(cmd_shrunk_max > cmd_shrunk_min) && all(tile_min > cmd_shrunk_min) && all(tile_max < cmd_shrunk_max); + break; case CmdType::Glyph: let cmd_glyph = reinterpret(constants.draw_buffer[draw_index]); @@ -400,39 +415,86 @@ void resolve(uniform ResolveConstants constants, uint3 thread_id: SV_DispatchThr cmd_min = max(cmd_min, scissor.offset_min); cmd_max = min(cmd_max, scissor.offset_max); - intersects = !(any(tile_max < cmd_min) || any(tile_min > cmd_max)); + intersects_tile = !(any(tile_max < cmd_min) || any(tile_min > cmd_max)); + } + } + + var intersects_mask = WaveActiveBallot(intersects_tile).x; + + if (WaveActiveAnyTrue(covers_tile)) { + let transparent = covers_tile && covers_tile_alpha == 0.0; + let non_transparent = covers_tile && covers_tile_alpha != 0.0; + let opaque = covers_tile && covers_tile_alpha == 1.0; + + let transparent_tile_ballot = WaveActiveBallot(transparent).x; + intersects_mask &= ~transparent_tile_ballot; + + if (WaveActiveAnyTrue(opaque)) { + let opaque_tile_ballot = WaveActiveBallot(opaque).x; + let opaque_mask = (1 << firstbithigh(opaque_tile_ballot)) - 1; + hit_opaque = true; + intersects_mask &= ~opaque_mask; + } else if (WaveActiveAnyTrue(non_transparent)) { + // First we read in the alpha carry from the previous coarse + // word, if any. + var alpha = WaveReadLaneFirst(alpha_carry); + + // Then each wave calculates its alpha by blending its + // predecessors into itself. + for (var i = WAVE_SIZE; i-- > 0;) { + if (i >= WaveGetLaneIndex()) { + let cmd_alpha = WaveReadLaneAt(covers_tile_alpha, i); + alpha = cmd_alpha + alpha * (1.0 - cmd_alpha); + } + } + + // Check if any lanes went beyond the threshold. + let considered_opaque = alpha > 0.999; + if (WaveActiveAnyTrue(considered_opaque)) { + let opaque_tile_ballot = WaveActiveBallot(considered_opaque).x; + let opaque_mask = (1 << firstbithigh(opaque_tile_ballot)) - 1; + hit_opaque = true; + intersects_mask &= ~opaque_mask; + } + + // If they didn't go beyond the threshold, we need to update + // the alpha carry value. + alpha_carry = WaveReadLaneFirst(alpha); } } - var intersects_mask = WaveActiveBallot(intersects).x; - - if (WaveActiveAnyTrue(opaque_tile)) { - let opaque_tile_ballot = WaveActiveBallot(opaque_tile); - // TODO: Needs to check all live words of the ballot... - let first_opaque_tile = firstbithigh(opaque_tile_ballot).x; - let opaque_mask = ~((1 << first_opaque_tile) - 1); - intersects_mask &= opaque_mask; - constants.fine_buffer[i] = intersects_mask; - hit_opaque = true; - } else { - constants.fine_buffer[i] = intersects_mask; + if (intersects_mask != 0) { + constants.coarse_buffer[lo] = packed; + constants.fine_buffer[lo] = intersects_mask; + lo--; } } - constants.tile_buffer[tile_offset].index_min = lo + 1; - constants.tile_buffer[tile_offset].index_max = hi + 1; + if (WaveIsFirstLane()) { + if (lo != hi) { + uint dispatch; + InterlockedAdd(constants.tile_dispatch_buffer.z, 16, dispatch); + + let offset = dispatch >> 4; + constants.tile_buffer[offset] = uint4(x, y, lo + 1, hi + 1); + + let stride = (constants.tile_stride + 31) / 32; + let mask = 1u << (x & 31); + InterlockedOr(constants.tile_mask_buffer[y * stride + x / 32], 1 << (x & 31)); + } + } } struct RasterizeConstants { uint tile_stride; uint _pad; - Cmd *draw_buffer; - Scissor *scissor_buffer; - Glyph *glyph_buffer; - uint *coarse_buffer; - uint *fine_buffer; - Draw2d::Tile *tile_buffer; + Ptr draw_buffer; + Ptr scissor_buffer; + Ptr glyph_buffer; + Ptr coarse_buffer; + Ptr fine_buffer; + Ptr tile_buffer; }; /// x = (((index >> 2) & 0x0007) & 0xFFFE) | index & 0x0001 @@ -452,21 +514,20 @@ float3 plasma_quintic(float x) { [shader("compute")] [numthreads(8, 8, 1)] -void rasterize(uniform RasterizeConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 group_id: SV_GroupID) { - let tile_coord = group_id.xy * WorkgroupSize().xy / Draw2d::TILE_SIZE; - let tile_index = tile_coord.y * constants.tile_stride + tile_coord.x; +void rasterize(uniform RasterizeConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 thread_id_in_group: SV_GroupThreadID) { + let tile_index = thread_id.z / 16; + let x = thread_id.z & 3; + let y = (thread_id.z >> 2) & 3; - let lo = constants.tile_buffer[tile_index].index_min; - let hi = constants.tile_buffer[tile_index].index_max; - - if (lo == hi) { - return; - } + let tile = constants.tile_buffer[tile_index]; + let position = tile.xy * Draw2d::TILE_SIZE + uint2(x, y) * WorkgroupSize().xy + thread_id_in_group.xy; + let lo = tile.z; + let hi = tile.w; #if DEBUG_SHOW_TILES == 1 - let color = plasma_quintic(float(hi - lo) / 50.0); - ui_layer_write.Store(thread_id.xy, float4(color, 1.0)); + let color = plasma_quintic(float(hi - lo) / 16.0); + ui_layer.Store(position, float4(color, 1.0)); #elif DEBUG_SHOW_TILES == 2 @@ -474,26 +535,28 @@ void rasterize(uniform RasterizeConstants constants, uint3 thread_id: SV_Dispatc for (uint i = lo; i < hi; i++) { count += countbits(constants.fine_buffer[i]); } - let color = plasma_quintic(float(count) / 600.0); - ui_layer_write.Store(thread_id.xy, float4(color, 1.0)); + let color = count == 1 ? float3(1.0, 0.0, 0.0) : plasma_quintic(float(count) / 300.0); + ui_layer.Store(position, float4(color, 1.0)); #else - let sample_center = thread_id.xy + float2(0.5); + let sample_center = float2(position) + float2(0.5); var accum = float4(0.0); - for (uint i = lo; i < hi; i++) { + var i = lo; + // lo != hi, or the group wouldn't have been dispatched. + do { + let base = (constants.coarse_buffer[i] & 0xffff) * 32; var bitmap = constants.fine_buffer[i]; - while (bitmap != 0) { + // Any bitmap in the fine buffer is non-zero. + do { let index = firstbitlow(bitmap); bitmap ^= bitmap & -bitmap; - let base_index = (constants.coarse_buffer[i] & 0xffff) * 32; - let cmd = constants.draw_buffer[base_index + index]; + let cmd = constants.draw_buffer[base + index]; let cmd_type = cmd.packed_type >> 24; let cmd_scissor = cmd.packed_type & 0xffff; - let scissor = constants.scissor_buffer[cmd_scissor]; var primitive_color = float4(0.0); @@ -545,12 +608,16 @@ void rasterize(uniform RasterizeConstants constants, uint3 thread_id: SV_Dispatc let glyph = constants.glyph_buffer[cmd_glyph.index]; let cmd_min = cmd_glyph.position + glyph.offset_min; let cmd_max = cmd_glyph.position + glyph.offset_max; - if (all(sample_center >= max(scissor.offset_min, cmd_min)) && all(sample_center <= min(scissor.offset_max, cmd_max))) { - let glyph_size = glyph.offset_max - glyph.offset_min; - let uv = lerp(glyph.atlas_min, glyph.atlas_max, (sample_center - cmd_min) / glyph_size); - let color = unpackUnorm4x8ToFloat(cmd_glyph.color).bgra; - let coverage = glyph_atlas.SampleLevel(samplers[Sampler::BilinearUnnormalized], uv, 0.0).r * color.a; - primitive_color = color * coverage; + if (all(sample_center >= cmd_min) && all(sample_center <= cmd_max)) { + let cmd_min_clipped = max(scissor.offset_min, cmd_min); + let cmd_max_clipped = min(scissor.offset_max, cmd_max); + if (all(sample_center >= cmd_min_clipped) && all(sample_center <= cmd_max_clipped)) { + let glyph_size = glyph.offset_max - glyph.offset_min; + let uv = lerp(glyph.atlas_min, glyph.atlas_max, (sample_center - cmd_min) / glyph_size); + let color = unpackUnorm4x8ToFloat(cmd_glyph.color).bgra; + let coverage = glyph_atlas.SampleLevel(samplers[Sampler::BilinearUnnormalized], uv, 0.0).r * color.a; + primitive_color = color * coverage; + } } break; } @@ -558,10 +625,12 @@ void rasterize(uniform RasterizeConstants constants, uint3 thread_id: SV_Dispatc // does it blend? accum.rgba = primitive_color.rgba + accum.rgba * (1.0 - primitive_color.a); - } - } + } while (bitmap != 0); + + i++; + } while (i != hi); - ui_layer.Store(thread_id.xy, accum); + ui_layer.Store(position, accum); #endif } diff --git a/title/shark-shaders/shaders/radix_sort.slang b/title/shark-shaders/shaders/radix_sort.slang index 120569e..d7d3f07 100644 --- a/title/shark-shaders/shaders/radix_sort.slang +++ b/title/shark-shaders/shaders/radix_sort.slang @@ -34,10 +34,10 @@ static const uint WAVE_COUNT = RADIX_GROUP_SIZE / WAVE_SIZE; struct UpsweepConstants { uint shift; uint _pad; - uint *finished_buffer; - uint *count_buffer; - uint *src_buffer; - uint *spine_buffer; + Ptr finished_buffer; + Ptr count_buffer; + Ptr src_buffer; + Ptr spine_buffer; }; groupshared uint histogram[RADIX_DIGITS]; @@ -49,7 +49,7 @@ groupshared uint sums[WAVE_COUNT]; [shader("compute")] [require(spvGroupNonUniformBallot, spvGroupNonUniformArithmetic)] [numthreads(RADIX_GROUP_SIZE, 1, 1)] -void upsweep(uniform UpsweepConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 group_id: SV_GroupID, uint3 thread_id_in_group: SV_GroupThreadID) { +void upsweep(uniform UpsweepConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 group_id: SV_GroupID, uint thread_index_in_group: SV_GroupIndex) { let shift = constants.shift; let count = constants.count_buffer[0]; @@ -58,46 +58,57 @@ void upsweep(uniform UpsweepConstants constants, uint3 thread_id: SV_DispatchThr // Clear local histogram. // Assumes RADIX_GROUP_SIZE == RADIX_DIGITS - histogram[thread_id_in_group.x] = 0; + histogram[thread_index_in_group] = 0; + // Ensure we've finished clearing the LDS histogram. GroupMemoryBarrierWithGroupSync(); if (is_last_group_in_dispatch) { for (uint i = 0; i < RADIX_ITEMS_PER_THREAD; i++) { - const uint src_index = group_id.x * WorkgroupSize().x * RADIX_ITEMS_PER_THREAD + i * RADIX_DIGITS + thread_id_in_group.x; - if (src_index < count) { - const uint value = constants.src_buffer[src_index]; - const uint digit = (value >> shift) & RADIX_MASK; - InterlockedAdd(histogram[digit], 1); - } + let src_index = group_id.x * WorkgroupSize().x * RADIX_ITEMS_PER_THREAD + i * RADIX_DIGITS + thread_index_in_group; + // This will count out-of-bounds values into the last histogram + // bucket, but since it's only happening for the last group in a + // dispatch, and because it's the last bucket, it won't affect the + // results. + let value = src_index < count ? constants.src_buffer[src_index] : 0xffffffff; + let digit = (value >> shift) & RADIX_MASK; + InterlockedAdd(histogram[digit], 1); } } else { for (uint i = 0; i < RADIX_ITEMS_PER_THREAD; i++) { - const uint src_index = group_id.x * WorkgroupSize().x * RADIX_ITEMS_PER_THREAD + i * RADIX_DIGITS + thread_id_in_group.x; - const uint value = constants.src_buffer[src_index]; - const uint digit = (value >> shift) & RADIX_MASK; + let src_index = group_id.x * WorkgroupSize().x * RADIX_ITEMS_PER_THREAD + i * RADIX_DIGITS + thread_index_in_group; + let value = constants.src_buffer[src_index]; + let digit = (value >> shift) & RADIX_MASK; InterlockedAdd(histogram[digit], 1); } } + // Ensure we've finished updating the histogram in LDS. GroupMemoryBarrierWithGroupSync(); // Scatter to the spine, this is a striped layout so we can efficiently // calculate the prefix sum. Re-calculate how many workgroups we dispatched // to determine the stride we need to write at. - constants.spine_buffer[(thread_id_in_group.x * dispatch_group_count) + group_id.x] = histogram[thread_id_in_group.x]; + // + // Note the spine buffer size is rounded up so there's no need for bounds + // checking. + constants.spine_buffer[(thread_index_in_group * dispatch_group_count) + group_id.x] = histogram[thread_index_in_group]; + // Ensure the spine has been written before we increment the finished + // atomic counter. DeviceMemoryBarrierWithGroupSync(); // Store whether we're the last-executing group in LDS. This contrasts with // the 'static' `is_last_group_in_dispatch`, which represents the group with // the largest group_id in the dispatch. - if (thread_id_in_group.x == 0) { + if (thread_index_in_group == 0) { var old_value = 0; InterlockedAdd(*constants.finished_buffer, 1, old_value); is_last_group_dynamic = old_value == dispatch_group_count - 1; } + // Ensure all waves read the value of `is_last_group_dynamic` that we just + // filled. GroupMemoryBarrierWithGroupSync(); // Only the last-executing group needs to continue, it will mop up the spine @@ -106,20 +117,21 @@ void upsweep(uniform UpsweepConstants constants, uint3 thread_id: SV_DispatchThr return; } - // Reset for the next pass. - InterlockedExchange(*constants.finished_buffer, 0); + // Reset for the next pass, this can be a simple store as there's a barrier + // between passes, and we are the only group executing at this point. + *constants.finished_buffer = 0; - let wave_id = thread_id_in_group.x / WaveGetLaneCount(); + let wave_index = thread_index_in_group / WaveGetLaneCount(); carry = 0; for (uint i = 0; i < dispatch_group_count; i++) { // Load values and calculate partial sums. - let value = constants.spine_buffer[i * RADIX_DIGITS + thread_id_in_group.x]; + let value = constants.spine_buffer[i * RADIX_DIGITS + thread_index_in_group]; let sum = WaveActiveSum(value); let scan = WavePrefixSum(value); if (WaveIsFirstLane()) { - sums[wave_id] = sum; + sums[wave_index] = sum; } // Even though we read and write from the spine, this can be a group @@ -135,19 +147,20 @@ void upsweep(uniform UpsweepConstants constants, uint3 thread_id: SV_DispatchThr let carry_in = carry; // Scan partials. - if (thread_id_in_group.x < WAVE_COUNT) { - sums[thread_id_in_group.x] = WavePrefixSum(sums[thread_id_in_group.x]); + if (thread_index_in_group < WAVE_COUNT) { + sums[thread_index_in_group] = WavePrefixSum(sums[thread_index_in_group]); } + // Make sure we've finished turning the partial sums into a prefix sum. GroupMemoryBarrierWithGroupSync(); // Write out the final prefix sum, combining the carry-in, wave sums, // and local scan. - constants.spine_buffer[i * RADIX_DIGITS + thread_id_in_group.x] = carry_in + sums[wave_id] + scan; + constants.spine_buffer[i * RADIX_DIGITS + thread_index_in_group] = carry_in + sums[wave_index] + scan; // `sums` in LDS now contains partials, so we need to also add the wave // sum to get an inclusive prefix sum for the next iteration. - if (wave_id == WAVE_COUNT - 1 && WaveIsFirstLane()) { + if (wave_index == WAVE_COUNT - 1 && WaveIsFirstLane()) { InterlockedAdd(carry, sums[WAVE_COUNT - 1] + sum); } } @@ -156,10 +169,10 @@ void upsweep(uniform UpsweepConstants constants, uint3 thread_id: SV_DispatchThr struct DownsweepConstants { uint shift; uint _pad; - uint *count_buffer; - uint *spine_buffer; - uint *src_buffer; - uint *dst_buffer; + Ptr count_buffer; + Ptr spine_buffer; + Ptr src_buffer; + Ptr dst_buffer; } groupshared uint spine[RADIX_DIGITS]; @@ -168,77 +181,75 @@ groupshared uint match_masks[WAVE_COUNT][RADIX_DIGITS]; [shader("compute")] [require(spvGroupNonUniformBallot)] [numthreads(RADIX_GROUP_SIZE, 1, 1)] -void downsweep(uniform DownsweepConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 group_id: SV_GroupID, uint3 thread_id_in_group: SV_GroupThreadID) { +void downsweep(uniform DownsweepConstants constants, uint3 thread_id: SV_DispatchThreadID, uint3 group_id: SV_GroupID, uint thread_index_in_group: SV_GroupIndex) { let shift = constants.shift; let count = constants.count_buffer[0]; let dispatch_group_count = radix_sort::CalculateGroupCountForItemCount(count); let is_last_group_in_dispatch = group_id.x == dispatch_group_count - 1; - let wave_id = thread_id_in_group.x / WaveGetLaneCount(); + let wave_index = thread_index_in_group / WaveGetLaneCount(); // Gather from spine buffer into LDS. - spine[thread_id_in_group.x] = constants.spine_buffer[thread_id_in_group.x * dispatch_group_count + group_id.x]; + spine[thread_index_in_group] = constants.spine_buffer[thread_index_in_group * dispatch_group_count + group_id.x]; if (is_last_group_in_dispatch) { for (uint i = 0; i < RADIX_ITEMS_PER_THREAD; i++) { // Clear shared memory and load values from src buffer. for (uint j = 0; j < WAVE_COUNT; j++) { - match_masks[j][thread_id_in_group.x] = 0; + match_masks[j][thread_index_in_group] = 0; } GroupMemoryBarrierWithGroupSync(); - let index = group_id.x * RADIX_ITEMS_PER_GROUP + i * RADIX_DIGITS + thread_id_in_group.x; + let index = group_id.x * RADIX_ITEMS_PER_GROUP + i * RADIX_DIGITS + thread_index_in_group; let value = index < count ? constants.src_buffer[index] : 0xffffffff; let digit = (value >> shift) & RADIX_MASK; - InterlockedOr(match_masks[wave_id][digit], 1 << WaveGetLaneIndex()); + InterlockedOr(match_masks[wave_index][digit], 1 << WaveGetLaneIndex()); GroupMemoryBarrierWithGroupSync(); - uint peer_scan = 0; + var peer_scan = 0; for (uint j = 0; j < WAVE_COUNT; j++) { - if (j < wave_id) { - peer_scan += countbits(match_masks[j][digit]); - } + peer_scan += j < wave_index ? countbits(match_masks[j][digit]) : 0; } - peer_scan += countbits(match_masks[wave_id][digit] & WaveLtMask().x); + peer_scan += countbits(match_masks[wave_index][digit] & WaveLtMask().x); if (index < count) { constants.dst_buffer[spine[digit] + peer_scan] = value; } - GroupMemoryBarrierWithGroupSync(); + if (i != RADIX_ITEMS_PER_THREAD - 1) { + GroupMemoryBarrierWithGroupSync(); - // Increment the spine with the counts for the workgroup we just - // wrote out. - for (uint i = 0; i < WAVE_COUNT; i++) { - InterlockedAdd(spine[thread_id_in_group.x], countbits(match_masks[i][thread_id_in_group.x])); + // Increment the spine with the counts for the workgroup we just + // wrote out. + for (uint i = 0; i < WAVE_COUNT; i++) { + InterlockedAdd(spine[thread_index_in_group], countbits(match_masks[i][thread_index_in_group])); + } } } } else { for (uint i = 0; i < RADIX_ITEMS_PER_THREAD; i++) { // Clear shared memory and load values from src buffer. for (uint j = 0; j < WAVE_COUNT; j++) { - match_masks[j][thread_id_in_group.x] = 0; + match_masks[j][thread_index_in_group] = 0; } GroupMemoryBarrierWithGroupSync(); - let index = group_id.x * RADIX_ITEMS_PER_GROUP + i * RADIX_DIGITS + thread_id_in_group.x; + let index = group_id.x * RADIX_ITEMS_PER_GROUP + i * RADIX_DIGITS + thread_index_in_group; let value = constants.src_buffer[index]; let digit = (value >> shift) & RADIX_MASK; - InterlockedOr(match_masks[wave_id][digit], 1 << WaveGetLaneIndex()); + InterlockedOr(match_masks[wave_index][digit], 1 << WaveGetLaneIndex()); GroupMemoryBarrierWithGroupSync(); - uint peer_scan = 0; + var peer_scan = 0; for (uint j = 0; j < WAVE_COUNT; j++) { - if (j < wave_id) { - peer_scan += countbits(match_masks[j][digit]); - } + peer_scan += j < wave_index ? countbits(match_masks[j][digit]) : 0; } - peer_scan += countbits(match_masks[wave_id][digit] & WaveLtMask().x); + peer_scan += countbits(match_masks[wave_index][digit] & WaveLtMask().x); constants.dst_buffer[spine[digit] + peer_scan] = value; @@ -248,7 +259,7 @@ void downsweep(uniform DownsweepConstants constants, uint3 thread_id: SV_Dispatc // Increment the spine with the counts for the workgroup we just // wrote out. for (uint i = 0; i < WAVE_COUNT; i++) { - InterlockedAdd(spine[thread_id_in_group.x], countbits(match_masks[i][thread_id_in_group.x])); + InterlockedAdd(spine[thread_index_in_group], countbits(match_masks[i][thread_index_in_group])); } } } diff --git a/title/shark-shaders/src/pipelines.rs b/title/shark-shaders/src/pipelines.rs index 6c3fef1..0ae7369 100644 --- a/title/shark-shaders/src/pipelines.rs +++ b/title/shark-shaders/src/pipelines.rs @@ -176,8 +176,12 @@ pub struct BasicConstants<'a> { #[repr(C)] pub struct Draw2dClearConstants<'a> { + pub tile_resolution_x: u32, + pub tile_resolution_y: u32, pub finished_buffer_address: BufferAddress<'a>, pub coarse_buffer_address: BufferAddress<'a>, + pub tile_mask_buffer_address: BufferAddress<'a>, + pub tile_dispatch_buffer_address: BufferAddress<'a>, } #[repr(C)] @@ -198,7 +202,7 @@ pub struct Draw2dScatterConstants<'a> { pub struct Draw2dSortConstants<'a> { pub coarse_buffer_len: u32, pub _pad: u32, - pub indirect_dispatch_buffer_address: BufferAddress<'a>, + pub sort_dispatch_buffer_address: BufferAddress<'a>, pub coarse_buffer_address: BufferAddress<'a>, } @@ -213,6 +217,8 @@ pub struct Draw2dResolveConstants<'a> { pub coarse_buffer_address: BufferAddress<'a>, pub fine_buffer_address: BufferAddress<'a>, pub tile_buffer_address: BufferAddress<'a>, + pub tile_mask_buffer_address: BufferAddress<'a>, + pub tile_dispatch_buffer_address: BufferAddress<'a>, } #[repr(C)] @@ -232,7 +238,7 @@ pub struct Draw2dRasterizeConstants<'a> { pub struct CompositeConstants<'a> { pub tile_resolution_x: u32, pub tile_resolution_y: u32, - pub tile_buffer_address: BufferAddress<'a>, + pub tile_mask_buffer_address: BufferAddress<'a>, } #[repr(C)] diff --git a/title/shark/src/draw.rs b/title/shark/src/draw.rs index 37dac90..05db1ba 100644 --- a/title/shark/src/draw.rs +++ b/title/shark/src/draw.rs @@ -750,16 +750,14 @@ impl<'gpu> DrawState<'gpu> { ], ); - let tile_buffer = gpu.request_transient_buffer( + let tile_mask_buffer = gpu.request_transient_buffer( frame, thread_token, BufferUsageFlags::STORAGE, - self.tile_resolution_x as usize - * self.tile_resolution_y as usize - * std::mem::size_of::() - * 2, + (self.tile_resolution_x.div_ceil(32) * self.tile_resolution_y) as usize + * std::mem::size_of::(), ); - let tile_buffer_address = gpu.get_buffer_address(tile_buffer.to_arg()); + let tile_mask_buffer_address = gpu.get_buffer_address(tile_mask_buffer.to_arg()); // Render UI { @@ -769,6 +767,17 @@ impl<'gpu> DrawState<'gpu> { microshades::PURPLE_RGBA_F32[3], ); + let tile_buffer = gpu.request_transient_buffer( + frame, + thread_token, + BufferUsageFlags::STORAGE, + self.tile_resolution_x as usize + * self.tile_resolution_y as usize + * std::mem::size_of::() + * 4, + ); + let tile_buffer_address = gpu.get_buffer_address(tile_buffer.to_arg()); + let draw_buffer = gpu.request_transient_buffer_with_data( frame, thread_token, @@ -800,7 +809,14 @@ impl<'gpu> DrawState<'gpu> { COARSE_BUFFER_LEN * std::mem::size_of::(), ); - let indirect_dispatch_buffer = gpu.request_transient_buffer( + let sort_dispatch_buffer = gpu.request_transient_buffer( + frame, + thread_token, + BufferUsageFlags::INDIRECT, + 3 * std::mem::size_of::(), + ); + + let tile_dispatch_buffer = gpu.request_transient_buffer( frame, thread_token, BufferUsageFlags::INDIRECT, @@ -832,8 +848,10 @@ impl<'gpu> DrawState<'gpu> { let scissor_buffer_address = gpu.get_buffer_address(scissor_buffer.to_arg()); let glyph_buffer_address = gpu.get_buffer_address(glyph_buffer.to_arg()); let coarse_buffer_address = gpu.get_buffer_address(coarse_buffer.to_arg()); - let indirect_dispatch_buffer_address = - gpu.get_buffer_address(indirect_dispatch_buffer.to_arg()); + let sort_dispatch_buffer_address = + gpu.get_buffer_address(sort_dispatch_buffer.to_arg()); + let tile_dispatch_buffer_address = + gpu.get_buffer_address(tile_dispatch_buffer.to_arg()); let finished_buffer_address = gpu.get_buffer_address(finished_buffer.to_arg()); let tmp_buffer_address = gpu.get_buffer_address(tmp_buffer.to_arg()); let spine_buffer_address = gpu.get_buffer_address(spine_buffer.to_arg()); @@ -845,8 +863,12 @@ impl<'gpu> DrawState<'gpu> { ShaderStageFlags::COMPUTE, 0, &Draw2dClearConstants { + tile_resolution_x: self.tile_resolution_x, + tile_resolution_y: self.tile_resolution_y, finished_buffer_address, coarse_buffer_address, + tile_mask_buffer_address, + tile_dispatch_buffer_address, }, ); gpu.cmd_dispatch(cmd_encoder, 1, 1, 1); @@ -905,7 +927,7 @@ impl<'gpu> DrawState<'gpu> { // -1 due to the count taking up a single slot in the buffer. coarse_buffer_len: COARSE_BUFFER_LEN as u32 - 1, _pad: 0, - indirect_dispatch_buffer_address, + sort_dispatch_buffer_address, coarse_buffer_address, }, ); @@ -951,7 +973,7 @@ impl<'gpu> DrawState<'gpu> { spine_buffer_address, }, ); - gpu.cmd_dispatch_indirect(cmd_encoder, indirect_dispatch_buffer.to_arg(), 0); + gpu.cmd_dispatch_indirect(cmd_encoder, sort_dispatch_buffer.to_arg(), 0); gpu.cmd_barrier( cmd_encoder, @@ -981,7 +1003,7 @@ impl<'gpu> DrawState<'gpu> { spine_buffer_address, }, ); - gpu.cmd_dispatch_indirect(cmd_encoder, indirect_dispatch_buffer.to_arg(), 0); + gpu.cmd_dispatch_indirect(cmd_encoder, sort_dispatch_buffer.to_arg(), 0); gpu.cmd_barrier( cmd_encoder, @@ -1012,6 +1034,8 @@ impl<'gpu> DrawState<'gpu> { coarse_buffer_address, fine_buffer_address: tmp_buffer_address, tile_buffer_address, + tile_mask_buffer_address, + tile_dispatch_buffer_address, }, ); gpu.cmd_dispatch( @@ -1025,7 +1049,7 @@ impl<'gpu> DrawState<'gpu> { cmd_encoder, Some(&GlobalBarrier { prev_access: &[Access::ComputeWrite], - next_access: &[Access::ComputeOtherRead], + next_access: &[Access::ComputeOtherRead, Access::IndirectBuffer], }), &[], ); @@ -1047,12 +1071,7 @@ impl<'gpu> DrawState<'gpu> { tile_buffer_address, }, ); - gpu.cmd_dispatch( - cmd_encoder, - self.width.div_ceil(8), - self.height.div_ceil(8), - 1, - ); + gpu.cmd_dispatch_indirect(cmd_encoder, tile_dispatch_buffer.to_arg(), 0); gpu.cmd_end_debug_marker(cmd_encoder); } @@ -1097,7 +1116,7 @@ impl<'gpu> DrawState<'gpu> { &CompositeConstants { tile_resolution_x: self.tile_resolution_x, tile_resolution_y: self.tile_resolution_y, - tile_buffer_address, + tile_mask_buffer_address, }, ); gpu.cmd_dispatch( diff --git a/title/shark/src/main.rs b/title/shark/src/main.rs index fe4652f..281f992 100644 --- a/title/shark/src/main.rs +++ b/title/shark/src/main.rs @@ -384,9 +384,11 @@ pub fn main() { let y = height / 2.0 + w * c; ui_state.push_scissor(vec2(x - w, y - h), vec2(x + w, y + h), true); - ui_state.rect( - 0.0, 0.0, width, height, 0.0, [0.0; 4], 0xffffffff, 0xffffffff, - ); + for _ in 0..200 { + ui_state.rect( + 0.0, 0.0, width, height, 0.0, [0.0; 4], 0x10101010, 0x10101010, + ); + } ui_state.pop_scissor(); ui_state.push_scissor(vec2(x - w, y - h), vec2(x + w, y + h), true); @@ -423,10 +425,27 @@ pub fn main() { y - 200.0, 400.0, 400.0, - 100.0, + 50.0, [100.0, 50.0, 25.0, 0.0], 0x33333333, - microshades::BLUE_RGBA8[4].rotate_right(8), + (microshades::BLUE_RGBA8[4] >> 8) | 0xff00_0000, + ); + } + + for i in 0..50 { + let (s, c) = sin_cos_pi_f32(game_state.time * 0.1 + i as f32 * 0.04); + + let x = width / 2.0 + w * 0.5 * s; + let y = height / 2.0 + w * 0.5 * c; + ui_state.rect( + x - 200.0, + y - 200.0, + 400.0, + 400.0, + 10.0, + [10.0, 10.0, 10.0, 10.0], + 0xffff0000, + 0x0, ); } -- 2.51.1