From d0dedab9124db77595be3bceec2f337db25a0a7c Mon Sep 17 00:00:00 2001 From: arcadeperfect Date: Wed, 4 Dec 2024 12:53:08 -0500 Subject: [PATCH] ping pong --- assets/shaders/generate_circle.wgsl | 7 +- assets/shaders/second_pass.wgsl | 19 +++ src/main.rs | 245 ++++++++++++++++++++-------- 3 files changed, 200 insertions(+), 71 deletions(-) create mode 100644 assets/shaders/second_pass.wgsl diff --git a/assets/shaders/generate_circle.wgsl b/assets/shaders/generate_circle.wgsl index 03d4d28..23446f2 100644 --- a/assets/shaders/generate_circle.wgsl +++ b/assets/shaders/generate_circle.wgsl @@ -57,8 +57,11 @@ struct Params { noise_scale: f32, noise_amplitude: f32, } +// @group(0) @binding(0) var params: Params; +// @group(0) @binding(1) var texture: texture_storage_2d; @group(0) @binding(0) var params: Params; -@group(0) @binding(1) var texture: texture_storage_2d; +@group(0) @binding(2) var output_texture: texture_storage_2d; + // Changed to 8x8 workgroup size - better for most GPUs @compute @workgroup_size(8, 8) @@ -100,7 +103,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { pos = vec2f(mag, 0.0); let edge = vec2f(r, 0.0); - textureStore(texture, upos, vec4(v , dist, distance(pos, edge), 1.)); + textureStore(output_texture, upos, vec4(v , dist, distance(pos, edge), 1.)); } diff --git a/assets/shaders/second_pass.wgsl b/assets/shaders/second_pass.wgsl new file mode 100644 index 0000000..8f69354 --- /dev/null +++ b/assets/shaders/second_pass.wgsl @@ -0,0 +1,19 @@ +struct Params { + dimensions: u32, + radius: f32, + noise_seed: u32, + noise_scale: f32, + noise_amplitude: f32, +} + +@group(0) @binding(0) var params: Params; +@group(0) @binding(1) var input_texture: texture_storage_2d; +@group(0) @binding(2) var output_texture: texture_storage_2d; + +@compute @workgroup_size(8, 8) +fn main(@builtin(global_invocation_id) global_id: vec3){ + let upos = vec2(i32(global_id.x), i32(global_id.y)); + let value = textureLoad(input_texture, upos); + textureStore(output_texture, upos, vec4(1.0 - value.x, 1.0 - value.y, 1.0 - value.z, value.w)); +} + diff --git a/src/main.rs b/src/main.rs index 57badcf..6b09c54 100644 --- a/src/main.rs +++ b/src/main.rs @@ -24,7 +24,8 @@ use bytemuck::bytes_of; use gui::ParamsChanged; /// This example uses a shader source file from the assets subdirectory -const SHADER_ASSET_PATH: &str = "shaders/generate_circle.wgsl"; +const SHADER1_ASSET_PATH: &str = "shaders/generate_circle.wgsl"; +const SHADER2_ASSET_PATH: &str = "shaders/second_pass.wgsl"; const NOISE_SHADER_HANDLE: Handle = Handle::weak_from_u128(13378847158248049035); const VECTOR_SHADER_HANDLE: Handle = Handle::weak_from_u128(23378847158248049035); @@ -69,11 +70,11 @@ fn main() { } fn update_uniform_buffer( - gpu_buffer_bind_group: Option>, + bind_groups: Option>, render_queue: Res, params: Res, ) { - if let Some(bind_group) = gpu_buffer_bind_group { + if let Some(bind_group) = bind_groups { render_queue.write_buffer(&bind_group.uniform_buffer, 0, bytemuck::bytes_of(&*params)); } } @@ -82,9 +83,7 @@ fn update_uniform_buffer( struct GpuReadbackPlugin; impl Plugin for GpuReadbackPlugin { fn build(&self, app: &mut App) { - // let asset_server = app.world().resource::(); - // let _noise_shader: Handle = asset_server.load("shaders/noise.wgsl"); - // let _main_shader: Handle = asset_server.load("shaders/generate_circle.wgsl"); + // Load the noise shader first as an internal asset load_internal_asset!( app, @@ -98,38 +97,41 @@ impl Plugin for GpuReadbackPlugin { "../assets/shaders/utils/utils.wgsl", Shader::from_wgsl ); - - // Load the main shader that imports the noise shader - // load_internal_asset!( - // app, - // MAIN_SHADER_HANDLE, - // "../assets/shaders/generate_circle.wgsl", - // Shader::from_wgsl - // ); } fn finish(&self, app: &mut App) { let render_app = app.sub_app_mut(RenderApp); - render_app.init_resource::().add_systems( + render_app.init_resource::() + .add_systems( Render, ( update_uniform_buffer, - prepare_bind_group + prepare_bind_groups .in_set(RenderSet::PrepareBindGroups) // We don't need to recreate the bind group every frame - .run_if(not(resource_exists::)), + .run_if(not(resource_exists::)), ), ); - render_app - .world_mut() - .resource_mut::() - .add_node(ComputeNodeLabel, ComputeNode::default()); + + let mut render_graph = render_app.world_mut().resource_mut::(); + + render_graph.add_node(ComputeNodeLabel1, ComputeNode{ pass_index: 0}); + render_graph.add_node(ComputeNodeLabel2, ComputeNode{ pass_index: 1}); + render_graph.add_node_edge(ComputeNodeLabel1, ComputeNodeLabel2); + + // render_app + // .world_mut() + // .resource_mut::() + // .add_node(ComputeNodeLabel, ComputeNode::default()); render_app.add_event::(); } } #[derive(Resource, ExtractResource, Clone)] -struct ReadbackImage(Handle); +struct ReadbackImage { + ping: Handle, + pong: Handle, +} fn setup(mut commands: Commands, mut images: ResMut>) { commands.spawn((Camera2d::default(),)); @@ -148,31 +150,41 @@ fn setup(mut commands: Commands, mut images: ResMut>) { // TextureFormat::R32Uint, RenderAssetUsages::RENDER_WORLD, ); + // We also need to enable the COPY_SRC, as well as STORAGE_BINDING so we can use it in the // compute shader image.texture_descriptor.usage |= TextureUsages::COPY_SRC | TextureUsages::STORAGE_BINDING; let image = images.add(image); + let mut create_image = || { + let mut image = Image::new_fill( + size, + TextureDimension::D2, + &[0, 0, 0, 0], + TextureFormat::Rgba32Float, + // TextureFormat::R32Uint, + RenderAssetUsages::RENDER_WORLD, + ); + image.texture_descriptor.usage |= TextureUsages::COPY_SRC | TextureUsages::STORAGE_BINDING; + images.add(image) + }; + + let ping = create_image(); + let pong = create_image(); + commands.spawn(Readback::texture(ping.clone())); + // Spawn the readback components. For each frame, the data will be read back from the GPU // asynchronously and trigger the `ReadbackComplete` event on this entity. Despawn the entity // to stop reading back the data. // Textures can also be read back from the GPU. Pay careful attention to the format of the // texture, as it will affect how the data is interpreted. - commands.spawn(Readback::texture(image.clone())).observe( - |trigger: Trigger| { - // You probably want to interpret the data as a color rather than a `ShaderType`, - // but in this case we know the data is a single channel storage texture, so we can - // interpret it as a `Vec` - let data: Vec = trigger.event().to_shader_type(); - // info!("Image {:?}", data); - }, - ); + // commands.spawn(Readback::texture(image.clone())); commands.spawn(( // Sprite::from_image(image.clone()), Sprite { - image: image.clone(), + image: ping.clone(), custom_size: Some(Vec2::splat(1000.0)), ..Default::default() }, @@ -180,56 +192,79 @@ fn setup(mut commands: Commands, mut images: ResMut>) { )); // This is just a simple way to pass the image handle to the render app for our compute node - commands.insert_resource(ReadbackImage(image)); + // commands.insert_resource(ReadbackImage(image)); + commands.insert_resource(ReadbackImage { + ping: ping.clone(), + pong: pong, + }); } #[derive(Resource)] -struct GpuBufferBindGroup { - bind_group: BindGroup, +struct GpuBufferBindGroups { + first_pass: BindGroup, + second_pass: BindGroup, uniform_buffer: Buffer, } +// #[derive(Resource)] +// struct GpuBufferBindGroup { +// bind_group: BindGroup, +// uniform_buffer: Buffer, +// } -fn prepare_bind_group( +fn prepare_bind_groups( mut commands: Commands, - pipeline: Res, + pipeline: Res, render_device: Res, image: Res, images: Res>, params: Res, render_queue: Res, ) { - let uniform_buffer = render_device.create_buffer(&BufferDescriptor { label: Some("uniform"), size: std::mem::size_of::() as u64, usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST, mapped_at_creation: false, }); - + render_queue.write_buffer(&uniform_buffer, 0, bytes_of(&*params)); - - let image = images.get(&image.0).unwrap(); - let bind_group = render_device.create_bind_group( + + let ping_image = images.get(&image.ping).unwrap(); + let pong_image = images.get(&image.pong).unwrap(); + + let first_pass = render_device.create_bind_group( None, &pipeline.layout, &BindGroupEntries::sequential(( uniform_buffer.as_entire_buffer_binding(), - image.texture_view.into_binding(), + ping_image.texture_view.into_binding(), + pong_image.texture_view.into_binding(), )), ); - commands.insert_resource(GpuBufferBindGroup { - bind_group, + let second_pass = render_device.create_bind_group( + None, + &pipeline.layout, + &BindGroupEntries::sequential(( + uniform_buffer.as_entire_buffer_binding(), + pong_image.texture_view.into_binding(), + ping_image.texture_view.into_binding(), + )), + ); + commands.insert_resource(GpuBufferBindGroups { + first_pass, + second_pass, uniform_buffer, }); } #[derive(Resource)] -struct ComputePipeline { +struct ComputePipelines { layout: BindGroupLayout, - pipeline: CachedComputePipelineId, + first_pass: CachedComputePipelineId, + second_pass: CachedComputePipelineId, } -impl FromWorld for ComputePipeline { +impl FromWorld for ComputePipelines { fn from_world(world: &mut World) -> Self { let render_device = world.resource::(); let layout = render_device.create_bind_group_layout( @@ -238,33 +273,89 @@ impl FromWorld for ComputePipeline { ShaderStages::COMPUTE, ( uniform_buffer::(false), + texture_storage_2d(TextureFormat::Rgba32Float, StorageTextureAccess::ReadOnly), texture_storage_2d(TextureFormat::Rgba32Float, StorageTextureAccess::WriteOnly), ), ), ); - let shader = world.load_asset(SHADER_ASSET_PATH); + let shader1 = world.load_asset(SHADER1_ASSET_PATH); let pipeline_cache = world.resource::(); - let pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { - label: Some("GPU readback compute shader".into()), + let first_pass = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { + label: Some("First pass".into()), + layout: vec![layout.clone()], + push_constant_ranges: Vec::new(), + shader: shader1.clone(), + shader_defs: Vec::new(), + entry_point: "main".into(), + zero_initialize_workgroup_memory: false, + }); + let shader2 = world.load_asset(SHADER2_ASSET_PATH); + let second_pass = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { + label: Some("First pass".into()), layout: vec![layout.clone()], push_constant_ranges: Vec::new(), - shader: shader.clone(), + shader: shader2.clone(), shader_defs: Vec::new(), entry_point: "main".into(), zero_initialize_workgroup_memory: false, }); - ComputePipeline { layout, pipeline } + + ComputePipelines { + layout, + first_pass, + second_pass, + } } } +// #[derive(Resource)] +// struct ComputePipeline { +// layout: BindGroupLayout, +// pipeline: CachedComputePipelineId, +// } + +// impl FromWorld for ComputePipeline { +// fn from_world(world: &mut World) -> Self { +// let render_device = world.resource::(); +// let layout = render_device.create_bind_group_layout( +// None, +// &BindGroupLayoutEntries::sequential( +// ShaderStages::COMPUTE, +// ( +// uniform_buffer::(false), +// texture_storage_2d(TextureFormat::Rgba32Float, StorageTextureAccess::WriteOnly), +// texture_storage_2d(TextureFormat::Rgba32Float, StorageTextureAccess::WriteOnly), +// ), +// ), +// ); +// let shader = world.load_asset(SHADER1_ASSET_PATH); + +// let pipeline_cache = world.resource::(); +// let pipeline = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { +// label: Some("GPU readback compute shader".into()), +// layout: vec![layout.clone()], +// push_constant_ranges: Vec::new(), +// shader: shader.clone(), +// shader_defs: Vec::new(), +// entry_point: "main".into(), +// zero_initialize_workgroup_memory: false, +// }); +// ComputePipeline { layout, pipeline } +// } +// } + /// Label to identify the node in the render graph #[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)] -struct ComputeNodeLabel; +struct ComputeNodeLabel1; +#[derive(Debug, Hash, PartialEq, Eq, Clone, RenderLabel)] +struct ComputeNodeLabel2; /// The node that will execute the compute shader #[derive(Default)] -struct ComputeNode {} +struct ComputeNode { + pass_index: u32, +} impl render_graph::Node for ComputeNode { fn run( &self, @@ -273,22 +364,38 @@ impl render_graph::Node for ComputeNode { world: &World, ) -> Result<(), render_graph::NodeRunError> { let pipeline_cache = world.resource::(); - let pipeline = world.resource::(); - let bind_group = world.resource::(); - - if let Some(init_pipeline) = pipeline_cache.get_compute_pipeline(pipeline.pipeline) { - let mut pass = - render_context - .command_encoder() - .begin_compute_pass(&ComputePassDescriptor { - label: Some("GPU readback compute pass"), - ..default() - }); - - pass.set_bind_group(0, &bind_group.bind_group, &[]); - pass.set_pipeline(init_pipeline); + let pipelines = world.resource::(); + let bind_groups = world.resource::(); + + let (pipeline_id, bind_group) = match self.pass_index { + 0 => (pipelines.first_pass, &bind_groups.first_pass), + 1 => (pipelines.second_pass, &bind_groups.second_pass), + _ => return Ok(()), + }; + + if let Some(pipeline) = pipeline_cache.get_compute_pipeline(pipeline_id) { + let mut pass = render_context + .command_encoder() + .begin_compute_pass(&ComputePassDescriptor::default()); + + pass.set_bind_group(0, bind_group, &[]); + pass.set_pipeline(pipeline); pass.dispatch_workgroups(BUFFER_LEN as u32, BUFFER_LEN as u32, 1); } + + // if let Some(init_pipeline) = pipeline_cache.get_compute_pipeline(pipeline.pipeline) { + // let mut pass = + // render_context + // .command_encoder() + // .begin_compute_pass(&ComputePassDescriptor { + // label: Some("GPU readback compute pass"), + // ..default() + // }); + + // pass.set_bind_group(0, &bind_group.bind_group, &[]); + // pass.set_pipeline(init_pipeline); + // pass.dispatch_workgroups(BUFFER_LEN as u32, BUFFER_LEN as u32, 1); + // } Ok(()) } }