Skip to content

TODO: Load Baking #7799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: trunk
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion naga-cli/src/bin/naga.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ struct Args {
#[argh(switch)]
keep_coordinate_space: bool,

/// force loop bounding if the backend supports it.
#[argh(switch)]
force_loop_bounding: bool,

/// in dot output, include only the control flow graph
#[argh(switch)]
dot_cfg_only: bool,
Expand Down Expand Up @@ -427,6 +431,10 @@ fn run() -> anyhow::Result<()> {
block_ctx_dump_prefix: args.block_ctx_dir.clone().map(std::path::PathBuf::from),
};

params.spv_out.force_loop_bounding = args.force_loop_bounding;
params.msl.force_loop_bounding = args.force_loop_bounding;
params.hlsl.force_loop_bounding = args.force_loop_bounding;

params.entry_point.clone_from(&args.entry_point);
if let Some(ref version) = args.profile {
params.glsl.version = version.0;
Expand Down Expand Up @@ -706,7 +714,7 @@ fn write_output(
let file = fs::File::create(output_path)?;
bincode::serialize_into(file, module)?;
}
"metal" => {
"metal" | "msl" => {
use naga::back::msl;

let mut options = params.msl.clone();
Expand Down
166 changes: 166 additions & 0 deletions naga/src/back/load_baking.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
use core::{cell::Cell, mem};
use std::vec::Vec;

use crate::{ir, valid, FastHashMap, FastHashSet, Handle};

enum VariableState {
Uninitialized,
Initialized,
Written,
ReRead,
}

impl VariableState {
fn write(&mut self) {
match self {
VariableState::Uninitialized => *self = VariableState::Initialized,
VariableState::Initialized => *self = VariableState::Written,
VariableState::ReRead => {}
VariableState::Written => {}
}
}

fn read(&mut self) {
match self {
VariableState::Uninitialized => unreachable!(),
VariableState::Initialized => {}
VariableState::Written => *self = VariableState::ReRead,
VariableState::ReRead => {}
}
}
}

struct LoadBaker<'a> {
pub module: &'a ir::Module,
pub function: &'a ir::Function,
pub function_info: &'a valid::FunctionInfo,

requires_bake: FastHashSet<Handle<ir::Expression>>,

states: FastHashMap<Handle<ir::Expression>, (usize, VariableState)>,

loop_levels: Vec<usize>,
}

impl<'a> LoadBaker<'a> {
pub fn new(
module: &'a ir::Module,
function: &'a ir::Function,
function_info: &'a valid::FunctionInfo,
) -> Self {
let requires_bake = FastHashSet::default();
let states = FastHashMap::default();

Self {
module,
function,
function_info,
requires_bake,
states,
loop_levels: Vec::new(),
}
}

pub fn evaluate(&mut self) {
let current_depth = ScopedDepth::new();
self.evaluate_block(&current_depth, &self.function.body);
}

fn evaluate_block(&mut self, depth: &ScopedDepth, block: &'a ir::Block) {
let _guard = depth.enter();
for statement in block {
match *statement {
ir::Statement::Store { pointer, value } => {
self.evaluate_expression(depth, value);
self.register_write(depth, pointer);
}
ir::Statement::Emit(ref range) => {
for expression in range.clone() {
self.evaluate_expression(depth, expression);
}
}
ir::Statement::Block(ref block) => {
self.evaluate_block(&depth, block);
}
ir::Statement::Loop {
ref body,
ref continuing,
break_if,
} => {
self.loop_levels.push(depth.get());

self.evaluate_block(&depth, body);
self.evaluate_block(&depth, continuing);
if let Some(condition) = break_if {
self.evaluate_expression(depth, condition);
}

self.loop_levels.pop();
}
_ => {}
}
}
}

fn evaluate_expression(&mut self, depth: &ScopedDepth, expression: Handle<ir::Expression>) {
let expression = &self.function.expressions[expression];

match *expression {
ir::Expression::Load { pointer } => {
let exp = &self.function.expressions[pointer];

let is_local_variable = matches!(exp, ir::Expression::LocalVariable(_));

if is_local_variable {
self.register_load(depth, pointer);
}
}
_ => {}
}
}

fn register_load(&mut self, depth: &ScopedDepth, pointer: Handle<ir::Expression>) {
// Register the load as happening at the depth of the variable.
if let Some((_, state)) = self.states.get_mut(&pointer) {
state.read();
return;
}
}

fn register_write(&mut self, depth: &ScopedDepth, pointer: Handle<ir::Expression>) {}
}

struct ScopedDepth {
current: Cell<usize>,
}

impl ScopedDepth {
pub fn new() -> Self {
Self {
current: Cell::new(0),
}
}

pub fn enter(&self) -> DepthGuard<'_> {
self.current.set(self.current.get() + 1);
DepthGuard {
manager: self,
depth: self.current.get(),
}
}

pub fn get(&self) -> usize {
self.current.get()
}
}

struct DepthGuard<'a> {
manager: &'a ScopedDepth,
depth: usize,
}

impl Drop for DepthGuard<'_> {
fn drop(&mut self) {
self.manager.current.set(self.manager.current.get() - 1);
}
}
2 changes: 2 additions & 0 deletions naga/src/back/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ pub mod pipeline_constants;
#[cfg(any(hlsl_out, glsl_out))]
mod continue_forward;

mod load_baking;

