]> git.nega.tv - josh/narcissus/commitdiff
shark: Improve shader performance
authorJosh Simmons <josh@nega.tv>
Thu, 18 Jul 2024 21:19:39 +0000 (23:19 +0200)
committerJosh Simmons <josh@nega.tv>
Thu, 18 Jul 2024 21:19:39 +0000 (23:19 +0200)
Scalarize all the things!
Remove extra passes over the input primitives buffer.
Use AABB of entire subgroup to determine which tiles to write.

title/shark-shaders/shaders/display_transform.comp.glsl
title/shark-shaders/shaders/primitive_2d.h
title/shark-shaders/shaders/primitive_2d_bin.comp.glsl
title/shark-shaders/shaders/primitive_2d_bin_clear.comp.glsl
title/shark-shaders/shaders/primitive_2d_rasterize.comp.glsl
title/shark/src/main.rs

index be5cb3e65f0b24d98eeebf60971ed9ecee2e4a42..be48b9171404d4ee15d62a3dc050344470ed2b22 100644 (file)
@@ -37,9 +37,9 @@ void main() {
 
     TilesRead tiles_read = TilesRead(uniforms.tiles);
 
-    const uint first = tiles_read.values[tile_base + TILE_BITMAP_RANGE_OFFSET + 0];
-    const uint last = tiles_read.values[tile_base + TILE_BITMAP_RANGE_OFFSET + 1];
-    if (first <= last) {
+    const uint lo = tiles_read.values[tile_base + TILE_BITMAP_RANGE_LO_OFFSET];
+    const uint hi = tiles_read.values[tile_base + TILE_BITMAP_RANGE_HI_OFFSET];
+    if (lo <= hi) {
         const vec4 ui = imageLoad(ui_layer_read, ivec2(gl_GlobalInvocationID.xy)).rgba;
         composited = ui.rgb + (composited * (1.0 - ui.a));
     }
index 37ebb1fd2a00e21741403035c6977266473848e9..010735d934a0e9d5a8f2c0a899ca8421073acaaf 100644 (file)
@@ -4,11 +4,12 @@ const uint MAX_PRIMS = 1 << 18;
 const uint TILE_BITMAP_L1_WORDS = (MAX_PRIMS / 32 / 32);
 const uint TILE_BITMAP_L0_WORDS = (MAX_PRIMS / 32);
 const uint TILE_STRIDE = (TILE_BITMAP_L0_WORDS + TILE_BITMAP_L1_WORDS + 2);
-const uint TILE_BITMAP_RANGE_OFFSET = 0;
-const uint TILE_BITMAP_L1_OFFSET = 2;
-const uint TILE_BITMAP_L0_OFFSET = TILE_BITMAP_L1_OFFSET + TILE_BITMAP_L1_WORDS;
+const uint TILE_BITMAP_RANGE_LO_OFFSET = 0;
+const uint TILE_BITMAP_RANGE_HI_OFFSET = (TILE_BITMAP_RANGE_LO_OFFSET + 1);
+const uint TILE_BITMAP_L1_OFFSET = (TILE_BITMAP_RANGE_HI_OFFSET + 1);
+const uint TILE_BITMAP_L0_OFFSET = (TILE_BITMAP_L1_OFFSET + TILE_BITMAP_L1_WORDS);
 
-bool test_glyph(uint index, uvec2 tile_min, uvec2 tile_max) {
+bool test_glyph(uint index, vec2 tile_min, vec2 tile_max) {
     const GlyphInstance gi = uniforms.glyph_instances.values[index];
     const Glyph gl = uniforms.glyphs.values[gi.index];
     const vec2 glyph_min = gi.position + gl.offset_min;
index 298f0408182947e30b2f2b0989cc6ff63a3bfbe7..0010d718504801f9ddb97e9533d67eb861e43791 100644 (file)
@@ -7,80 +7,80 @@
 #extension GL_EXT_scalar_block_layout : require
 #extension GL_EXT_control_flow_attributes : require
 
-#extension GL_KHR_shader_subgroup_vote : require
+#extension GL_KHR_shader_subgroup_arithmetic : require
 #extension GL_KHR_shader_subgroup_ballot : require
+#extension GL_KHR_shader_subgroup_vote : require
 
 #include "compute_bindings.h"
 #include "primitive_2d.h"
 
 const uint SUBGROUP_SIZE = 64;
-const uint NUM_PRIMS_WG = (SUBGROUP_SIZE * 32);
+const uint NUM_SUBGROUPS = 16;
+const uint NUM_PRIMITIVES_WG = (SUBGROUP_SIZE * NUM_SUBGROUPS);
 
 // TODO: Spec constant support for different subgroup sizes.
 layout (local_size_x = SUBGROUP_SIZE, local_size_y = 1, local_size_z = 1) in;
 
-shared uint bitmap_0[SUBGROUP_SIZE];
-
 void main() {
-    const uvec2 bin_coord = gl_GlobalInvocationID.yz;
-    const uvec2 bin_min = bin_coord * TILE_SIZE * 8;
-    const uvec2 bin_max = min(bin_min + TILE_SIZE * 8, uniforms.screen_resolution);
-
-    for (uint i = 0; i < NUM_PRIMS_WG; i += gl_SubgroupSize.x) {
-        const uint prim_index = gl_WorkGroupID.x * NUM_PRIMS_WG + i + gl_SubgroupInvocationID;
-        bool intersects = false;
-        if (prim_index < uniforms.num_primitives) {
-            const GlyphInstance gi = uniforms.glyph_instances.values[prim_index];
+    uint word_index = 0;
+
+    for (uint i = 0; i < NUM_PRIMITIVES_WG; i += gl_SubgroupSize.x) {
+        const uint primitive_index = gl_WorkGroupID.x * NUM_PRIMITIVES_WG + i + gl_SubgroupInvocationID;
+
+        vec2 primitive_min = vec2(99999.9);
+        vec2 primitive_max = vec2(-99999.9);
+
+        if (primitive_index < uniforms.num_primitives) {
+            const GlyphInstance gi = uniforms.glyph_instances.values[primitive_index];
             const Glyph gl = uniforms.glyphs.values[gi.index];
-            const vec2 glyph_min = gi.position + gl.offset_min;
-            const vec2 glyph_max = gi.position + gl.offset_max;
-            intersects = !(any(lessThan(bin_max, glyph_min)) || any(greaterThan(bin_min, glyph_max)));
+            primitive_min = gi.position + gl.offset_min;
+            primitive_max = gi.position + gl.offset_max;
         }
-        const uvec4 ballot = subgroupBallot(intersects);
-        bitmap_0[i / 32 + 0] = ballot.x;
-        bitmap_0[i / 32 + 1] = ballot.y;
-    }
 
-    memoryBarrierShared();
+        const vec2 primitives_min = subgroupMin(primitive_min);
+        const vec2 primitives_max = subgroupMax(primitive_max);
 
-    const uint x = gl_SubgroupInvocationID.x & 7;
-    const uint y = gl_SubgroupInvocationID.x >> 3;
-    const uvec2 tile_coord = gl_GlobalInvocationID.yz * 8 + uvec2(x, y);
-    const uvec2 tile_min = tile_coord * TILE_SIZE;
-    const uvec2 tile_max = min(tile_min + TILE_SIZE, uniforms.screen_resolution);
+        if (any(greaterThan(primitives_min, uniforms.screen_resolution)) || any(lessThan(primitives_max, vec2(0.0)))) {
+            word_index += 2;
+            continue;
+        }
 
-    if (all(lessThan(tile_min, tile_max))) {
-        const uint tile_index = tile_coord.y * uniforms.tile_stride + tile_coord.x;
+        ivec2 bin_start = ivec2(floor(max(min(primitives_min, uniforms.screen_resolution), 0.0) / TILE_SIZE));
+        ivec2 bin_end = ivec2(floor((max(min(primitives_max, uniforms.screen_resolution), 0.0) + (TILE_SIZE - 1)) / TILE_SIZE));
 
-        for (uint i = 0; i < 2; i++) {
-            uint out_1 = 0;
+        for (int y = bin_start.y; y < bin_end.y; y++) {
+            for (int x = bin_start.x; x < bin_end.x; x++) {
+                const uvec2 bin_coord = uvec2(x, y);
+                const uint bin_index = bin_coord.y * uniforms.tile_stride + bin_coord.x;
+                const vec2 bin_min = bin_coord * TILE_SIZE;
+                const vec2 bin_max = min(bin_min + TILE_SIZE, uniforms.screen_resolution);
 
-            for (uint j = 0; j < 32; j++) {
-                uint out_0 = 0;
-                uint index_0 = i * 32 + j;
-                uint word_0 = bitmap_0[index_0];
-                while (word_0 != 0) {
-                    const uint bit_0 = findLSB(word_0);
-                    word_0 ^= word_0 & -word_0;
+                const bool intersects = !(any(lessThan(bin_max, primitive_min)) || any(greaterThan(bin_min, primitive_max)));
+                const uvec4 ballot = subgroupBallot(intersects);
 
-                    const uint prim_index = gl_WorkGroupID.x * NUM_PRIMS_WG + index_0 * 32 + bit_0;
-                    if (test_glyph(prim_index, tile_min, tile_max)) {
-                        out_0 |= 1 << bit_0;
-                    }
+                if (ballot.x == 0 && ballot.y == 0) {
+                    continue;
                 }
 
-                if (out_0 != 0) {
-                    out_1 |= 1 << j;
-                    uniforms.tiles.values[tile_index * TILE_STRIDE + TILE_BITMAP_L0_OFFSET + gl_WorkGroupID.x * 64 + index_0] = out_0;
+                if (ballot.x != 0) {
+                    uniforms.tiles.values[bin_index * TILE_STRIDE + TILE_BITMAP_L0_OFFSET + gl_WorkGroupID.x * 32 + word_index + 0] = ballot.x;
+                }
+
+                if (ballot.y != 0) {
+                    uniforms.tiles.values[bin_index * TILE_STRIDE + TILE_BITMAP_L0_OFFSET + gl_WorkGroupID.x * 32 + word_index + 1] = ballot.y;
                 }
-            }
 
-            uniforms.tiles.values[tile_index * TILE_STRIDE + TILE_BITMAP_L1_OFFSET + gl_WorkGroupID.x * 2 + i] = out_1;
+                if (subgroupElect()) {
+                    uniforms.tiles.values[bin_index * TILE_STRIDE + TILE_BITMAP_L1_OFFSET + gl_WorkGroupID.x] |=
+                        (uint(ballot.x != 0) << (word_index + 0)) |
+                        (uint(ballot.y != 0) << (word_index + 1));
 
-            if (out_1 != 0) {
-                atomicMin(uniforms.tiles.values[tile_index * TILE_STRIDE + TILE_BITMAP_RANGE_OFFSET + 0], gl_WorkGroupID.x * 2 + i);
-                atomicMax(uniforms.tiles.values[tile_index * TILE_STRIDE + TILE_BITMAP_RANGE_OFFSET + 1], gl_WorkGroupID.x * 2 + i);
+                    atomicMin(uniforms.tiles.values[bin_index * TILE_STRIDE + TILE_BITMAP_RANGE_LO_OFFSET], gl_WorkGroupID.x);
+                    atomicMax(uniforms.tiles.values[bin_index * TILE_STRIDE + TILE_BITMAP_RANGE_HI_OFFSET], gl_WorkGroupID.x);
+                }
             }
         }
+
+        word_index += 2;
     }
 }
index 5707be1ef694d5469b991bb1fcd7d308398eaf77..8cfa4fba499e8b9a0ca792b1c0aa615d1d56e28c 100644 (file)
 layout (local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
 
 void main() {
-    if (gl_GlobalInvocationID.x >= uniforms.tile_resolution.x * uniforms.tile_resolution.y) {
-        return;
-    }
+    const uint tile_index = gl_GlobalInvocationID.z * uniforms.tile_stride + gl_GlobalInvocationID.y;
 
-    const uint index = gl_GlobalInvocationID.x * TILE_STRIDE + TILE_BITMAP_RANGE_OFFSET;
+    uniforms.tiles.values[tile_index * TILE_STRIDE + TILE_BITMAP_RANGE_LO_OFFSET] = 0xffffffff;
+    uniforms.tiles.values[tile_index * TILE_STRIDE + TILE_BITMAP_RANGE_HI_OFFSET] = 0;
 
-    uniforms.tiles.values[index + 0] = 0xffffffff;
-    uniforms.tiles.values[index + 1] = 0;
+    if (gl_GlobalInvocationID.x < TILE_BITMAP_L1_WORDS) {
+        uniforms.tiles.values[tile_index * TILE_STRIDE + TILE_BITMAP_L1_OFFSET + gl_GlobalInvocationID.x] = 0;
+    }
 }
index 7043983243ae11fd608c69a9ed192088fef4e9fb..67f878c30fba06604251f95db80b8dc34d36182a 100644 (file)
@@ -34,25 +34,24 @@ vec3 plasma_quintic(float x)
 #endif
 
 void main() {
-    const uvec2 tile_coord = gl_WorkGroupID.xy / 4;
+    const uvec2 tile_coord = gl_WorkGroupID.xy / (TILE_SIZE / gl_WorkGroupSize.xy);
     const uint tile_index = tile_coord.y * uniforms.tile_stride + tile_coord.x;
     const uint tile_base = tile_index * TILE_STRIDE;
 
     TilesRead tiles_read = TilesRead(uniforms.tiles);
 
-    const uint first = tiles_read.values[tile_base + TILE_BITMAP_RANGE_OFFSET + 0];
-    const uint last = tiles_read.values[tile_base + TILE_BITMAP_RANGE_OFFSET + 1];
+    const uint lo = tiles_read.values[tile_base + TILE_BITMAP_RANGE_LO_OFFSET];
+    const uint hi = tiles_read.values[tile_base + TILE_BITMAP_RANGE_HI_OFFSET];
 
-    [[branch]]
-    if (last < first) {
+    if (hi < lo) {
         return;
     }
 
 #if DEBUG_SHOW_TILES == 1
 
-    int count = 0;
+    uint count = 0;
     // For each tile, iterate over all words in the L1 bitmap.
-    for (uint index_l1 = first; index_l1 <= last; index_l1++) {
+    for (uint index_l1 = lo; index_l1 <= hi; index_l1++) {
         // For each word, iterate all set bits.
         uint bitmap_l1 = tiles_read.values[tile_base + TILE_BITMAP_L1_OFFSET + index_l1];
 
@@ -72,12 +71,20 @@ void main() {
     const vec3 color = plasma_quintic(float(count) / 100.0);
     imageStore(ui_layer_write, ivec2(gl_GlobalInvocationID.xy), vec4(color, 1.0));
 
+#elif DEBUG_SHOW_TILES == 2
+
+    uint count = hi - lo;
+    const vec3 color = plasma_quintic(float(count) / 100.0);
+    imageStore(ui_layer_write, ivec2(gl_GlobalInvocationID.xy), vec4(color, 1.0));
+
 #else
 
+    const vec2 sample_center = gl_GlobalInvocationID.xy + vec2(0.5);
+
     vec4 accum = vec4(0.0);
 
-    // For each tile, iterate over all words in the L1 bitmap. 
-    for (uint index_l1 = first; index_l1 <= last; index_l1++) {
+    // For each tile, iterate over all words in the L1 bitmap.
+    for (uint index_l1 = lo; index_l1 <= hi; index_l1++) {
         // For each word, iterate all set bits.
         uint bitmap_l1 = tiles_read.values[tile_base + TILE_BITMAP_L1_OFFSET + index_l1];
 
@@ -100,12 +107,12 @@ void main() {
                 const Glyph gl = uniforms.glyphs.values[gi.index];
                 const vec2 glyph_min = gi.position + gl.offset_min;
                 const vec2 glyph_max = gi.position + gl.offset_max;
-                const vec2 sample_center = gl_GlobalInvocationID.xy + vec2(0.5);
+
                 [[branch]]
                 if (all(greaterThanEqual(sample_center, glyph_min)) && all(lessThanEqual(sample_center, glyph_max))) {
                     const vec2 glyph_size = gl.offset_max - gl.offset_min;
-                    const vec4 color = unpackUnorm4x8(gi.color).bgra;
                     const vec2 uv = mix(gl.atlas_min, gl.atlas_max, (sample_center - glyph_min) / glyph_size) / uniforms.atlas_resolution;
+                    const vec4 color = unpackUnorm4x8(gi.color).bgra;
                     const float coverage = textureLod(sampler2D(glyph_atlas, bilinear_sampler), uv, 0.0).r * color.a;
                     accum.rgb = (coverage * color.rgb) + accum.rgb * (1.0 - coverage);
                     accum.a = coverage + accum.a * (1.0 - coverage);
index 3f83795d8b7c86681ffa8bcb6c9d70e72fe44f65..1671a9e29800017e5ffc7cb1f985af2ed3177ff0 100644 (file)
@@ -1483,9 +1483,9 @@ impl<'gpu> DrawState<'gpu> {
 
                 gpu.cmd_dispatch(
                     cmd_encoder,
-                    (self.tile_resolution_y * self.tile_resolution_x + 63) / 64,
-                    1,
-                    1,
+                    (num_primitives_1024 + 63) / 64,
+                    self.tile_resolution_x,
+                    self.tile_resolution_y,
                 );
 
                 gpu.cmd_barrier(
@@ -1499,12 +1499,7 @@ impl<'gpu> DrawState<'gpu> {
 
                 gpu.cmd_set_pipeline(cmd_encoder, self.bin_pipeline);
 
-                gpu.cmd_dispatch(
-                    cmd_encoder,
-                    (num_primitives + 2047) / 2048,
-                    (self.tile_resolution_x + 3) / 4,
-                    (self.tile_resolution_y + 3) / 4,
-                );
+                gpu.cmd_dispatch(cmd_encoder, (num_primitives + 1023) / 1024, 1, 1);
 
                 gpu.cmd_barrier(
                     cmd_encoder,
@@ -1738,8 +1733,8 @@ pub fn main() {
             for i in 0..80 {
                 let i = i as f32;
                 ui_state.text_fmt(
-                    base_x * 100.0 * scale + 5.0,
-                    base_y * 100.0 * scale + i * 15.0 * scale,
+                    base_x * 100.0 * scale - 5.0,
+                    base_y * 150.0 * scale + i * 15.0 * scale,
                     FontFamily::RobotoRegular,
                     20.0,
                     format_args!("tick: {:?}", tick_duration),