Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bevy_gpu_compute/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub use bevy_gpu_compute_macro::wgsl_input_array;
pub use bevy_gpu_compute_macro::wgsl_output_array;
pub use bevy_gpu_compute_macro::wgsl_output_vec;
pub use bevy_gpu_compute_macro::wgsl_shader_module;
pub use bevy_gpu_compute_macro::wgsl_shader_module_library;

//helpers when writing the shader module:
pub use bevy_gpu_compute_core::MaxOutputLengths;
Expand Down
9 changes: 9 additions & 0 deletions bevy_gpu_compute_core/src/rust/library_import.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use crate::wgsl::shader_module::user_defined_portion::WgslShaderModuleUserPortion;

pub fn merge_libraries_into_wgsl_module(user_module: &mut WgslShaderModuleUserPortion, library_modules: &mut Vec<WgslShaderModuleUserPortion>) {
for library in library_modules.iter_mut() {
user_module.helper_functions.append(&mut library.helper_functions);
user_module.static_consts.append(&mut library.static_consts);
user_module.helper_types.append(&mut library.helper_types);
}
}
2 changes: 2 additions & 0 deletions bevy_gpu_compute_core/src/rust/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod in_out_metadata;
mod iter_space_dimmensions;
mod library_import;
mod max_output_lengths;
mod type_erased_array_input_data;
mod type_erased_config_input_data;
Expand All @@ -8,6 +9,7 @@ mod type_safe_api_helpers;

