From fd08276bb672e02512a9c459f6d032b199950a24 Mon Sep 17 00:00:00 2001 From: Josh Simmons Date: Thu, 18 Jul 2024 23:19:39 +0200 Subject: [PATCH] shark: Improve shader performance Scalarize all the things! Remove extra passes over the input primitives buffer. Use AABB of entire subgroup to determine which tiles to write. --- .../shaders/display_transform.comp.glsl | 6 +- title/shark-shaders/shaders/primitive_2d.h | 9 +- .../shaders/primitive_2d_bin.comp.glsl | 98 +++++++++---------- .../shaders/primitive_2d_bin_clear.comp.glsl | 12 +-- .../shaders/primitive_2d_rasterize.comp.glsl | 29 +++--- title/shark/src/main.rs | 17 ++-- 6 files changed, 87 insertions(+), 84 deletions(-) diff --git a/title/shark-shaders/shaders/display_transform.comp.glsl b/title/shark-shaders/shaders/display_transform.comp.glsl index be5cb3e..be48b91 100644 --- a/title/shark-shaders/shaders/display_transform.comp.glsl +++ b/title/shark-shaders/shaders/display_transform.comp.glsl @@ -37,9 +37,9 @@ void main() { TilesRead tiles_read = TilesRead(uniforms.tiles); - const uint first = tiles_read.values[tile_base + TILE_BITMAP_RANGE_OFFSET + 0]; - const uint last = tiles_read.values[tile_base + TILE_BITMAP_RANGE_OFFSET + 1]; - if (first <= last) { + const uint lo = tiles_read.values[tile_base + TILE_BITMAP_RANGE_LO_OFFSET]; + const uint hi = tiles_read.values[tile_base + TILE_BITMAP_RANGE_HI_OFFSET]; + if (lo <= hi) { const vec4 ui = imageLoad(ui_layer_read, ivec2(gl_GlobalInvocationID.xy)).rgba; composited = ui.rgb + (composited * (1.0 - ui.a)); } diff --git a/title/shark-shaders/shaders/primitive_2d.h b/title/shark-shaders/shaders/primitive_2d.h index 37ebb1f..010735d 100644 --- a/title/shark-shaders/shaders/primitive_2d.h +++ b/title/shark-shaders/shaders/primitive_2d.h @@ -4,11 +4,12 @@ const uint MAX_PRIMS = 1 << 18; const uint TILE_BITMAP_L1_WORDS = (MAX_PRIMS / 32 / 32); const uint TILE_BITMAP_L0_WORDS = (MAX_PRIMS / 32); const uint TILE_STRIDE = (TILE_BITMAP_L0_WORDS + TILE_BITMAP_L1_WORDS + 2); -const uint TILE_BITMAP_RANGE_OFFSET = 0; -const uint TILE_BITMAP_L1_OFFSET = 2; -const uint TILE_BITMAP_L0_OFFSET = TILE_BITMAP_L1_OFFSET + TILE_BITMAP_L1_WORDS; +const uint TILE_BITMAP_RANGE_LO_OFFSET = 0; +const uint TILE_BITMAP_RANGE_HI_OFFSET = (TILE_BITMAP_RANGE_LO_OFFSET + 1); +const uint TILE_BITMAP_L1_OFFSET = (TILE_BITMAP_RANGE_HI_OFFSET + 1); +const uint TILE_BITMAP_L0_OFFSET = (TILE_BITMAP_L1_OFFSET + TILE_BITMAP_L1_WORDS); -bool test_glyph(uint index, uvec2 tile_min, uvec2 tile_max) { +bool test_glyph(uint index, vec2 tile_min, vec2 tile_max) { const GlyphInstance gi = uniforms.glyph_instances.values[index]; const Glyph gl = uniforms.glyphs.values[gi.index]; const vec2 glyph_min = gi.position + gl.offset_min; diff --git a/title/shark-shaders/shaders/primitive_2d_bin.comp.glsl b/title/shark-shaders/shaders/primitive_2d_bin.comp.glsl index 298f040..0010d71 100644 --- a/title/shark-shaders/shaders/primitive_2d_bin.comp.glsl +++ b/title/shark-shaders/shaders/primitive_2d_bin.comp.glsl @@ -7,80 +7,80 @@ #extension GL_EXT_scalar_block_layout : require #extension GL_EXT_control_flow_attributes : require -#extension GL_KHR_shader_subgroup_vote : require +#extension GL_KHR_shader_subgroup_arithmetic : require #extension GL_KHR_shader_subgroup_ballot : require +#extension GL_KHR_shader_subgroup_vote : require #include "compute_bindings.h" #include "primitive_2d.h" const uint SUBGROUP_SIZE = 64; -const uint NUM_PRIMS_WG = (SUBGROUP_SIZE * 32); +const uint NUM_SUBGROUPS = 16; +const uint NUM_PRIMITIVES_WG = (SUBGROUP_SIZE * NUM_SUBGROUPS); // TODO: Spec constant support for different subgroup sizes. layout (local_size_x = SUBGROUP_SIZE, local_size_y = 1, local_size_z = 1) in; -shared uint bitmap_0[SUBGROUP_SIZE]; - void main() { - const uvec2 bin_coord = gl_GlobalInvocationID.yz; - const uvec2 bin_min = bin_coord * TILE_SIZE * 8; - const uvec2 bin_max = min(bin_min + TILE_SIZE * 8, uniforms.screen_resolution); - - for (uint i = 0; i < NUM_PRIMS_WG; i += gl_SubgroupSize.x) { - const uint prim_index = gl_WorkGroupID.x * NUM_PRIMS_WG + i + gl_SubgroupInvocationID; - bool intersects = false; - if (prim_index < uniforms.num_primitives) { - const GlyphInstance gi = uniforms.glyph_instances.values[prim_index]; + uint word_index = 0; + + for (uint i = 0; i < NUM_PRIMITIVES_WG; i += gl_SubgroupSize.x) { + const uint primitive_index = gl_WorkGroupID.x * NUM_PRIMITIVES_WG + i + gl_SubgroupInvocationID; + + vec2 primitive_min = vec2(99999.9); + vec2 primitive_max = vec2(-99999.9); + + if (primitive_index < uniforms.num_primitives) { + const GlyphInstance gi = uniforms.glyph_instances.values[primitive_index]; const Glyph gl = uniforms.glyphs.values[gi.index]; - const vec2 glyph_min = gi.position + gl.offset_min; - const vec2 glyph_max = gi.position + gl.offset_max; - intersects = !(any(lessThan(bin_max, glyph_min)) || any(greaterThan(bin_min, glyph_max))); + primitive_min = gi.position + gl.offset_min; + primitive_max = gi.position + gl.offset_max; } - const uvec4 ballot = subgroupBallot(intersects); - bitmap_0[i / 32 + 0] = ballot.x; - bitmap_0[i / 32 + 1] = ballot.y; - } - memoryBarrierShared(); + const vec2 primitives_min = subgroupMin(primitive_min); + const vec2 primitives_max = subgroupMax(primitive_max); - const uint x = gl_SubgroupInvocationID.x & 7; - const uint y = gl_SubgroupInvocationID.x >> 3; - const uvec2 tile_coord = gl_GlobalInvocationID.yz * 8 + uvec2(x, y); - const uvec2 tile_min = tile_coord * TILE_SIZE; - const uvec2 tile_max = min(tile_min + TILE_SIZE, uniforms.screen_resolution); + if (any(greaterThan(primitives_min, uniforms.screen_resolution)) || any(lessThan(primitives_max, vec2(0.0)))) { + word_index += 2; + continue; + } - if (all(lessThan(tile_min, tile_max))) { - const uint tile_index = tile_coord.y * uniforms.tile_stride + tile_coord.x; + ivec2 bin_start = ivec2(floor(max(min(primitives_min, uniforms.screen_resolution), 0.0) / TILE_SIZE)); + ivec2 bin_end = ivec2(floor((max(min(primitives_max, uniforms.screen_resolution), 0.0) + (TILE_SIZE - 1)) / TILE_SIZE)); - for (uint i = 0; i < 2; i++) { - uint out_1 = 0; + for (int y = bin_start.y; y < bin_end.y; y++) { + for (int x = bin_start.x; x < bin_end.x; x++) { + const uvec2 bin_coord = uvec2(x, y); + const uint bin_index = bin_coord.y * uniforms.tile_stride + bin_coord.x; + const vec2 bin_min = bin_coord * TILE_SIZE; + const vec2 bin_max = min(bin_min + TILE_SIZE, uniforms.screen_resolution); - for (uint j = 0; j < 32; j++) { - uint out_0 = 0; - uint index_0 = i * 32 + j; - uint word_0 = bitmap_0[index_0]; - while (word_0 != 0) { - const uint bit_0 = findLSB(word_0); - word_0 ^= word_0 & -word_0; + const bool intersects = !(any(lessThan(bin_max, primitive_min)) || any(greaterThan(bin_min, primitive_max))); + const uvec4 ballot = subgroupBallot(intersects); - const uint prim_index = gl_WorkGroupID.x * NUM_PRIMS_WG + index_0 * 32 + bit_0; - if (test_glyph(prim_index, tile_min, tile_max)) { - out_0 |= 1 << bit_0; - } + if (ballot.x == 0 && ballot.y == 0) { + continue; } - if (out_0 != 0) { - out_1 |= 1 << j; - uniforms.tiles.values[tile_index * TILE_STRIDE + TILE_BITMAP_L0_OFFSET + gl_WorkGroupID.x * 64 + index_0] = out_0; + if (ballot.x != 0) { + uniforms.tiles.values[bin_index * TILE_STRIDE + TILE_BITMAP_L0_OFFSET + gl_WorkGroupID.x * 32 + word_index + 0] = ballot.x; + } + + if (ballot.y != 0) { + uniforms.tiles.values[bin_index * TILE_STRIDE + TILE_BITMAP_L0_OFFSET + gl_WorkGroupID.x * 32 + word_index + 1] = ballot.y; } - } - uniforms.tiles.values[tile_index * TILE_STRIDE + TILE_BITMAP_L1_OFFSET + gl_WorkGroupID.x * 2 + i] = out_1; + if (subgroupElect()) { + uniforms.tiles.values[bin_index * TILE_STRIDE + TILE_BITMAP_L1_OFFSET + gl_WorkGroupID.x] |= + (uint(ballot.x != 0) << (word_index + 0)) | + (uint(ballot.y != 0) << (word_index + 1)); - if (out_1 != 0) { - atomicMin(uniforms.tiles.values[tile_index * TILE_STRIDE + TILE_BITMAP_RANGE_OFFSET + 0], gl_WorkGroupID.x * 2 + i); - atomicMax(uniforms.tiles.values[tile_index * TILE_STRIDE + TILE_BITMAP_RANGE_OFFSET + 1], gl_WorkGroupID.x * 2 + i); + atomicMin(uniforms.tiles.values[bin_index * TILE_STRIDE + TILE_BITMAP_RANGE_LO_OFFSET], gl_WorkGroupID.x); + atomicMax(uniforms.tiles.values[bin_index * TILE_STRIDE + TILE_BITMAP_RANGE_HI_OFFSET], gl_WorkGroupID.x); + } } } + + word_index += 2; } } diff --git a/title/shark-shaders/shaders/primitive_2d_bin_clear.comp.glsl b/title/shark-shaders/shaders/primitive_2d_bin_clear.comp.glsl index 5707be1..8cfa4fb 100644 --- a/title/shark-shaders/shaders/primitive_2d_bin_clear.comp.glsl +++ b/title/shark-shaders/shaders/primitive_2d_bin_clear.comp.glsl @@ -17,12 +17,12 @@ layout (local_size_x = 64, local_size_y = 1, local_size_z = 1) in; void main() { - if (gl_GlobalInvocationID.x >= uniforms.tile_resolution.x * uniforms.tile_resolution.y) { - return; - } + const uint tile_index = gl_GlobalInvocationID.z * uniforms.tile_stride + gl_GlobalInvocationID.y; - const uint index = gl_GlobalInvocationID.x * TILE_STRIDE + TILE_BITMAP_RANGE_OFFSET; + uniforms.tiles.values[tile_index * TILE_STRIDE + TILE_BITMAP_RANGE_LO_OFFSET] = 0xffffffff; + uniforms.tiles.values[tile_index * TILE_STRIDE + TILE_BITMAP_RANGE_HI_OFFSET] = 0; - uniforms.tiles.values[index + 0] = 0xffffffff; - uniforms.tiles.values[index + 1] = 0; + if (gl_GlobalInvocationID.x < TILE_BITMAP_L1_WORDS) { + uniforms.tiles.values[tile_index * TILE_STRIDE + TILE_BITMAP_L1_OFFSET + gl_GlobalInvocationID.x] = 0; + } } diff --git a/title/shark-shaders/shaders/primitive_2d_rasterize.comp.glsl b/title/shark-shaders/shaders/primitive_2d_rasterize.comp.glsl index 7043983..67f878c 100644 --- a/title/shark-shaders/shaders/primitive_2d_rasterize.comp.glsl +++ b/title/shark-shaders/shaders/primitive_2d_rasterize.comp.glsl @@ -34,25 +34,24 @@ vec3 plasma_quintic(float x) #endif void main() { - const uvec2 tile_coord = gl_WorkGroupID.xy / 4; + const uvec2 tile_coord = gl_WorkGroupID.xy / (TILE_SIZE / gl_WorkGroupSize.xy); const uint tile_index = tile_coord.y * uniforms.tile_stride + tile_coord.x; const uint tile_base = tile_index * TILE_STRIDE; TilesRead tiles_read = TilesRead(uniforms.tiles); - const uint first = tiles_read.values[tile_base + TILE_BITMAP_RANGE_OFFSET + 0]; - const uint last = tiles_read.values[tile_base + TILE_BITMAP_RANGE_OFFSET + 1]; + const uint lo = tiles_read.values[tile_base + TILE_BITMAP_RANGE_LO_OFFSET]; + const uint hi = tiles_read.values[tile_base + TILE_BITMAP_RANGE_HI_OFFSET]; - [[branch]] - if (last < first) { + if (hi < lo) { return; } #if DEBUG_SHOW_TILES == 1 - int count = 0; + uint count = 0; // For each tile, iterate over all words in the L1 bitmap. - for (uint index_l1 = first; index_l1 <= last; index_l1++) { + for (uint index_l1 = lo; index_l1 <= hi; index_l1++) { // For each word, iterate all set bits. uint bitmap_l1 = tiles_read.values[tile_base + TILE_BITMAP_L1_OFFSET + index_l1]; @@ -72,12 +71,20 @@ void main() { const vec3 color = plasma_quintic(float(count) / 100.0); imageStore(ui_layer_write, ivec2(gl_GlobalInvocationID.xy), vec4(color, 1.0)); +#elif DEBUG_SHOW_TILES == 2 + + uint count = hi - lo; + const vec3 color = plasma_quintic(float(count) / 100.0); + imageStore(ui_layer_write, ivec2(gl_GlobalInvocationID.xy), vec4(color, 1.0)); + #else + const vec2 sample_center = gl_GlobalInvocationID.xy + vec2(0.5); + vec4 accum = vec4(0.0); - // For each tile, iterate over all words in the L1 bitmap. - for (uint index_l1 = first; index_l1 <= last; index_l1++) { + // For each tile, iterate over all words in the L1 bitmap. + for (uint index_l1 = lo; index_l1 <= hi; index_l1++) { // For each word, iterate all set bits. uint bitmap_l1 = tiles_read.values[tile_base + TILE_BITMAP_L1_OFFSET + index_l1]; @@ -100,12 +107,12 @@ void main() { const Glyph gl = uniforms.glyphs.values[gi.index]; const vec2 glyph_min = gi.position + gl.offset_min; const vec2 glyph_max = gi.position + gl.offset_max; - const vec2 sample_center = gl_GlobalInvocationID.xy + vec2(0.5); + [[branch]] if (all(greaterThanEqual(sample_center, glyph_min)) && all(lessThanEqual(sample_center, glyph_max))) { const vec2 glyph_size = gl.offset_max - gl.offset_min; - const vec4 color = unpackUnorm4x8(gi.color).bgra; const vec2 uv = mix(gl.atlas_min, gl.atlas_max, (sample_center - glyph_min) / glyph_size) / uniforms.atlas_resolution; + const vec4 color = unpackUnorm4x8(gi.color).bgra; const float coverage = textureLod(sampler2D(glyph_atlas, bilinear_sampler), uv, 0.0).r * color.a; accum.rgb = (coverage * color.rgb) + accum.rgb * (1.0 - coverage); accum.a = coverage + accum.a * (1.0 - coverage); diff --git a/title/shark/src/main.rs b/title/shark/src/main.rs index 3f83795..1671a9e 100644 --- a/title/shark/src/main.rs +++ b/title/shark/src/main.rs @@ -1483,9 +1483,9 @@ impl<'gpu> DrawState<'gpu> { gpu.cmd_dispatch( cmd_encoder, - (self.tile_resolution_y * self.tile_resolution_x + 63) / 64, - 1, - 1, + (num_primitives_1024 + 63) / 64, + self.tile_resolution_x, + self.tile_resolution_y, ); gpu.cmd_barrier( @@ -1499,12 +1499,7 @@ impl<'gpu> DrawState<'gpu> { gpu.cmd_set_pipeline(cmd_encoder, self.bin_pipeline); - gpu.cmd_dispatch( - cmd_encoder, - (num_primitives + 2047) / 2048, - (self.tile_resolution_x + 3) / 4, - (self.tile_resolution_y + 3) / 4, - ); + gpu.cmd_dispatch(cmd_encoder, (num_primitives + 1023) / 1024, 1, 1); gpu.cmd_barrier( cmd_encoder, @@ -1738,8 +1733,8 @@ pub fn main() { for i in 0..80 { let i = i as f32; ui_state.text_fmt( - base_x * 100.0 * scale + 5.0, - base_y * 100.0 * scale + i * 15.0 * scale, + base_x * 100.0 * scale - 5.0, + base_y * 150.0 * scale + i * 15.0 * scale, FontFamily::RobotoRegular, 20.0, format_args!("tick: {:?}", tick_duration), -- 2.49.0