/// Names of vector components.
pub const COMPONENTS: &[char] = &['x', 'y', 'z', 'w'];
/// Indent for backends.
Expand Down
40 changes: 40 additions & 0 deletions naga/tests/in/spv/load-elimination.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 21
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %18 "main"
OpExecutionMode %18 LocalSize 1 1 1
OpName %5 "sink"
OpName %11 "a"
OpName %16 "b"
%2 = OpTypeVoid
%3 = OpTypeInt 32 0
%4 = OpConstant %3 0
%6 = OpTypePointer Private %3
%5 = OpVariable %6 Private %4
%9 = OpTypeFunction %2
%10 = OpConstant %3 2
%12 = OpTypePointer Function %3
%13 = OpConstantNull %3
%8 = OpFunction %2 None %9
%7 = OpLabel
%11 = OpVariable %12 Function %13
OpBranch %14
%14 = OpLabel
%15 = OpLoad %3 %5
OpStore %11 %15
%16 = OpLoad %3 %11
OpStore %11 %10
OpStore %5 %16
OpReturn
OpFunctionEnd
%18 = OpFunction %2 None %9
%17 = OpLabel
OpBranch %19
%19 = OpLabel
%20 = OpFunctionCall %2 %8
OpReturn
OpFunctionEnd
19 changes: 19 additions & 0 deletions naga/tests/in/spv/test.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[vk::binding(0, 0)]
RWStructuredBuffer<uint32_t> data;

[shader("compute")]
void main()
{
for (uint32_t i = 0; i < 4; i++)
{
if (data[i] == 1)
{
continue;
}
if (data[i] == 2)
{
break;
}
data[i] = i * 2; // Example operation: double each element
}
}
Binary file added naga/tests/in/spv/test.spv
Binary file not shown.
77 changes: 77 additions & 0 deletions naga/tests/in/spv/test.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
; SPIR-V
; Version: 1.5
; Generator: Khronos Slang Compiler; 0
; Bound: 39
; Schema: 0
OpCapability Shader
OpExtension "SPV_KHR_storage_buffer_storage_class"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main" %data
OpExecutionMode %main LocalSize 1 1 1
OpSource Slang 1
OpName %RWStructuredBuffer "RWStructuredBuffer"
OpName %data "data"
OpName %main "main"
OpDecorate %_runtimearr_uint ArrayStride 4
OpDecorate %RWStructuredBuffer Block
OpMemberDecorate %RWStructuredBuffer 0 Offset 0
OpDecorate %data Binding 0
OpDecorate %data DescriptorSet 0
%void = OpTypeVoid
%6 = OpTypeFunction %void
%uint = OpTypeInt 32 0
%uint_0 = OpConstant %uint 0
%bool = OpTypeBool
%uint_4 = OpConstant %uint 4
%int = OpTypeInt 32 1
%int_0 = OpConstant %int 0
%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
%_runtimearr_uint = OpTypeRuntimeArray %uint
%RWStructuredBuffer = OpTypeStruct %_runtimearr_uint
%_ptr_StorageBuffer_RWStructuredBuffer = OpTypePointer StorageBuffer %RWStructuredBuffer
%uint_1 = OpConstant %uint 1
%uint_2 = OpConstant %uint 2
%data = OpVariable %_ptr_StorageBuffer_RWStructuredBuffer StorageBuffer
%main = OpFunction %void None %6
%17 = OpLabel
OpBranch %18
%18 = OpLabel
%19 = OpPhi %uint %uint_0 %17 %20 %21
OpLoopMerge %22 %21 None
OpBranch %23
%23 = OpLabel
OpSelectionMerge %24 None
OpSwitch %int_0 %25
%25 = OpLabel
%26 = OpULessThan %bool %19 %uint_4
OpSelectionMerge %27 None
OpBranchConditional %26 %27 %28
%28 = OpLabel
OpBranch %22
%27 = OpLabel
%29 = OpAccessChain %_ptr_StorageBuffer_uint %data %int_0 %19
%30 = OpLoad %uint %29
%31 = OpIEqual %bool %30 %uint_1
OpSelectionMerge %32 None
OpBranchConditional %31 %33 %32
%33 = OpLabel
OpBranch %24
%32 = OpLabel
%34 = OpLoad %uint %29
%35 = OpIEqual %bool %34 %uint_2
OpSelectionMerge %36 None
OpBranchConditional %35 %37 %36
%37 = OpLabel
OpBranch %22
%36 = OpLabel
%38 = OpIMul %uint %19 %uint_2
OpStore %29 %38
OpBranch %24
%24 = OpLabel
%20 = OpIAdd %uint %19 %uint_1
OpBranch %21
%21 = OpLabel
OpBranch %18
%22 = OpLabel
OpReturn
OpFunctionEnd
10 changes: 10 additions & 0 deletions naga/tests/in/spv/test.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
targets = "HLSL | METAL | GLSL | WGSL | IR"

[spv]
force_loop_bounding = false

[hlsl]
force_loop_bounding = false

[msl]
force_loop_bounding = false
1 change: 1 addition & 0 deletions naga/tests/in/wgsl/load-elimination.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
targets = "SPIRV | METAL | GLSL | WGSL | HLSL | IR"
15 changes: 15 additions & 0 deletions naga/tests/in/wgsl/load-elimination.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
var<private> sink: u32 = 0;

fn simple() {
var a = sink;
let b = a;

a = 2u;

sink = b;
}

@compute @workgroup_size(1)
fn main() {
simple();
}
10 changes: 10 additions & 0 deletions naga/tests/in/wgsl/test.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
targets = "SPIRV | METAL | GLSL | WGSL| HLSL | IR"

[spv]
force_loop_bounding = false

[hlsl]
force_loop_bounding = false

[msl]
force_loop_bounding = false
11 changes: 11 additions & 0 deletions naga/tests/in/wgsl/test.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
@group(0) @binding(0)
var<storage, read_write> data: array<u32>;

@compute @workgroup_size(1)
fn main()
{
for (var i = 0u; i < 4; i++)
{
data[i] = i * 2; // Example operation: double each element
}
}
Loading
Loading