pub use in_out_metadata::*;
pub use iter_space_dimmensions::*;
pub use library_import::*;
pub use max_output_lengths::*;
pub use type_erased_array_input_data::*;
pub use type_erased_config_input_data::*;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ mod tests {

#[test]
fn test_wgsl_shader_module_library_portion_from_user_portion() {
let user_portion = WgslShaderModuleUserPortion { static_consts: vec![WgslConstAssignment { code: WgslShaderModuleSectionCode { wgsl_code: "const example_module_const : u32 = 42;".to_string() } }], helper_types: vec![], uniforms: vec![WgslType { name: ShaderCustomTypeName::new("Uniforms"), code: WgslShaderModuleSectionCode { wgsl_code: "struct Uniforms { time : f32, resolution : vec2 < f32 > , }".to_string() } }], input_arrays: vec![WgslInputArray { item_type: WgslType { name: ShaderCustomTypeName::new("Position"), code: WgslShaderModuleSectionCode { wgsl_code: "alias Position = array < f32, 2 > ;".to_string() } } }, WgslInputArray { item_type: WgslType { name: ShaderCustomTypeName::new("Radius") , code: WgslShaderModuleSectionCode { wgsl_code: "alias Radius = f32;".to_string() } }}], output_arrays: vec![WgslOutputArray { item_type: WgslType { name: ShaderCustomTypeName::new("CollisionResult"), code: WgslShaderModuleSectionCode { wgsl_code: "struct CollisionResult { entity1 : u32, entity2 : u32, }".to_string() } }, atomic_counter_name: Some("collisionresult_counter".to_string()) }], helper_functions: vec![WgslFunction { name: "calculate_distance_squared".to_string(), code: WgslShaderModuleSectionCode { wgsl_code: "fn calculate_distance_squared(p1 : array < f32, 2 > , p2 : array < f32, 2 >)\n-> f32\n{\n let dx = p1 [0] - p2 [0]; let dy = p1 [1] - p2 [1]; return dx * dx + dy *\n dy;\n}".to_string() } }], main_function: Some(WgslFunction { name: "main".to_owned(), code: WgslShaderModuleSectionCode { wgsl_code: "fn main(@builtin(global_invocation_id) iter_pos: vec3<u32>)\n{\n let current_entity = iter_pos.x; let other_entity = iter_pos.y; if\n current_entity >= POSITION_INPUT_ARRAY_LENGTH || other_entity >=\n POSITION_INPUT_ARRAY_LENGTH || current_entity == other_entity ||\n current_entity >= other_entity { return; } let current_radius =\n radius_input_array [current_entity]; let other_radius = radius_input_array\n [other_entity]; if current_radius <= 0.0 || other_radius <= 0.0\n { return; } let current_pos = position_input_array [current_entity]; let\n other_pos = position_input_array [other_entity]; let dist_squared =\n calculate_distance_squared(current_pos, other_pos); let radius_sum =\n current_radius + other_radius; if dist_squared < radius_sum * radius_sum\n {\n {\n let collisionresult_output_array_index =\n atomicAdd(& collisionresult_counter, 1u); if\n collisionresult_output_array_index <\n COLLISIONRESULT_OUTPUT_ARRAY_LENGTH\n {\n collisionresult_output_array\n [collisionresult_output_array_index] = CollisionResult\n { entity1 : current_entity, entity2 : other_entity, };\n }\n };\n }\n}".to_owned() } }), binding_numbers_by_variable_name: Some(HashMap::from([(String::from("uniforms"), 0), (String::from("position_input_array"), 1), (String::from("radius_input_array"), 2), (String::from("collisionresult_output_array"), 3), (String::from("collisionresult_counter"), 4)]))
let user_portion = WgslShaderModuleUserPortion { static_consts: vec![WgslConstAssignment { code: WgslShaderModuleSectionCode { wgsl_code: "const example_module_const : u32 = 42;".to_string() } }], helper_types: vec![], uniforms: vec![WgslType { name: ShaderCustomTypeName::new("Uniforms"), code: WgslShaderModuleSectionCode { wgsl_code: "struct Uniforms { time : f32, resolution : vec2 < f32 > , }".to_string() } }], input_arrays: vec![WgslInputArray { item_type: WgslType { name: ShaderCustomTypeName::new("Position"), code: WgslShaderModuleSectionCode { wgsl_code: "alias Position = array < f32, 2 > ;".to_string() } } }, WgslInputArray { item_type: WgslType { name: ShaderCustomTypeName::new("Radius") , code: WgslShaderModuleSectionCode { wgsl_code: "alias Radius = f32;".to_string() } }}], output_arrays: vec![WgslOutputArray { item_type: WgslType { name: ShaderCustomTypeName::new("CollisionResult"), code: WgslShaderModuleSectionCode { wgsl_code: "struct CollisionResult { entity1 : u32, entity2 : u32, }".to_string() } }, atomic_counter_name: Some("collisionresult_counter".to_string()) }], helper_functions: vec![WgslFunction { name: "calculate_distance_squared".to_string(), code: WgslShaderModuleSectionCode { wgsl_code: "fn calculate_distance_squared(p1 : array < f32, 2 > , p2 : array < f32, 2 >)\n-> f32\n{\n let dx = p1 [0] - p2 [0]; let dy = p1 [1] - p2 [1]; return dx * dx + dy *\n dy;\n}".to_string() } }], main_function: Some(WgslFunction { name: "main".to_owned(), code: WgslShaderModuleSectionCode { wgsl_code: "fn main(@builtin(global_invocation_id) iter_pos: vec3<u32>)\n{\n let current_entity = iter_pos.x; let other_entity = iter_pos.y; if\n current_entity >= POSITION_INPUT_ARRAY_LENGTH || other_entity >=\n POSITION_INPUT_ARRAY_LENGTH || current_entity == other_entity ||\n current_entity >= other_entity { return; } let current_radius =\n radius_input_array [current_entity]; let other_radius = radius_input_array\n [other_entity]; if current_radius <= 0.0 || other_radius <= 0.0\n { return; } let current_pos = position_input_array [current_entity]; let\n other_pos = position_input_array [other_entity]; let dist_squared =\n calculate_distance_squared(current_pos, other_pos); let radius_sum =\n current_radius + other_radius; if dist_squared < radius_sum * radius_sum\n {\n {\n let collisionresult_output_array_index =\n atomicAdd(& collisionresult_counter, 1u); if\n collisionresult_output_array_index <\n COLLISIONRESULT_OUTPUT_ARRAY_LENGTH\n {\n collisionresult_output_array\n [collisionresult_output_array_index] = CollisionResult\n { entity1 : current_entity, entity2 : other_entity, };\n }\n };\n }\n}".to_owned() } }), binding_numbers_by_variable_name: Some(HashMap::from([(String::from("uniforms"), 0), (String::from("position_input_array"), 1), (String::from("radius_input_array"), 2), (String::from("collisionresult_output_array"), 3), (String::from("collisionresult_counter"), 4)])), use_statements: vec![],
};

let expected_wgsl_code = "const example_module_const : u32 = 42;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub struct WgslShaderModuleUserPortion {
/// look for any attempt to ASSIGN to the value of "global_id.x", "global_id.y", or "global_id.z" or just "global_id" and throw an error
pub main_function: Option<WgslFunction>,
pub binding_numbers_by_variable_name: Option<HashMap<String, u32>>,
pub use_statements: Vec<WgslImport>,
}
impl WgslShaderModuleUserPortion {
pub fn empty() -> Self {
Expand All @@ -38,6 +39,7 @@ impl WgslShaderModuleUserPortion {
helper_functions: vec![],
main_function: None,
binding_numbers_by_variable_name: None,
use_statements: vec![],
}
}
}
4 changes: 4 additions & 0 deletions bevy_gpu_compute_core/src/wgsl/shader_sections/import.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#[derive(Clone, Debug, PartialEq)]
pub struct WgslImport {
pub path: String,
}
2 changes: 2 additions & 0 deletions bevy_gpu_compute_core/src/wgsl/shader_sections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod code;
mod const_assignment;
mod custom_type;
mod function;
mod import;
mod input_array;
mod output_array;
mod wgpu_binding;
Expand All @@ -11,6 +12,7 @@ pub use code::*;
pub use const_assignment::*;
pub use custom_type::*;
pub use function::*;
pub use import::*;
pub use input_array::*;
pub use output_array::*;
pub use wgpu_binding::*;
Expand Down
11 changes: 10 additions & 1 deletion bevy_gpu_compute_macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,16 @@ pub fn wgsl_shader_module(_attr: TokenStream, item: TokenStream) -> TokenStream
set_dummy(item.clone().into());
let module = parse_macro_input!(item as syn::ItemMod);
let compiler_pipeline = CompilerPipeline::default();
compiler_pipeline.compile(module).into()
compiler_pipeline.compile(module, true).into()
}

#[proc_macro_attribute]
#[proc_macro_error]
pub fn wgsl_shader_module_library(_attr: TokenStream, item: TokenStream) -> TokenStream {
set_dummy(item.clone().into());
let module = parse_macro_input!(item as syn::ItemMod);
let compiler_pipeline = CompilerPipeline::default();
compiler_pipeline.compile(module, false).into()
}

/// used to help this library figure out what to do with user-defined types
Expand Down
6 changes: 5 additions & 1 deletion bevy_gpu_compute_macro/src/pipeline/compilation_metadata.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use crate::pipeline::phases::custom_type_collector::custom_type::CustomType;
use crate::pipeline::phases::{
custom_type_collector::custom_type::CustomType, user_import_collector::user_import::UserImport,
};
use bevy_gpu_compute_core::wgsl::shader_module::user_defined_portion::WgslShaderModuleUserPortion;
use proc_macro2::TokenStream;

pub struct CompilationMetadata {
pub user_imports: Option<Vec<UserImport>>,
pub main_func_required: bool,
pub custom_types: Option<Vec<CustomType>>,
pub wgsl_module_user_portion: Option<WgslShaderModuleUserPortion>,
pub typesafe_buffer_builders: Option<TokenStream>,
Expand Down
21 changes: 19 additions & 2 deletions bevy_gpu_compute_macro/src/pipeline/compilation_unit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use proc_macro2::TokenStream;

use super::{
compilation_metadata::CompilationMetadata,
phases::custom_type_collector::custom_type::CustomType,
phases::{
custom_type_collector::custom_type::CustomType,
user_import_collector::user_import::UserImport,
},
};

pub struct CompilationUnit {
Expand All @@ -15,19 +18,24 @@ pub struct CompilationUnit {
}

impl CompilationUnit {
pub fn new(original_rust_module: syn::ItemMod) -> Self {
pub fn new(original_rust_module: syn::ItemMod, main_func_required: bool) -> Self {
CompilationUnit {
original_rust_module,
rust_module_for_cpu: None,
rust_module_for_gpu: None,
compiled_tokens: None,
metadata: CompilationMetadata {
user_imports: None,
custom_types: None,
wgsl_module_user_portion: None,
typesafe_buffer_builders: None,
main_func_required,
},
}
}
pub fn main_func_required(&self) -> bool {
self.metadata.main_func_required
}
pub fn rust_module_for_gpu(&self) -> &syn::ItemMod {
if self.rust_module_for_gpu.is_none() {
panic!("rust_module_for_gpu is not set");
Expand All @@ -40,6 +48,15 @@ impl CompilationUnit {
}
self.rust_module_for_cpu.as_ref().unwrap()
}
pub fn set_user_imports(&mut self, user_imports: Vec<UserImport>) {
self.metadata.user_imports = Some(user_imports);
}
pub fn user_imports(&self) -> &Vec<UserImport> {
if self.metadata.user_imports.is_none() {
panic!("user_imports is not set");
}
self.metadata.user_imports.as_ref().unwrap()
}
pub fn set_custom_types(&mut self, custom_types: Vec<CustomType>) {
self.metadata.custom_types = Some(custom_types);
}
Expand Down
6 changes: 4 additions & 2 deletions bevy_gpu_compute_macro/src/pipeline/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use super::phases::{
module_for_rust_usage_cleaner::compiler_phase::ModuleForRustUsageCleaner,
non_mutating_tree_validation::compiler_phase::NonMutatingTreeValidation,
typesafe_buffer_builders_generator::compiler_phase::TypesafeBufferBuildersGenerator,
user_import_collector::compiler_phase::UserImportCollector,
wgsl_helper_transformer::compiler_phase::WgslHelperTransformer,
};
use crate::pipeline::compilation_unit::CompilationUnit;
Expand All @@ -20,6 +21,7 @@ impl Default for CompilerPipeline {
Self {
phases: vec![
Box::new(NonMutatingTreeValidation {}),
Box::new(UserImportCollector {}),
Box::new(CustomTypeCollector {}),
Box::new(TypesafeBufferBuildersGenerator {}),
Box::new(WgslHelperTransformer {}),
Expand All @@ -31,8 +33,8 @@ impl Default for CompilerPipeline {
}
}
impl CompilerPipeline {
pub fn compile(&self, module: syn::ItemMod) -> TokenStream {
let mut unit = CompilationUnit::new(module);
pub fn compile(&self, module: syn::ItemMod, main_func_required: bool) -> TokenStream {
let mut unit = CompilationUnit::new(module, main_func_required);
for phase in &self.phases {
phase.execute(&mut unit);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use quote::quote;

pub fn generate_required_imports() -> TokenStream {
quote! {
use super::*;
use bevy_gpu_compute_core::wgsl::shader_sections::*; //todo, make this less brittle, how?
use bevy_gpu_compute_core::wgsl::shader_custom_type_name::*;
use bevy_gpu_compute_core::wgsl_helpers::*;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::collections::HashMap;

use bevy_gpu_compute_core::{
wgsl::shader_custom_type_name::ShaderCustomTypeName,
wgsl::shader_sections::{
WgslConstAssignment, WgslFunction, WgslInputArray, WgslOutputArray,
use bevy_gpu_compute_core::wgsl::{
shader_custom_type_name::ShaderCustomTypeName,
shader_sections::{
WgslConstAssignment, WgslFunction, WgslImport, WgslInputArray, WgslOutputArray,
WgslShaderModuleSectionCode, WgslType,
},
};
Expand Down Expand Up @@ -88,6 +88,12 @@ impl ToStructInitializer {
}
)
}

pub fn wgsl_import(c: &WgslImport) -> TokenStream {
let i: TokenStream = c.path.parse().unwrap();
quote!(#i)
}

pub fn hash_map(c: &HashMap<String, u32>) -> TokenStream {
let entries: TokenStream = c
.iter()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,18 @@ pub fn generate_shader_module_object(
.unwrap(),
);

let library_modules: TokenStream = wgsl_shader_module
.use_statements
.iter()
.map(|use_statement| {
let ts = ToStructInitializer::wgsl_import(use_statement);
quote!(#ts,)
})
.collect();

quote!(
pub fn parsed() -> WgslShaderModuleUserPortion {
WgslShaderModuleUserPortion {
let mut user_portion = WgslShaderModuleUserPortion {
static_consts: [
#static_consts
]
Expand All @@ -104,15 +113,20 @@ pub fn generate_shader_module_object(
.into(),
main_function: #main_function,
binding_numbers_by_variable_name: Some(#bindings_map),
}
use_statements: [].into(),
};
merge_libraries_into_wgsl_module(&mut user_portion, &mut [
#library_modules
].into());
user_portion
}
)
}

#[cfg(test)]
mod test {
use proc_macro_error::abort;
use proc_macro2::{Span, TokenStream};
use proc_macro_error::abort;

#[test]
pub fn test_parse_str() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub fn generate_unaltered_module(original_module: &ItemMod) -> TokenStream {
quote! {
#[allow(dead_code, unused_variables, unused_imports)]
mod #new_ident {
use super::*;
#content_combined
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ pub struct GpuResourceMngmntAndWgslGenerator;

impl CompilerPhase for GpuResourceMngmntAndWgslGenerator {
fn execute(&self, input: &mut CompilationUnit) {
let (shader_module, custom_types) =
parse_shader_module_for_gpu(input.rust_module_for_gpu(), input.custom_types());
let (shader_module, custom_types) = parse_shader_module_for_gpu(
input.rust_module_for_gpu(),
input.custom_types(),
input.main_func_required(),
input.user_imports(),
);
input.set_wgsl_module_user_portion(shader_module);
input.set_custom_types(custom_types);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use bevy_gpu_compute_core::wgsl::shader_sections::WgslImport;

use crate::pipeline::phases::user_import_collector::user_import::UserImport;

pub fn generate_user_imports_for_wgsl_module_def(
user_imports: &Vec<UserImport>,
) -> Vec<WgslImport> {
let mut out = vec![];
for import in user_imports.iter() {
let mut segments: Vec<String> = import.path.iter().map(|ident| ident.to_string()).collect();
segments.push("parsed()".to_string());
let mut path = segments.join("::");

if import.has_leading_colon {
path = format!("::{path}");
}

out.push(WgslImport { path });
}
out
}
Original file line number Diff line number Diff line change
@@ -1,26 +1,33 @@
use bevy_gpu_compute_core::wgsl::shader_module::user_defined_portion::WgslShaderModuleUserPortion;

use crate::pipeline::phases::custom_type_collector::custom_type::CustomType;
use crate::pipeline::phases::user_import_collector::user_import::UserImport;

use super::constants::extract_constants;
use super::divide_custom_types::generate_helper_types_inputs_and_outputs_for_wgsl_module_def;
use super::helper_functions::extract_helper_functions;
use super::imports::generate_user_imports_for_wgsl_module_def;
use super::main_function::parse_main_function;

/// This will also change custom_types
pub fn parse_shader_module_for_gpu(
rust_module_transformed_for_gpu: &syn::ItemMod,
custom_types: &Vec<CustomType>,
main_func_required: bool,
user_imports: &Vec<UserImport>,
) -> (WgslShaderModuleUserPortion, Vec<CustomType>) {
let mut out_module: WgslShaderModuleUserPortion = WgslShaderModuleUserPortion::empty();
out_module.main_function = Some(parse_main_function(
rust_module_transformed_for_gpu,
custom_types,
));
if main_func_required {
out_module.main_function = Some(parse_main_function(
rust_module_transformed_for_gpu,
custom_types,
));
}
out_module.static_consts = extract_constants(rust_module_transformed_for_gpu, custom_types);
out_module.helper_functions =
extract_helper_functions(rust_module_transformed_for_gpu, custom_types);
let new_custom_types =
generate_helper_types_inputs_and_outputs_for_wgsl_module_def(custom_types, &mut out_module);
out_module.use_statements = generate_user_imports_for_wgsl_module_def(user_imports);
(out_module, new_custom_types)
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod compiler_phase;
mod constants;
mod divide_custom_types;
mod helper_functions;
mod imports;
mod lib;
mod main_function;
pub mod to_wgsl_syntax;
Loading
Loading