diff --git a/bevy_gpu_compute/src/prelude.rs b/bevy_gpu_compute/src/prelude.rs index ff9504e..b36cb11 100644 --- a/bevy_gpu_compute/src/prelude.rs +++ b/bevy_gpu_compute/src/prelude.rs @@ -4,6 +4,7 @@ pub use bevy_gpu_compute_macro::wgsl_input_array; pub use bevy_gpu_compute_macro::wgsl_output_array; pub use bevy_gpu_compute_macro::wgsl_output_vec; pub use bevy_gpu_compute_macro::wgsl_shader_module; +pub use bevy_gpu_compute_macro::wgsl_shader_module_library; //helpers when writing the shader module: pub use bevy_gpu_compute_core::MaxOutputLengths; diff --git a/bevy_gpu_compute_core/src/rust/library_import.rs b/bevy_gpu_compute_core/src/rust/library_import.rs new file mode 100755 index 0000000..56055cd --- /dev/null +++ b/bevy_gpu_compute_core/src/rust/library_import.rs @@ -0,0 +1,9 @@ +use crate::wgsl::shader_module::user_defined_portion::WgslShaderModuleUserPortion; + +pub fn merge_libraries_into_wgsl_module(user_module: &mut WgslShaderModuleUserPortion, library_modules: &mut Vec) { + for library in library_modules.iter_mut() { + user_module.helper_functions.append(&mut library.helper_functions); + user_module.static_consts.append(&mut library.static_consts); + user_module.helper_types.append(&mut library.helper_types); + } +} diff --git a/bevy_gpu_compute_core/src/rust/mod.rs b/bevy_gpu_compute_core/src/rust/mod.rs index 9e19c9d..45bbf44 100644 --- a/bevy_gpu_compute_core/src/rust/mod.rs +++ b/bevy_gpu_compute_core/src/rust/mod.rs @@ -1,5 +1,6 @@ mod in_out_metadata; mod iter_space_dimmensions; +mod library_import; mod max_output_lengths; mod type_erased_array_input_data; mod type_erased_config_input_data; @@ -8,6 +9,7 @@ mod type_safe_api_helpers; pub use in_out_metadata::*; pub use iter_space_dimmensions::*; +pub use library_import::*; pub use max_output_lengths::*; pub use type_erased_array_input_data::*; pub use type_erased_config_input_data::*; diff --git a/bevy_gpu_compute_core/src/wgsl/shader_module/derived_portion.rs b/bevy_gpu_compute_core/src/wgsl/shader_module/derived_portion.rs index 55e9d3b..8158870 100644 --- a/bevy_gpu_compute_core/src/wgsl/shader_module/derived_portion.rs +++ b/bevy_gpu_compute_core/src/wgsl/shader_module/derived_portion.rs @@ -93,7 +93,7 @@ mod tests { #[test] fn test_wgsl_shader_module_library_portion_from_user_portion() { - let user_portion = WgslShaderModuleUserPortion { static_consts: vec![WgslConstAssignment { code: WgslShaderModuleSectionCode { wgsl_code: "const example_module_const : u32 = 42;".to_string() } }], helper_types: vec![], uniforms: vec![WgslType { name: ShaderCustomTypeName::new("Uniforms"), code: WgslShaderModuleSectionCode { wgsl_code: "struct Uniforms { time : f32, resolution : vec2 < f32 > , }".to_string() } }], input_arrays: vec![WgslInputArray { item_type: WgslType { name: ShaderCustomTypeName::new("Position"), code: WgslShaderModuleSectionCode { wgsl_code: "alias Position = array < f32, 2 > ;".to_string() } } }, WgslInputArray { item_type: WgslType { name: ShaderCustomTypeName::new("Radius") , code: WgslShaderModuleSectionCode { wgsl_code: "alias Radius = f32;".to_string() } }}], output_arrays: vec![WgslOutputArray { item_type: WgslType { name: ShaderCustomTypeName::new("CollisionResult"), code: WgslShaderModuleSectionCode { wgsl_code: "struct CollisionResult { entity1 : u32, entity2 : u32, }".to_string() } }, atomic_counter_name: Some("collisionresult_counter".to_string()) }], helper_functions: vec![WgslFunction { name: "calculate_distance_squared".to_string(), code: WgslShaderModuleSectionCode { wgsl_code: "fn calculate_distance_squared(p1 : array < f32, 2 > , p2 : array < f32, 2 >)\n-> f32\n{\n let dx = p1 [0] - p2 [0]; let dy = p1 [1] - p2 [1]; return dx * dx + dy *\n dy;\n}".to_string() } }], main_function: Some(WgslFunction { name: "main".to_owned(), code: WgslShaderModuleSectionCode { wgsl_code: "fn main(@builtin(global_invocation_id) iter_pos: vec3)\n{\n let current_entity = iter_pos.x; let other_entity = iter_pos.y; if\n current_entity >= POSITION_INPUT_ARRAY_LENGTH || other_entity >=\n POSITION_INPUT_ARRAY_LENGTH || current_entity == other_entity ||\n current_entity >= other_entity { return; } let current_radius =\n radius_input_array [current_entity]; let other_radius = radius_input_array\n [other_entity]; if current_radius <= 0.0 || other_radius <= 0.0\n { return; } let current_pos = position_input_array [current_entity]; let\n other_pos = position_input_array [other_entity]; let dist_squared =\n calculate_distance_squared(current_pos, other_pos); let radius_sum =\n current_radius + other_radius; if dist_squared < radius_sum * radius_sum\n {\n {\n let collisionresult_output_array_index =\n atomicAdd(& collisionresult_counter, 1u); if\n collisionresult_output_array_index <\n COLLISIONRESULT_OUTPUT_ARRAY_LENGTH\n {\n collisionresult_output_array\n [collisionresult_output_array_index] = CollisionResult\n { entity1 : current_entity, entity2 : other_entity, };\n }\n };\n }\n}".to_owned() } }), binding_numbers_by_variable_name: Some(HashMap::from([(String::from("uniforms"), 0), (String::from("position_input_array"), 1), (String::from("radius_input_array"), 2), (String::from("collisionresult_output_array"), 3), (String::from("collisionresult_counter"), 4)])) + let user_portion = WgslShaderModuleUserPortion { static_consts: vec![WgslConstAssignment { code: WgslShaderModuleSectionCode { wgsl_code: "const example_module_const : u32 = 42;".to_string() } }], helper_types: vec![], uniforms: vec![WgslType { name: ShaderCustomTypeName::new("Uniforms"), code: WgslShaderModuleSectionCode { wgsl_code: "struct Uniforms { time : f32, resolution : vec2 < f32 > , }".to_string() } }], input_arrays: vec![WgslInputArray { item_type: WgslType { name: ShaderCustomTypeName::new("Position"), code: WgslShaderModuleSectionCode { wgsl_code: "alias Position = array < f32, 2 > ;".to_string() } } }, WgslInputArray { item_type: WgslType { name: ShaderCustomTypeName::new("Radius") , code: WgslShaderModuleSectionCode { wgsl_code: "alias Radius = f32;".to_string() } }}], output_arrays: vec![WgslOutputArray { item_type: WgslType { name: ShaderCustomTypeName::new("CollisionResult"), code: WgslShaderModuleSectionCode { wgsl_code: "struct CollisionResult { entity1 : u32, entity2 : u32, }".to_string() } }, atomic_counter_name: Some("collisionresult_counter".to_string()) }], helper_functions: vec![WgslFunction { name: "calculate_distance_squared".to_string(), code: WgslShaderModuleSectionCode { wgsl_code: "fn calculate_distance_squared(p1 : array < f32, 2 > , p2 : array < f32, 2 >)\n-> f32\n{\n let dx = p1 [0] - p2 [0]; let dy = p1 [1] - p2 [1]; return dx * dx + dy *\n dy;\n}".to_string() } }], main_function: Some(WgslFunction { name: "main".to_owned(), code: WgslShaderModuleSectionCode { wgsl_code: "fn main(@builtin(global_invocation_id) iter_pos: vec3)\n{\n let current_entity = iter_pos.x; let other_entity = iter_pos.y; if\n current_entity >= POSITION_INPUT_ARRAY_LENGTH || other_entity >=\n POSITION_INPUT_ARRAY_LENGTH || current_entity == other_entity ||\n current_entity >= other_entity { return; } let current_radius =\n radius_input_array [current_entity]; let other_radius = radius_input_array\n [other_entity]; if current_radius <= 0.0 || other_radius <= 0.0\n { return; } let current_pos = position_input_array [current_entity]; let\n other_pos = position_input_array [other_entity]; let dist_squared =\n calculate_distance_squared(current_pos, other_pos); let radius_sum =\n current_radius + other_radius; if dist_squared < radius_sum * radius_sum\n {\n {\n let collisionresult_output_array_index =\n atomicAdd(& collisionresult_counter, 1u); if\n collisionresult_output_array_index <\n COLLISIONRESULT_OUTPUT_ARRAY_LENGTH\n {\n collisionresult_output_array\n [collisionresult_output_array_index] = CollisionResult\n { entity1 : current_entity, entity2 : other_entity, };\n }\n };\n }\n}".to_owned() } }), binding_numbers_by_variable_name: Some(HashMap::from([(String::from("uniforms"), 0), (String::from("position_input_array"), 1), (String::from("radius_input_array"), 2), (String::from("collisionresult_output_array"), 3), (String::from("collisionresult_counter"), 4)])), use_statements: vec![], }; let expected_wgsl_code = "const example_module_const : u32 = 42; diff --git a/bevy_gpu_compute_core/src/wgsl/shader_module/user_defined_portion.rs b/bevy_gpu_compute_core/src/wgsl/shader_module/user_defined_portion.rs index da7b220..c9e0704 100644 --- a/bevy_gpu_compute_core/src/wgsl/shader_module/user_defined_portion.rs +++ b/bevy_gpu_compute_core/src/wgsl/shader_module/user_defined_portion.rs @@ -26,6 +26,7 @@ pub struct WgslShaderModuleUserPortion { /// look for any attempt to ASSIGN to the value of "global_id.x", "global_id.y", or "global_id.z" or just "global_id" and throw an error pub main_function: Option, pub binding_numbers_by_variable_name: Option>, + pub use_statements: Vec, } impl WgslShaderModuleUserPortion { pub fn empty() -> Self { @@ -38,6 +39,7 @@ impl WgslShaderModuleUserPortion { helper_functions: vec![], main_function: None, binding_numbers_by_variable_name: None, + use_statements: vec![], } } } diff --git a/bevy_gpu_compute_core/src/wgsl/shader_sections/import.rs b/bevy_gpu_compute_core/src/wgsl/shader_sections/import.rs new file mode 100755 index 0000000..80e6c3f --- /dev/null +++ b/bevy_gpu_compute_core/src/wgsl/shader_sections/import.rs @@ -0,0 +1,4 @@ +#[derive(Clone, Debug, PartialEq)] +pub struct WgslImport { + pub path: String, +} diff --git a/bevy_gpu_compute_core/src/wgsl/shader_sections/mod.rs b/bevy_gpu_compute_core/src/wgsl/shader_sections/mod.rs index a7c26cf..15fb823 100644 --- a/bevy_gpu_compute_core/src/wgsl/shader_sections/mod.rs +++ b/bevy_gpu_compute_core/src/wgsl/shader_sections/mod.rs @@ -2,6 +2,7 @@ mod code; mod const_assignment; mod custom_type; mod function; +mod import; mod input_array; mod output_array; mod wgpu_binding; @@ -11,6 +12,7 @@ pub use code::*; pub use const_assignment::*; pub use custom_type::*; pub use function::*; +pub use import::*; pub use input_array::*; pub use output_array::*; pub use wgpu_binding::*; diff --git a/bevy_gpu_compute_macro/src/lib.rs b/bevy_gpu_compute_macro/src/lib.rs index 1e82085..5edfc0e 100644 --- a/bevy_gpu_compute_macro/src/lib.rs +++ b/bevy_gpu_compute_macro/src/lib.rs @@ -60,7 +60,16 @@ pub fn wgsl_shader_module(_attr: TokenStream, item: TokenStream) -> TokenStream set_dummy(item.clone().into()); let module = parse_macro_input!(item as syn::ItemMod); let compiler_pipeline = CompilerPipeline::default(); - compiler_pipeline.compile(module).into() + compiler_pipeline.compile(module, true).into() +} + +#[proc_macro_attribute] +#[proc_macro_error] +pub fn wgsl_shader_module_library(_attr: TokenStream, item: TokenStream) -> TokenStream { + set_dummy(item.clone().into()); + let module = parse_macro_input!(item as syn::ItemMod); + let compiler_pipeline = CompilerPipeline::default(); + compiler_pipeline.compile(module, false).into() } /// used to help this library figure out what to do with user-defined types diff --git a/bevy_gpu_compute_macro/src/pipeline/compilation_metadata.rs b/bevy_gpu_compute_macro/src/pipeline/compilation_metadata.rs index 8135334..9790b99 100644 --- a/bevy_gpu_compute_macro/src/pipeline/compilation_metadata.rs +++ b/bevy_gpu_compute_macro/src/pipeline/compilation_metadata.rs @@ -1,8 +1,12 @@ -use crate::pipeline::phases::custom_type_collector::custom_type::CustomType; +use crate::pipeline::phases::{ + custom_type_collector::custom_type::CustomType, user_import_collector::user_import::UserImport, +}; use bevy_gpu_compute_core::wgsl::shader_module::user_defined_portion::WgslShaderModuleUserPortion; use proc_macro2::TokenStream; pub struct CompilationMetadata { + pub user_imports: Option>, + pub main_func_required: bool, pub custom_types: Option>, pub wgsl_module_user_portion: Option, pub typesafe_buffer_builders: Option, diff --git a/bevy_gpu_compute_macro/src/pipeline/compilation_unit.rs b/bevy_gpu_compute_macro/src/pipeline/compilation_unit.rs index 90cc65a..79abc61 100644 --- a/bevy_gpu_compute_macro/src/pipeline/compilation_unit.rs +++ b/bevy_gpu_compute_macro/src/pipeline/compilation_unit.rs @@ -3,7 +3,10 @@ use proc_macro2::TokenStream; use super::{ compilation_metadata::CompilationMetadata, - phases::custom_type_collector::custom_type::CustomType, + phases::{ + custom_type_collector::custom_type::CustomType, + user_import_collector::user_import::UserImport, + }, }; pub struct CompilationUnit { @@ -15,19 +18,24 @@ pub struct CompilationUnit { } impl CompilationUnit { - pub fn new(original_rust_module: syn::ItemMod) -> Self { + pub fn new(original_rust_module: syn::ItemMod, main_func_required: bool) -> Self { CompilationUnit { original_rust_module, rust_module_for_cpu: None, rust_module_for_gpu: None, compiled_tokens: None, metadata: CompilationMetadata { + user_imports: None, custom_types: None, wgsl_module_user_portion: None, typesafe_buffer_builders: None, + main_func_required, }, } } + pub fn main_func_required(&self) -> bool { + self.metadata.main_func_required + } pub fn rust_module_for_gpu(&self) -> &syn::ItemMod { if self.rust_module_for_gpu.is_none() { panic!("rust_module_for_gpu is not set"); @@ -40,6 +48,15 @@ impl CompilationUnit { } self.rust_module_for_cpu.as_ref().unwrap() } + pub fn set_user_imports(&mut self, user_imports: Vec) { + self.metadata.user_imports = Some(user_imports); + } + pub fn user_imports(&self) -> &Vec { + if self.metadata.user_imports.is_none() { + panic!("user_imports is not set"); + } + self.metadata.user_imports.as_ref().unwrap() + } pub fn set_custom_types(&mut self, custom_types: Vec) { self.metadata.custom_types = Some(custom_types); } diff --git a/bevy_gpu_compute_macro/src/pipeline/lib.rs b/bevy_gpu_compute_macro/src/pipeline/lib.rs index 446830f..188754e 100644 --- a/bevy_gpu_compute_macro/src/pipeline/lib.rs +++ b/bevy_gpu_compute_macro/src/pipeline/lib.rs @@ -7,6 +7,7 @@ use super::phases::{ module_for_rust_usage_cleaner::compiler_phase::ModuleForRustUsageCleaner, non_mutating_tree_validation::compiler_phase::NonMutatingTreeValidation, typesafe_buffer_builders_generator::compiler_phase::TypesafeBufferBuildersGenerator, + user_import_collector::compiler_phase::UserImportCollector, wgsl_helper_transformer::compiler_phase::WgslHelperTransformer, }; use crate::pipeline::compilation_unit::CompilationUnit; @@ -20,6 +21,7 @@ impl Default for CompilerPipeline { Self { phases: vec![ Box::new(NonMutatingTreeValidation {}), + Box::new(UserImportCollector {}), Box::new(CustomTypeCollector {}), Box::new(TypesafeBufferBuildersGenerator {}), Box::new(WgslHelperTransformer {}), @@ -31,8 +33,8 @@ impl Default for CompilerPipeline { } } impl CompilerPipeline { - pub fn compile(&self, module: syn::ItemMod) -> TokenStream { - let mut unit = CompilationUnit::new(module); + pub fn compile(&self, module: syn::ItemMod, main_func_required: bool) -> TokenStream { + let mut unit = CompilationUnit::new(module, main_func_required); for phase in &self.phases { phase.execute(&mut unit); } diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/final_structure_generator/generate_required_imports.rs b/bevy_gpu_compute_macro/src/pipeline/phases/final_structure_generator/generate_required_imports.rs index b80f3e1..9b1f51f 100644 --- a/bevy_gpu_compute_macro/src/pipeline/phases/final_structure_generator/generate_required_imports.rs +++ b/bevy_gpu_compute_macro/src/pipeline/phases/final_structure_generator/generate_required_imports.rs @@ -3,6 +3,7 @@ use quote::quote; pub fn generate_required_imports() -> TokenStream { quote! { + use super::*; use bevy_gpu_compute_core::wgsl::shader_sections::*; //todo, make this less brittle, how? use bevy_gpu_compute_core::wgsl::shader_custom_type_name::*; use bevy_gpu_compute_core::wgsl_helpers::*; diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/final_structure_generator/per_component_expansion.rs b/bevy_gpu_compute_macro/src/pipeline/phases/final_structure_generator/per_component_expansion.rs index 5ee3bd4..b63254f 100644 --- a/bevy_gpu_compute_macro/src/pipeline/phases/final_structure_generator/per_component_expansion.rs +++ b/bevy_gpu_compute_macro/src/pipeline/phases/final_structure_generator/per_component_expansion.rs @@ -1,9 +1,9 @@ use std::collections::HashMap; -use bevy_gpu_compute_core::{ - wgsl::shader_custom_type_name::ShaderCustomTypeName, - wgsl::shader_sections::{ - WgslConstAssignment, WgslFunction, WgslInputArray, WgslOutputArray, +use bevy_gpu_compute_core::wgsl::{ + shader_custom_type_name::ShaderCustomTypeName, + shader_sections::{ + WgslConstAssignment, WgslFunction, WgslImport, WgslInputArray, WgslOutputArray, WgslShaderModuleSectionCode, WgslType, }, }; @@ -88,6 +88,12 @@ impl ToStructInitializer { } ) } + + pub fn wgsl_import(c: &WgslImport) -> TokenStream { + let i: TokenStream = c.path.parse().unwrap(); + quote!(#i) + } + pub fn hash_map(c: &HashMap) -> TokenStream { let entries: TokenStream = c .iter() diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/final_structure_generator/shader_module_object.rs b/bevy_gpu_compute_macro/src/pipeline/phases/final_structure_generator/shader_module_object.rs index 18f76bd..9d34121 100644 --- a/bevy_gpu_compute_macro/src/pipeline/phases/final_structure_generator/shader_module_object.rs +++ b/bevy_gpu_compute_macro/src/pipeline/phases/final_structure_generator/shader_module_object.rs @@ -76,9 +76,18 @@ pub fn generate_shader_module_object( .unwrap(), ); + let library_modules: TokenStream = wgsl_shader_module + .use_statements + .iter() + .map(|use_statement| { + let ts = ToStructInitializer::wgsl_import(use_statement); + quote!(#ts,) + }) + .collect(); + quote!( pub fn parsed() -> WgslShaderModuleUserPortion { - WgslShaderModuleUserPortion { + let mut user_portion = WgslShaderModuleUserPortion { static_consts: [ #static_consts ] @@ -104,15 +113,20 @@ pub fn generate_shader_module_object( .into(), main_function: #main_function, binding_numbers_by_variable_name: Some(#bindings_map), - } + use_statements: [].into(), + }; + merge_libraries_into_wgsl_module(&mut user_portion, &mut [ + #library_modules + ].into()); + user_portion } ) } #[cfg(test)] mod test { - use proc_macro_error::abort; use proc_macro2::{Span, TokenStream}; + use proc_macro_error::abort; #[test] pub fn test_parse_str() { diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/final_structure_generator/unaltered_module.rs b/bevy_gpu_compute_macro/src/pipeline/phases/final_structure_generator/unaltered_module.rs index 4b9801f..1d4711e 100644 --- a/bevy_gpu_compute_macro/src/pipeline/phases/final_structure_generator/unaltered_module.rs +++ b/bevy_gpu_compute_macro/src/pipeline/phases/final_structure_generator/unaltered_module.rs @@ -24,6 +24,7 @@ pub fn generate_unaltered_module(original_module: &ItemMod) -> TokenStream { quote! { #[allow(dead_code, unused_variables, unused_imports)] mod #new_ident { + use super::*; #content_combined } } diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/compiler_phase.rs b/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/compiler_phase.rs index 8d9ebbc..c4662a9 100644 --- a/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/compiler_phase.rs +++ b/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/compiler_phase.rs @@ -6,8 +6,12 @@ pub struct GpuResourceMngmntAndWgslGenerator; impl CompilerPhase for GpuResourceMngmntAndWgslGenerator { fn execute(&self, input: &mut CompilationUnit) { - let (shader_module, custom_types) = - parse_shader_module_for_gpu(input.rust_module_for_gpu(), input.custom_types()); + let (shader_module, custom_types) = parse_shader_module_for_gpu( + input.rust_module_for_gpu(), + input.custom_types(), + input.main_func_required(), + input.user_imports(), + ); input.set_wgsl_module_user_portion(shader_module); input.set_custom_types(custom_types); } diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/imports.rs b/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/imports.rs new file mode 100755 index 0000000..5f3370d --- /dev/null +++ b/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/imports.rs @@ -0,0 +1,21 @@ +use bevy_gpu_compute_core::wgsl::shader_sections::WgslImport; + +use crate::pipeline::phases::user_import_collector::user_import::UserImport; + +pub fn generate_user_imports_for_wgsl_module_def( + user_imports: &Vec, +) -> Vec { + let mut out = vec![]; + for import in user_imports.iter() { + let mut segments: Vec = import.path.iter().map(|ident| ident.to_string()).collect(); + segments.push("parsed()".to_string()); + let mut path = segments.join("::"); + + if import.has_leading_colon { + path = format!("::{path}"); + } + + out.push(WgslImport { path }); + } + out +} diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/lib.rs b/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/lib.rs index 6f91d68..a4e74b2 100644 --- a/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/lib.rs +++ b/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/lib.rs @@ -1,26 +1,33 @@ use bevy_gpu_compute_core::wgsl::shader_module::user_defined_portion::WgslShaderModuleUserPortion; use crate::pipeline::phases::custom_type_collector::custom_type::CustomType; +use crate::pipeline::phases::user_import_collector::user_import::UserImport; use super::constants::extract_constants; use super::divide_custom_types::generate_helper_types_inputs_and_outputs_for_wgsl_module_def; use super::helper_functions::extract_helper_functions; +use super::imports::generate_user_imports_for_wgsl_module_def; use super::main_function::parse_main_function; /// This will also change custom_types pub fn parse_shader_module_for_gpu( rust_module_transformed_for_gpu: &syn::ItemMod, custom_types: &Vec, + main_func_required: bool, + user_imports: &Vec, ) -> (WgslShaderModuleUserPortion, Vec) { let mut out_module: WgslShaderModuleUserPortion = WgslShaderModuleUserPortion::empty(); - out_module.main_function = Some(parse_main_function( - rust_module_transformed_for_gpu, - custom_types, - )); + if main_func_required { + out_module.main_function = Some(parse_main_function( + rust_module_transformed_for_gpu, + custom_types, + )); + } out_module.static_consts = extract_constants(rust_module_transformed_for_gpu, custom_types); out_module.helper_functions = extract_helper_functions(rust_module_transformed_for_gpu, custom_types); let new_custom_types = generate_helper_types_inputs_and_outputs_for_wgsl_module_def(custom_types, &mut out_module); + out_module.use_statements = generate_user_imports_for_wgsl_module_def(user_imports); (out_module, new_custom_types) } diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/mod.rs b/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/mod.rs index 35c096a..f8842c0 100644 --- a/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/mod.rs +++ b/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/mod.rs @@ -2,6 +2,7 @@ pub mod compiler_phase; mod constants; mod divide_custom_types; mod helper_functions; +mod imports; mod lib; mod main_function; pub mod to_wgsl_syntax; diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/to_wgsl_syntax/mod.rs b/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/to_wgsl_syntax/mod.rs index 8e290d9..ff78700 100644 --- a/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/to_wgsl_syntax/mod.rs +++ b/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/to_wgsl_syntax/mod.rs @@ -9,6 +9,7 @@ use proc_macro2::{Span, TokenStream}; use quote::ToTokens; use remove_attributes::remove_attributes; use remove_pub_from_struct_def::PubRemover; +use remove_use_stmts::UseStmtRemover; use syn::{File, parse, visit::Visit, visit_mut::VisitMut}; use r#type::TypeToWgslTransformer; use type_def::TypeDefToWgslTransformer; @@ -64,6 +65,7 @@ mod implicit_to_explicit_return; mod local_var; pub mod remove_attributes; mod remove_pub_from_struct_def; +mod remove_use_stmts; mod r#type; mod type_def; mod wgsl_builtin_constructors; @@ -87,7 +89,7 @@ pub fn convert_file_to_wgsl( ); abort!(Span::call_site(), message); }; - + UseStmtRemover {}.visit_file_mut(&mut file); PubRemover {}.visit_file_mut(&mut file); TypeToWgslTransformer { custom_types }.visit_file_mut(&mut file); ArrayToWgslTransformer {}.visit_file_mut(&mut file); diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/to_wgsl_syntax/remove_pub_from_struct_def.rs b/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/to_wgsl_syntax/remove_pub_from_struct_def.rs index f9da5e1..7eb84fa 100644 --- a/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/to_wgsl_syntax/remove_pub_from_struct_def.rs +++ b/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/to_wgsl_syntax/remove_pub_from_struct_def.rs @@ -12,4 +12,19 @@ impl VisitMut for PubRemover { syn::visit_mut::visit_field_mut(self, i); i.vis = Visibility::Inherited; } + + fn visit_item_type_mut(&mut self, i: &mut syn::ItemType) { + syn::visit_mut::visit_item_type_mut(self, i); + i.vis = Visibility::Inherited; + } + + fn visit_item_struct_mut(&mut self, i: &mut syn::ItemStruct) { + syn::visit_mut::visit_item_struct_mut(self, i); + i.vis = Visibility::Inherited; + } + + fn visit_item_const_mut(&mut self, i: &mut syn::ItemConst) { + syn::visit_mut::visit_item_const_mut(self, i); + i.vis = Visibility::Inherited; + } } diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/to_wgsl_syntax/remove_use_stmts.rs b/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/to_wgsl_syntax/remove_use_stmts.rs new file mode 100644 index 0000000..54b80fa --- /dev/null +++ b/bevy_gpu_compute_macro/src/pipeline/phases/gpu_resource_mngmnt_and_wgsl_generator/to_wgsl_syntax/remove_use_stmts.rs @@ -0,0 +1,14 @@ +pub struct UseStmtRemover {} + +use quote::quote; +use syn::{Item, visit_mut::VisitMut}; + +impl VisitMut for UseStmtRemover { + fn visit_item_mut(&mut self, i: &mut Item) { + syn::visit_mut::visit_item_mut(self, i); + if let Item::Use(use_stmt) = i { + // remove the use statement + *i = Item::Verbatim(quote! {}) + } + } +} diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/mod.rs b/bevy_gpu_compute_macro/src/pipeline/phases/mod.rs index 47c51d0..a27b1d0 100644 --- a/bevy_gpu_compute_macro/src/pipeline/phases/mod.rs +++ b/bevy_gpu_compute_macro/src/pipeline/phases/mod.rs @@ -5,4 +5,5 @@ pub mod gpu_resource_mngmnt_and_wgsl_generator; pub mod module_for_rust_usage_cleaner; pub mod non_mutating_tree_validation; pub mod typesafe_buffer_builders_generator; +pub mod user_import_collector; pub mod wgsl_helper_transformer; diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/module_for_rust_usage_cleaner/compiler_phase.rs b/bevy_gpu_compute_macro/src/pipeline/phases/module_for_rust_usage_cleaner/compiler_phase.rs index e14573d..1ab7bdb 100644 --- a/bevy_gpu_compute_macro/src/pipeline/phases/module_for_rust_usage_cleaner/compiler_phase.rs +++ b/bevy_gpu_compute_macro/src/pipeline/phases/module_for_rust_usage_cleaner/compiler_phase.rs @@ -12,7 +12,9 @@ pub struct ModuleForRustUsageCleaner; impl CompilerPhase for ModuleForRustUsageCleaner { fn execute(&self, input: &mut CompilationUnit) { let mut m = input.rust_module_for_cpu().clone(); - mutate_main_function_for_cpu_usage(input.wgsl_module_user_portion(), &mut m); + if input.main_func_required() { + mutate_main_function_for_cpu_usage(input.wgsl_module_user_portion(), &mut m); + } remove_internal_attributes(&mut m); make_types_pod(&mut m); make_types_public(&mut m); diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/non_mutating_tree_validation/compiler_phase.rs b/bevy_gpu_compute_macro/src/pipeline/phases/non_mutating_tree_validation/compiler_phase.rs index 32151e6..88dad6c 100644 --- a/bevy_gpu_compute_macro/src/pipeline/phases/non_mutating_tree_validation/compiler_phase.rs +++ b/bevy_gpu_compute_macro/src/pipeline/phases/non_mutating_tree_validation/compiler_phase.rs @@ -2,7 +2,6 @@ use crate::pipeline::{compilation_unit::CompilationUnit, phases::compiler_phase: use super::validate_no_doc_comments::validate_no_doc_comments; use super::validate_no_iter_pos_assignments::validate_no_iter_pos_assignments; -use super::validate_use_statements::validate_use_statements; /// any sort of input validation that can be done on the original tree that doesn't require mutation pub struct NonMutatingTreeValidation; @@ -11,6 +10,5 @@ impl CompilerPhase for NonMutatingTreeValidation { fn execute(&self, input: &mut CompilationUnit) { validate_no_doc_comments(input.original_rust_module()); validate_no_iter_pos_assignments(input.original_rust_module()); - validate_use_statements(input.original_rust_module()); } } diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/non_mutating_tree_validation/mod.rs b/bevy_gpu_compute_macro/src/pipeline/phases/non_mutating_tree_validation/mod.rs index 9577b8f..7e62487 100644 --- a/bevy_gpu_compute_macro/src/pipeline/phases/non_mutating_tree_validation/mod.rs +++ b/bevy_gpu_compute_macro/src/pipeline/phases/non_mutating_tree_validation/mod.rs @@ -1,4 +1,3 @@ pub mod compiler_phase; mod validate_no_doc_comments; mod validate_no_iter_pos_assignments; -mod validate_use_statements; diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/non_mutating_tree_validation/validate_use_statements.rs b/bevy_gpu_compute_macro/src/pipeline/phases/non_mutating_tree_validation/validate_use_statements.rs deleted file mode 100644 index b26f49b..0000000 --- a/bevy_gpu_compute_macro/src/pipeline/phases/non_mutating_tree_validation/validate_use_statements.rs +++ /dev/null @@ -1,54 +0,0 @@ -use proc_macro_error::abort; -use quote::ToTokens; -use syn::{Item, ItemMod, ItemUse, spanned::Spanned, visit::Visit}; - -const VALID_USE_STATEMENT_PATHS: [&str; 3] = - ["wgsl_helpers", "bevy_gpu_compute", "bevy_gpu_compute_macro"]; - -pub fn validate_use_statements(original_rust_module: &ItemMod) { - let mut handler = UseStatementHandler {}; - handler.visit_item_mod(original_rust_module); -} - -struct UseStatementHandler {} - -impl Visit<'_> for UseStatementHandler { - fn visit_item(&mut self, i: &Item) { - syn::visit::visit_item(self, i); - if let Item::Use(use_stmt) = i { - validate_use_statement(use_stmt); - } - } -} - -fn validate_use_statement(use_stmt: &ItemUse) { - let mut single_handler = SingleUseStatementHandler { found: false }; - single_handler.visit_item_use(use_stmt); - if !single_handler.found { - let message = format!( - "Invalid use statement: {:?}. You are only allowed to import from one of these crates: {}", - use_stmt.to_token_stream().to_string(), - VALID_USE_STATEMENT_PATHS.join(", ") - ); - abort!(use_stmt.span(), message); - } -} - -struct SingleUseStatementHandler { - found: bool, -} - -impl Visit<'_> for SingleUseStatementHandler { - fn visit_use_path(&mut self, i: &syn::UsePath) { - syn::visit::visit_use_path(self, i); - if VALID_USE_STATEMENT_PATHS.contains(&i.ident.to_string().as_str()) { - self.found = true; - } - } - fn visit_use_name(&mut self, i: &syn::UseName) { - syn::visit::visit_use_name(self, i); - if VALID_USE_STATEMENT_PATHS.contains(&i.ident.to_string().as_str()) { - self.found = true; - } - } -} diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/user_import_collector/collect.rs b/bevy_gpu_compute_macro/src/pipeline/phases/user_import_collector/collect.rs new file mode 100644 index 0000000..b4ef535 --- /dev/null +++ b/bevy_gpu_compute_macro/src/pipeline/phases/user_import_collector/collect.rs @@ -0,0 +1,90 @@ +// find all user declared imports, and make a list of them + +use syn::visit::Visit; + +use super::user_import::UserImport; + +#[derive(Default)] +struct UserImportCollector { + use_statements: Vec, +} + +const BUILT_IN_USE_STATEMENT_PATHS: [&str; 3] = + ["wgsl_helpers", "bevy_gpu_compute", "bevy_gpu_compute_macro"]; + +impl<'ast> Visit<'ast> for UserImportCollector { + fn visit_item_use(&mut self, i: &'ast syn::ItemUse) { + syn::visit::visit_item_use(self, i); + let mut built_in_handler = BuiltInUseStatementHandler::default(); + built_in_handler.visit_item_use(i); + + if built_in_handler.found { + return; + } + + let leading_colon = i.leading_colon.is_some(); + let path = traverse_use_tree(&i.tree); + self.use_statements + .push(UserImport::new(leading_colon, path)); + } +} + +#[derive(Default)] +struct BuiltInUseStatementHandler { + found: bool, +} + +impl Visit<'_> for BuiltInUseStatementHandler { + fn visit_use_path(&mut self, i: &syn::UsePath) { + syn::visit::visit_use_path(self, i); + if BUILT_IN_USE_STATEMENT_PATHS.contains(&i.ident.to_string().as_str()) { + self.found = true; + } + } + fn visit_use_name(&mut self, i: &syn::UseName) { + syn::visit::visit_use_name(self, i); + if BUILT_IN_USE_STATEMENT_PATHS.contains(&i.ident.to_string().as_str()) { + self.found = true; + } + } +} + +fn traverse_use_tree(use_tree: &syn::UseTree) -> Vec { + let mut handler = UseTreeHandler::default(); + handler.visit_use_tree(use_tree); + + if !handler.is_glob { + panic!("Only use globs are allowed (e.g. `use foo::bar::*;`)"); + } + + handler.path +} + +#[derive(Default)] +struct UseTreeHandler { + is_glob: bool, + path: Vec, +} + +impl<'ast> Visit<'ast> for UseTreeHandler { + fn visit_use_tree(&mut self, i: &'ast syn::UseTree) { + match i { + syn::UseTree::Path(path) => self.path.push(path.ident.clone()), + syn::UseTree::Name(_) => panic!("Only use globs are allowed (e.g. `use foo::bar::*;`)"), + syn::UseTree::Rename(_) => { + panic!("Use renames are unsupported (e.g. `use foo as bar;`)") + } + syn::UseTree::Glob(_) => self.is_glob = true, + syn::UseTree::Group(_) => { + panic!("Use groups are unsupported (e.g. `use foo::{{bar, baz}};`)") + } + }; + syn::visit::visit_use_tree(self, i); + } +} + +pub fn collect_user_imports(original_rust_module: &syn::ItemMod) -> Vec { + let mut import_collector = UserImportCollector::default(); + import_collector.visit_item_mod(original_rust_module); + import_collector.use_statements +} diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/user_import_collector/compiler_phase.rs b/bevy_gpu_compute_macro/src/pipeline/phases/user_import_collector/compiler_phase.rs new file mode 100644 index 0000000..881b937 --- /dev/null +++ b/bevy_gpu_compute_macro/src/pipeline/phases/user_import_collector/compiler_phase.rs @@ -0,0 +1,11 @@ +use crate::pipeline::phases::user_import_collector::collect::collect_user_imports; +use crate::pipeline::{compilation_unit::CompilationUnit, phases::compiler_phase::CompilerPhase}; + +pub struct UserImportCollector; + +impl CompilerPhase for UserImportCollector { + fn execute(&self, input: &mut CompilationUnit) { + let user_imports = collect_user_imports(input.original_rust_module()); + input.set_user_imports(user_imports); + } +} diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/user_import_collector/mod.rs b/bevy_gpu_compute_macro/src/pipeline/phases/user_import_collector/mod.rs new file mode 100755 index 0000000..8fd5e49 --- /dev/null +++ b/bevy_gpu_compute_macro/src/pipeline/phases/user_import_collector/mod.rs @@ -0,0 +1,3 @@ +pub mod collect; +pub mod compiler_phase; +pub mod user_import; diff --git a/bevy_gpu_compute_macro/src/pipeline/phases/user_import_collector/user_import.rs b/bevy_gpu_compute_macro/src/pipeline/phases/user_import_collector/user_import.rs new file mode 100644 index 0000000..7fc0da9 --- /dev/null +++ b/bevy_gpu_compute_macro/src/pipeline/phases/user_import_collector/user_import.rs @@ -0,0 +1,13 @@ +pub struct UserImport { + pub has_leading_colon: bool, + pub path: Vec, +} + +impl UserImport { + pub fn new(has_leading_colon: bool, path: Vec) -> Self { + Self { + has_leading_colon, + path, + } + } +} diff --git a/bevy_gpu_compute_macro/tests/components.rs b/bevy_gpu_compute_macro/tests/components.rs index c0a3615..f7e98ee 100644 --- a/bevy_gpu_compute_macro/tests/components.rs +++ b/bevy_gpu_compute_macro/tests/components.rs @@ -510,7 +510,7 @@ fn test_entire_collision_shader() { } let t2 = collision_shader::parsed(); - let user_portion = WgslShaderModuleUserPortion { static_consts: vec![WgslConstAssignment { code: WgslShaderModuleSectionCode { wgsl_code: "const EXAMPLE_MODULE_CONST : u32 = 42;".to_string() } }], helper_types: vec![], uniforms: vec![WgslType { name: ShaderCustomTypeName::new("Uniforms"), code: WgslShaderModuleSectionCode { wgsl_code: "struct Uniforms { time : f32, resolution : vec2 < f32 > , }".to_string() } }], input_arrays: vec![WgslInputArray { item_type: WgslType { name: ShaderCustomTypeName::new("Position"), code: WgslShaderModuleSectionCode { wgsl_code: "alias Position = array < f32, 2 > ;".to_string() } } }, WgslInputArray { item_type: WgslType { name: ShaderCustomTypeName::new("Radius") , code: WgslShaderModuleSectionCode { wgsl_code: "alias Radius = f32;".to_string() } } }], output_arrays: vec![WgslOutputArray { item_type: WgslType { name: ShaderCustomTypeName::new("CollisionResult"), code: WgslShaderModuleSectionCode { wgsl_code: "struct CollisionResult { entity1 : u32, entity2 : u32, }".to_string() } }, atomic_counter_name: Some("collisionresult_counter".to_string()) }], helper_functions: vec![WgslFunction { name: "calculate_distance_squared".to_string(), code: WgslShaderModuleSectionCode { wgsl_code: "fn calculate_distance_squared(p1 : array < f32, 2 > , p2 : array < f32, 2 >)\n-> f32\n{\n let dx = p1 [0] - p2 [0]; let dy = p1 [1] - p2 [1]; return dx * dx + dy *\n dy;\n}".to_string() } }], main_function: Some(WgslFunction { name: "main".to_owned(), code: WgslShaderModuleSectionCode { wgsl_code: "fn main(@builtin(global_invocation_id) iter_pos: vec3)\n{\n let current_entity = iter_pos.x; let other_entity = iter_pos.y; if\n current_entity >= POSITION_INPUT_ARRAY_LENGTH || other_entity >=\n POSITION_INPUT_ARRAY_LENGTH || current_entity == other_entity ||\n current_entity >= other_entity { return; } let current_radius =\n radius_input_array [current_entity]; let other_radius = radius_input_array\n [other_entity]; if current_radius <= 0.0 || other_radius <= 0.0\n { return; } let current_pos = position_input_array [current_entity]; let\n other_pos = position_input_array [other_entity]; let dist_squared =\n calculate_distance_squared(current_pos, other_pos); let radius_sum =\n current_radius + other_radius; if dist_squared < radius_sum * radius_sum\n {\n {\n let collisionresult_output_array_index =\n atomicAdd(& collisionresult_counter, 1u); if\n collisionresult_output_array_index <\n COLLISIONRESULT_OUTPUT_ARRAY_LENGTH\n {\n collisionresult_output_array\n [collisionresult_output_array_index] = + let user_portion = WgslShaderModuleUserPortion { static_consts: vec![WgslConstAssignment { code: WgslShaderModuleSectionCode { wgsl_code: "const EXAMPLE_MODULE_CONST : u32 = 42;".to_string() } }], helper_types: vec![], uniforms: vec![WgslType { name: ShaderCustomTypeName::new("Uniforms"), code: WgslShaderModuleSectionCode { wgsl_code: "struct Uniforms { time : f32, resolution : vec2 < f32 > , }".to_string() } }], input_arrays: vec![WgslInputArray { item_type: WgslType { name: ShaderCustomTypeName::new("Position"), code: WgslShaderModuleSectionCode { wgsl_code: "alias Position = array < f32, 2 > ;".to_string() } } }, WgslInputArray { item_type: WgslType { name: ShaderCustomTypeName::new("Radius") , code: WgslShaderModuleSectionCode { wgsl_code: "alias Radius = f32;".to_string() } } }], output_arrays: vec![WgslOutputArray { item_type: WgslType { name: ShaderCustomTypeName::new("CollisionResult"), code: WgslShaderModuleSectionCode { wgsl_code: "struct CollisionResult { entity1 : u32, entity2 : u32, }".to_string() } }, atomic_counter_name: Some("collisionresult_counter".to_string()) }], helper_functions: vec![WgslFunction { name: "calculate_distance_squared".to_string(), code: WgslShaderModuleSectionCode { wgsl_code: "fn calculate_distance_squared(p1 : array < f32, 2 > , p2 : array < f32, 2 >)\n-> f32\n{\n let dx = p1 [0] - p2 [0]; let dy = p1 [1] - p2 [1]; return dx * dx + dy *\n dy;\n}".to_string() } }], use_statements: vec![], main_function: Some(WgslFunction { name: "main".to_owned(), code: WgslShaderModuleSectionCode { wgsl_code: "fn main(@builtin(global_invocation_id) iter_pos: vec3)\n{\n let current_entity = iter_pos.x; let other_entity = iter_pos.y; if\n current_entity >= POSITION_INPUT_ARRAY_LENGTH || other_entity >=\n POSITION_INPUT_ARRAY_LENGTH || current_entity == other_entity ||\n current_entity >= other_entity { return; } let current_radius =\n radius_input_array [current_entity]; let other_radius = radius_input_array\n [other_entity]; if current_radius <= 0.0 || other_radius <= 0.0\n { return; } let current_pos = position_input_array [current_entity]; let\n other_pos = position_input_array [other_entity]; let dist_squared =\n calculate_distance_squared(current_pos, other_pos); let radius_sum =\n current_radius + other_radius; if dist_squared < radius_sum * radius_sum\n {\n {\n let collisionresult_output_array_index =\n atomicAdd(& collisionresult_counter, 1u); if\n collisionresult_output_array_index <\n COLLISIONRESULT_OUTPUT_ARRAY_LENGTH\n {\n collisionresult_output_array\n [collisionresult_output_array_index] = CollisionResult(current_entity, other_entity);\n }\n };\n }\n}".to_owned() } }), binding_numbers_by_variable_name: Some(HashMap::from([ ("uniforms".to_string(), 1),