diff --git a/index copy.html b/index copy.html new file mode 100644 index 0000000..96d3047 --- /dev/null +++ b/index copy.html @@ -0,0 +1,76 @@ + + + + + + Bevy Compute Shader + + + + + +
+
+ + alexharding.ooo +     + bluesky + +

Experiment with compute shader

+ +
+
+

Github

+
+ + +
+
+ + + + \ No newline at end of file diff --git a/index.html b/index.html index 1f2188c..493b81d 100644 --- a/index.html +++ b/index.html @@ -3,7 +3,7 @@ - Bevy Compute Shader + Compute Shader @@ -11,15 +11,26 @@
-

Experiment with compute shader

- + + alexharding.ooo +     + bluesky + +

Bevy Compute Shader Thing


Yeah wasm bindgen always makes it scroll to the bottom. I'm not ready to look into this. Scroll up ↑ 𓆩♡𓆪

Github


- - +

Side project to learn some compute shader stuff.

+

Wasm bindgen always makes it scroll to the bottom. I'm not ready to look into this. Scroll up ↑ 𓆩♡𓆪

+ + alexharding.ooo +     + bluesky + +
+
diff --git a/src/gui.rs b/src/gui.rs index e83cac6..d5a6790 100644 --- a/src/gui.rs +++ b/src/gui.rs @@ -1,7 +1,7 @@ use bevy::prelude::*; use bevy_egui::{egui, EguiContexts}; -use crate::ParamsUniform; +use crate::{ParamsUniform, ShaderConfigurator}; #[derive(Event)] pub struct ParamsChanged { @@ -22,6 +22,7 @@ fn ui_system( mut contexts: EguiContexts, // mut param_events: EventWriter, mut params: ResMut, + mut shader_configurator: ResMut, ) { // let mut radius = params.radius; @@ -37,7 +38,14 @@ fn ui_system( ui.add(egui::Slider::new(&mut params.noise_scale, 0.0..=2.).text("scale")); ui.add(egui::Slider::new(&mut params.noise_offset, 0.0..=20.).text("offset")); ui.add(egui::Slider::new(&mut params.warp_amount, 0.0..=0.2).text("warp amount")); - ui.add(egui::Slider::new(&mut params.warp_scale, 1.0..=20.).text("warp scale")); + ui.add(egui::Slider::new(&mut params.warp_scale, 1.0..=20.).text("warp scale")); + ui.horizontal(|ui| { + ui.label("war iterations"); + ui.add( + egui::DragValue::new(&mut shader_configurator.shader_configs[1].iterations) + .range(0..=50), + ); + }); }); }); } diff --git a/src/main.rs b/src/main.rs index 40cd21b..66d884c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,37 +1,30 @@ -//! Simple example demonstrating the use of the [`Readback`] component to read back data from the GPU -//! using both a storage buffer and texture. - use bevy::{ asset::load_internal_asset, prelude::*, render::{ - extract_resource::{ExtractResource, ExtractResourcePlugin}, - gpu_readback::{Readback, ReadbackComplete}, + extract_resource::{self, ExtractResource, ExtractResourcePlugin}, render_asset::{RenderAssetUsages, RenderAssets}, render_graph::{self, RenderGraph, RenderLabel}, render_resource::{binding_types::texture_storage_2d, *}, renderer::{RenderContext, RenderDevice, RenderQueue}, - storage::{GpuShaderStorageBuffer, ShaderStorageBuffer}, texture::GpuImage, Render, RenderApp, RenderSet, }, + utils::HashMap, }; - -mod gui; - use binding_types::uniform_buffer; use bytemuck::bytes_of; -use gui::ParamsChanged; +mod gui; -const SHADER1_HANDLE: Handle = Handle::weak_from_u128(13378847158248049035); -const SHADER2_HANDLE: Handle = Handle::weak_from_u128(23378847158248049035); -const SHADER3_HANDLE: Handle = Handle::weak_from_u128(33378847158248049035); -const NOISE_SHADER_HANDLE: Handle = Handle::weak_from_u128(14378847158248049035); -const VECTOR_SHADER_HANDLE: Handle = Handle::weak_from_u128(25378847158248049035); +const GENERATE_CIRCLE_HANDLE: Handle = Handle::weak_from_u128(13378847158248049035); +const DOMAIN_WARP_HANDLE: Handle = Handle::weak_from_u128(23378847158248049035); +const EXTRACT_HANDLE: Handle = Handle::weak_from_u128(33378847158248049035); +const UTIL_NOISE_SHADER_HANDLE: Handle = Handle::weak_from_u128(14378847158248049035); +const UTIL_VECTOR_SHADER_HANDLE: Handle = Handle::weak_from_u128(25378847158248049035); // The length of the buffer sent to the gpu -const BUFFER_LEN: usize = 1000; +const BUFFER_LEN: usize = 1024; #[derive(Resource, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable, ExtractResource, ShaderType)] #[repr(C)] @@ -67,8 +60,9 @@ fn main() { .add_plugins(( DefaultPlugins, GpuReadbackPlugin, - ExtractResourcePlugin::::default(), + ExtractResourcePlugin::::default(), ExtractResourcePlugin::::default(), + // ExtractResourcePlugin::::default(), gui::GuiPlugin, )) .insert_resource(ClearColor(Color::BLACK)) @@ -86,114 +80,158 @@ fn update_uniform_buffer( } } -// We need a plugin to organize all the systems and render node required for this example +#[derive(Debug, Clone)] +struct ShaderConfig { + shader_handle: Handle, + iterations: u32, +} + +#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)] +enum ComputeNodeLabel { + Compute1, + Compute2, + Compute3, + Final, +} + struct GpuReadbackPlugin; impl Plugin for GpuReadbackPlugin { fn build(&self, app: &mut App) { - - // let asset_server = app.world().resource::(); - // let asset_path = asset_server. - - // Load the noise shader first as an internal asset + let shader_configs = vec![ + ShaderConfig { + shader_handle: GENERATE_CIRCLE_HANDLE, + iterations: 1, + }, + ShaderConfig { + shader_handle: DOMAIN_WARP_HANDLE, + iterations: 5, + }, + ShaderConfig { + shader_handle: DOMAIN_WARP_HANDLE, + iterations: 1, + }, + ]; + + app.insert_resource(ShaderConfigurator { shader_configs }); + app.add_plugins(ExtractResourcePlugin::::default()); + load_internal_asset!( app, - NOISE_SHADER_HANDLE, + UTIL_NOISE_SHADER_HANDLE, "shaders/utils/noise.wgsl", Shader::from_wgsl ); load_internal_asset!( app, - VECTOR_SHADER_HANDLE, + UTIL_VECTOR_SHADER_HANDLE, "shaders/utils/utils.wgsl", Shader::from_wgsl ); load_internal_asset!( app, - SHADER1_HANDLE, + GENERATE_CIRCLE_HANDLE, "shaders/generate_circle.wgsl", Shader::from_wgsl ); load_internal_asset!( app, - SHADER2_HANDLE, + DOMAIN_WARP_HANDLE, "shaders/domain_warp.wgsl", Shader::from_wgsl ); load_internal_asset!( app, - SHADER3_HANDLE, - "shaders/3rd_pass.wgsl", + EXTRACT_HANDLE, + "shaders/extract.wgsl", Shader::from_wgsl ); - } fn finish(&self, app: &mut App) { + + let shader_configs = app.world().resource::().clone(); + let render_app = app.sub_app_mut(RenderApp); - render_app.init_resource::() - .add_systems( + + + render_app.insert_resource(shader_configs); + + render_app.init_resource::().add_systems( Render, ( update_uniform_buffer, prepare_bind_groups .in_set(RenderSet::PrepareBindGroups) - // We don't need to recreate the bind group every frame .run_if(not(resource_exists::)), + prepare_bind_group_selection + .in_set(RenderSet::PrepareBindGroups) + .after(prepare_bind_groups), ), ); let mut render_graph = render_app.world_mut().resource_mut::(); - - render_graph.add_node(ComputeNodeLabel1, ComputeNode{ pass_index: 0}); - render_graph.add_node(ComputeNodeLabel2, ComputeNode{ pass_index: 1}); - render_graph.add_node(ComputeNodeLabel3, ComputeNode{ pass_index: 2}); - render_graph.add_node_edge(ComputeNodeLabel1, ComputeNodeLabel2); - render_graph.add_node_edge(ComputeNodeLabel2, ComputeNodeLabel3); - + // Add compute nodes + render_graph.add_node( + ComputeNodeLabel::Compute1, + ComputeNode { + pipeline_index: 0, + is_final: false, + }, + ); + render_graph.add_node( + ComputeNodeLabel::Compute2, + ComputeNode { + pipeline_index: 1, + is_final: false, + }, + ); + render_graph.add_node( + ComputeNodeLabel::Compute3, + ComputeNode { + pipeline_index: 2, + is_final: false, + }, + ); + + // Add final pass + render_graph.add_node( + ComputeNodeLabel::Final, + ComputeNode { + pipeline_index: 0, + is_final: true, + }, + ); + + // Add edges between nodes + render_graph.add_node_edge(ComputeNodeLabel::Compute1, ComputeNodeLabel::Compute2); + render_graph.add_node_edge(ComputeNodeLabel::Compute2, ComputeNodeLabel::Compute3); + render_graph.add_node_edge(ComputeNodeLabel::Compute3, ComputeNodeLabel::Final); } } #[derive(Resource, ExtractResource, Clone)] -struct ReadbackImage { - ping: Handle, - pong: Handle, +struct ImageBufferContainer { + buffer_a: Handle, + buffer_b: Handle, result: Handle, } fn setup(mut commands: Commands, mut images: ResMut>) { commands.spawn((Camera2d::default(),)); - // Create a storage texture with some data let size = Extent3d { width: BUFFER_LEN as u32, height: BUFFER_LEN as u32, ..default() }; - let mut image = Image::new_fill( - size, - TextureDimension::D2, - &[0, 0, 0, 0], - TextureFormat::Rgba32Float, - // TextureFormat::R32Uint, - RenderAssetUsages::RENDER_WORLD, - ); - - // We also need to enable the COPY_SRC, as well as STORAGE_BINDING so we can use it in the - // compute shader - image.texture_descriptor.usage |= - TextureUsages::COPY_SRC | - TextureUsages::COPY_DST | - TextureUsages::STORAGE_BINDING | - TextureUsages::TEXTURE_BINDING; - let mut create_image = || { let mut image = Image::new_fill( size, TextureDimension::D2, - &[0, 0, 0, 0], + &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], TextureFormat::Rgba32Float, RenderAssetUsages::RENDER_WORLD, ); @@ -201,18 +239,9 @@ fn setup(mut commands: Commands, mut images: ResMut>) { images.add(image) }; - let ping = create_image(); - let pong = create_image(); + let buffer_a = create_image(); + let buffer_b = create_image(); let result = create_image(); - // commands.spawn(Readback::texture(pong.clone())); - - // Spawn the readback components. For each frame, the data will be read back from the GPU - // asynchronously and trigger the `ReadbackComplete` event on this entity. Despawn the entity - // to stop reading back the data. - - // Textures can also be read back from the GPU. Pay careful attention to the format of the - // texture, as it will affect how the data is interpreted. - // commands.spawn(Readback::texture(pong.clone())); commands.spawn(( Sprite { @@ -223,31 +252,35 @@ fn setup(mut commands: Commands, mut images: ResMut>) { Transform::from_xyz(0.0, 0.5, 0.0).with_scale(Vec3::splat(1.0)), )); - // This is just a simple way to pass the image handle to the render app for our compute node - // commands.insert_resource(ReadbackImage(image)); - commands.insert_resource(ReadbackImage { - ping: ping, - pong: pong, - result: result, + commands.insert_resource(ImageBufferContainer { + buffer_a, + buffer_b, + result, }); } #[derive(Resource)] struct GpuBufferBindGroups { - first_pass: BindGroup, - second_pass: BindGroup, - third_pass: BindGroup, + bind_groups: Vec, + final_pass_a: BindGroup, + final_pass_b: BindGroup, uniform_buffer: Buffer, } +#[derive(Resource)] +struct BindGroupSelection { + // node_bind_groups: Vec, // Index of bind group to use for each node + selectors: HashMap>, + final_pass: u32, +} fn prepare_bind_groups( mut commands: Commands, pipeline: Res, render_device: Res, - image: Res, + buffer_container: Res, images: Res>, - params: Res, + params_res: Res, render_queue: Res, ) { let uniform_buffer = render_device.create_buffer(&BufferDescriptor { @@ -257,59 +290,78 @@ fn prepare_bind_groups( mapped_at_creation: false, }); - render_queue.write_buffer(&uniform_buffer, 0, bytes_of(&*params)); + render_queue.write_buffer(&uniform_buffer, 0, bytes_of(&*params_res)); - let ping_image = images.get(&image.ping).unwrap(); - let pong_image = images.get(&image.pong).unwrap(); - let result_image = images.get(&image.result).unwrap(); + let image_a = images.get(&buffer_container.buffer_a).unwrap(); + let image_b = images.get(&buffer_container.buffer_b).unwrap(); + let result_image = images.get(&buffer_container.result).unwrap(); - let first_pass = render_device.create_bind_group( - None, - &pipeline.layout, - &BindGroupEntries::sequential(( - uniform_buffer.as_entire_buffer_binding(), - ping_image.texture_view.into_binding(), - pong_image.texture_view.into_binding(), - )), - ); - let second_pass = render_device.create_bind_group( + let bind_groups = vec![ + // A -> B + render_device.create_bind_group( + None, + &pipeline.layout, + &BindGroupEntries::sequential(( + uniform_buffer.as_entire_buffer_binding(), + image_a.texture_view.into_binding(), + image_b.texture_view.into_binding(), + )), + ), + // B -> A + render_device.create_bind_group( + None, + &pipeline.layout, + &BindGroupEntries::sequential(( + uniform_buffer.as_entire_buffer_binding(), + image_b.texture_view.into_binding(), + image_a.texture_view.into_binding(), + )), + ), + ]; + + let extract_a = render_device.create_bind_group( None, &pipeline.layout, &BindGroupEntries::sequential(( uniform_buffer.as_entire_buffer_binding(), - pong_image.texture_view.into_binding(), - ping_image.texture_view.into_binding(), + image_a.texture_view.into_binding(), + result_image.texture_view.into_binding(), )), ); - let third_pass = render_device.create_bind_group( + let extract_b = render_device.create_bind_group( None, &pipeline.layout, &BindGroupEntries::sequential(( uniform_buffer.as_entire_buffer_binding(), - ping_image.texture_view.into_binding(), + image_b.texture_view.into_binding(), result_image.texture_view.into_binding(), )), ); - commands.insert_resource(GpuBufferBindGroups { - first_pass, - second_pass, - third_pass, + bind_groups, + final_pass_a: extract_a, + final_pass_b: extract_b, uniform_buffer, + // iteration: 0, }); } +#[derive(Resource, Clone, ExtractResource)] +struct ShaderConfigurator { + shader_configs: Vec, +} + #[derive(Resource)] struct ComputePipelines { layout: BindGroupLayout, - first_pass: CachedComputePipelineId, - second_pass: CachedComputePipelineId, - third_pass: CachedComputePipelineId, + pipeline_configs: Vec, + final_pass: CachedComputePipelineId, } impl FromWorld for ComputePipelines { fn from_world(world: &mut World) -> Self { + let shader_configurator = world.resource::(); let render_device = world.resource::(); let layout = render_device.create_bind_group_layout( None, @@ -322,36 +374,32 @@ impl FromWorld for ComputePipelines { ), ), ); - // let shader1 = world.load_asset(SHADER1_ASSET_PATH); - // let shader1 = let pipeline_cache = world.resource::(); - let first_pass = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { - label: Some("First pass".into()), - layout: vec![layout.clone()], - push_constant_ranges: Vec::new(), - // shader: shader1.clone(), - shader: SHADER1_HANDLE, - shader_defs: Vec::new(), - entry_point: "main".into(), - zero_initialize_workgroup_memory: false, - }); - // let shader2 = world.load_asset(SHADER2_ASSET_PATH); - let second_pass = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { - label: Some("Second pass".into()), - layout: vec![layout.clone()], - push_constant_ranges: Vec::new(), - shader: SHADER2_HANDLE, - shader_defs: Vec::new(), - entry_point: "main".into(), - zero_initialize_workgroup_memory: false, - }); - let third_pass = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { - label: Some("Third pass".into()), + let shader_configs = shader_configurator.shader_configs.clone(); + + // Create pipeline for each shader with its iteration count + let mut pipeline_configs = Vec::new(); + for config in shader_configs { + let pipeline_id = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { + label: Some("compute".into()), + layout: vec![layout.clone()], + push_constant_ranges: Vec::new(), + shader: config.shader_handle, + shader_defs: Vec::new(), + entry_point: "main".into(), + zero_initialize_workgroup_memory: false, + }); + + pipeline_configs.push(pipeline_id); + } + + let final_pass = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { + label: Some("Final pass".into()), layout: vec![layout.clone()], push_constant_ranges: Vec::new(), - shader: SHADER3_HANDLE, + shader: EXTRACT_HANDLE, shader_defs: Vec::new(), entry_point: "main".into(), zero_initialize_workgroup_memory: false, @@ -359,27 +407,45 @@ impl FromWorld for ComputePipelines { ComputePipelines { layout, - first_pass, - second_pass, - third_pass, + pipeline_configs, + final_pass, } } } +fn prepare_bind_group_selection(mut commands: Commands, pipelines: Res, shader_configurator: Res) { + let mut selectors = HashMap::new(); + let mut total_iterations = 0; + let mut node: u32 = 0; -/// Label to identify the node in the render graph -#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)] -struct ComputeNodeLabel1; -#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)] -struct ComputeNodeLabel2; -#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)] -struct ComputeNodeLabel3; -/// The node that will execute the compute shader + + for _ in &pipelines.pipeline_configs { + let mut node_selections = Vec::new(); + + let i = shader_configurator.shader_configs[node as usize].iterations; + for _ in 0..i { + node_selections.push(total_iterations % 2); + total_iterations += 1; + } + selectors.insert(node, node_selections); + node += 1; + } + + let final_pass = total_iterations % 2; + + commands.insert_resource(BindGroupSelection { + selectors, + final_pass, + }); +} + #[derive(Default)] struct ComputeNode { - pass_index: u32, + pipeline_index: usize, + is_final: bool, } + impl render_graph::Node for ComputeNode { fn run( &self, @@ -390,33 +456,64 @@ impl render_graph::Node for ComputeNode { let pipeline_cache = world.resource::(); let pipelines = world.resource::(); let bind_groups = world.resource::(); + let encoder = render_context.command_encoder(); + let selectors = world.resource::(); + let shader_configurator = world.resource::(); + - let (pipeline_id, bind_group) = match self.pass_index { - 0 => (pipelines.first_pass, &bind_groups.first_pass), - 1 => (pipelines.second_pass, &bind_groups.second_pass), - 2 => (pipelines.third_pass, &bind_groups.third_pass), - _ => return Ok(()), - }; - println!("Running pass {}", self.pass_index); - if let Some(pipeline) = pipeline_cache.get_compute_pipeline(pipeline_id) { - println!("Pipeline ready for pass {}", self.pass_index); - let mut pass = render_context - .command_encoder() - .begin_compute_pass(&ComputePassDescriptor::default()); - - pass.set_bind_group(0, bind_group, &[]); - pass.set_pipeline(pipeline); - pass.dispatch_workgroups(BUFFER_LEN as u32, BUFFER_LEN as u32, 1); - - - - }else { - println!("Pipeline not ready for pass {}", self.pass_index); + if self.is_final { + if let Some(pipeline) = pipeline_cache.get_compute_pipeline(pipelines.final_pass) { + encoder.push_debug_group("Final pass"); + + { + let group = if selectors.final_pass == 0 { + &bind_groups.final_pass_a + } else { + &bind_groups.final_pass_b + }; + + let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor::default()); + pass.set_bind_group(0, group, &[]); + pass.set_pipeline(pipeline); + pass.dispatch_workgroups( + ((BUFFER_LEN + 15) / 16) as u32, + ((BUFFER_LEN + 15) / 16) as u32, + 1, + ); + } + encoder.pop_debug_group(); + } + } else { + let pipeline_id = pipelines.pipeline_configs[self.pipeline_index]; + + if let Some(pipeline) = pipeline_cache.get_compute_pipeline(pipeline_id) { + let iters = shader_configurator.shader_configs[self.pipeline_index].iterations; + println!("new node"); + for iteration in 0..iters { + println!("iters: {}", iters); + encoder.push_debug_group(&format!( + "Compute pass {} iteration {}", + self.pipeline_index, iteration + )); + + { + let node = self.pipeline_index as u32; + let selection = selectors.selectors[&node][iteration as usize]; + let mut pass = + encoder.begin_compute_pass(&ComputePassDescriptor::default()); + pass.set_bind_group(0, &bind_groups.bind_groups[selection as usize], &[]); + pass.set_pipeline(pipeline); + pass.dispatch_workgroups( + ((BUFFER_LEN + 15) / 16) as u32, + ((BUFFER_LEN + 15) / 16) as u32, + 1, + ); + } + encoder.pop_debug_group(); + } + } } - - - Ok(()) } } diff --git a/src/shaders/domain_warp.wgsl b/src/shaders/domain_warp.wgsl index d9a9742..d5ac130 100644 --- a/src/shaders/domain_warp.wgsl +++ b/src/shaders/domain_warp.wgsl @@ -24,8 +24,7 @@ fn sample_with_offset(pos: vec2, offset: vec2) -> vec4 { ); return textureLoad(input_texture, new_pos); } - -@compute @workgroup_size(8, 8) +@compute @workgroup_size(16, 16) fn main(@builtin(global_invocation_id) global_id: vec3) { let x = global_id.x; let y = global_id.y; @@ -36,13 +35,12 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { } -// let upos = vec2(i32(x), i32(y)); -// textureStore(output_texture, upos, vec4f(0.0, 0.0, 1.0, 1.0)); // Solid blue - -// } - - let upos = vec2(i32(x), i32(y)); + + // // Just output solid red to verify shader is running + // textureStore(output_texture, upos, vec4(1.0, 0.0, 0.0, 1.0)); + + let dim = f32(params.dimensions); // Convert position to 0-1 range for noise generation diff --git a/src/shaders/3rd_pass.wgsl b/src/shaders/extract.wgsl similarity index 80% rename from src/shaders/3rd_pass.wgsl rename to src/shaders/extract.wgsl index a7f5b5d..8504784 100644 --- a/src/shaders/3rd_pass.wgsl +++ b/src/shaders/extract.wgsl @@ -13,7 +13,7 @@ struct Params { @group(0) @binding(1) var input_texture: texture_storage_2d; @group(0) @binding(2) var output_texture: texture_storage_2d; -@compute @workgroup_size(8, 8) +@compute @workgroup_size(16, 16) fn main(@builtin(global_invocation_id) global_id: vec3) { let x = global_id.x; let y = global_id.y; @@ -26,6 +26,5 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { let current = textureLoad(input_texture, upos); - // Write solid red to verify shader is writing - textureStore(output_texture, upos, vec4f(1-current.r, current.g, current.b, 1.0)); + textureStore(output_texture, upos, vec4f(1.- current.x,1.,1.,1.)); } diff --git a/src/shaders/generate_circle.wgsl b/src/shaders/generate_circle.wgsl index b871695..9259c28 100644 --- a/src/shaders/generate_circle.wgsl +++ b/src/shaders/generate_circle.wgsl @@ -14,11 +14,12 @@ struct Params { // @group(0) @binding(0) var params: Params; // @group(0) @binding(1) var texture: texture_storage_2d; @group(0) @binding(0) var params: Params; +@group(0) @binding(1) var input_texture: texture_storage_2d; @group(0) @binding(2) var output_texture: texture_storage_2d; // Changed to 8x8 workgroup size - better for most GPUs -@compute @workgroup_size(8, 8) +@compute @workgroup_size(16, 16) fn main(@builtin(global_invocation_id) global_id: vec3) { let x = global_id.x; let y = global_id.y;