Skip to content

Non uniform for everything! #177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
use super::Builder;
use crate::builder_spirv::{SpirvValue, SpirvValueExt, SpirvValueKind};
use crate::spirv_type::SpirvType;
use rspirv::spirv::Word;
use rspirv::spirv::{Decoration, Word};
use rustc_codegen_spirv_types::Capability;
use rustc_codegen_ssa::traits::BuilderMethods;
use rustc_errors::ErrorGuaranteed;
use rustc_span::DUMMY_SP;
Expand Down Expand Up @@ -41,11 +42,20 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
};
let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
let u32_ptr = self.type_ptr_to(u32_ty);
let array = array.def(self);
let actual_index = actual_index.def(self);
let ptr = self
.emit()
.in_bounds_access_chain(u32_ptr, None, array.def(self), [actual_index.def(self)])
.in_bounds_access_chain(u32_ptr, None, array, [actual_index])
.unwrap()
.with_type(u32_ptr);
if self.builder.has_capability(Capability::ShaderNonUniform) {
// apply NonUniform to the operation and the index
self.emit()
.decorate(ptr.def(self), Decoration::NonUniform, []);
self.emit()
.decorate(actual_index, Decoration::NonUniform, []);
}
self.load(u32_ty, ptr, Align::ONE)
}

Expand Down Expand Up @@ -233,11 +243,20 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
};
let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
let u32_ptr = self.type_ptr_to(u32_ty);
let array = array.def(self);
let actual_index = actual_index.def(self);
let ptr = self
.emit()
.in_bounds_access_chain(u32_ptr, None, array.def(self), [actual_index.def(self)])
.in_bounds_access_chain(u32_ptr, None, array, [actual_index])
.unwrap()
.with_type(u32_ptr);
if self.builder.has_capability(Capability::ShaderNonUniform) {
// apply NonUniform to the operation and the index
self.emit()
.decorate(ptr.def(self), Decoration::NonUniform, []);
self.emit()
.decorate(actual_index, Decoration::NonUniform, []);
}
self.store(value, ptr, Align::ONE);
Ok(())
}
Expand Down
5 changes: 5 additions & 0 deletions crates/rustc_codegen_spirv/src/linker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,11 @@ pub fn link(
duplicates::remove_duplicate_debuginfo(&mut output);
}

{
let _timer = sess.timer("link_remove_non_uniform");
simple_passes::remove_non_uniform_decorations(sess, &mut output)?;
}

// NOTE(eddyb) SPIR-T pipeline is entirely limited to this block.
{
let (spv_words, module_or_err, lower_from_spv_timer) =
Expand Down
19 changes: 18 additions & 1 deletion crates/rustc_codegen_spirv/src/linker/simple_passes.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::{Result, get_name, get_names};
use rspirv::dr::{Block, Function, Module};
use rspirv::spirv::{ExecutionModel, Op, Word};
use rspirv::spirv::{Decoration, ExecutionModel, Op, Word};
use rustc_codegen_spirv_types::Capability;
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_session::Session;
use std::iter::once;
Expand Down Expand Up @@ -264,3 +265,19 @@ pub fn check_fragment_insts(sess: &Session, module: &Module) -> Result<()> {
}
}
}

/// Remove all [`Decoration::NonUniform`] if this module does *not* have [`Capability::ShaderNonUniform`].
/// This allows image asm to always declare `NonUniform` and not worry about conditional compilation.
pub fn remove_non_uniform_decorations(_sess: &Session, module: &mut Module) -> Result<()> {
let has_shader_non_uniform_capability = module.capabilities.iter().any(|inst| {
inst.class.opcode == Op::Capability
&& inst.operands[0].unwrap_capability() == Capability::ShaderNonUniform
});
if !has_shader_non_uniform_capability {
module.annotations.retain(|inst| {
!(inst.class.opcode == Op::Decorate
&& inst.operands[1].unwrap_decoration() == Decoration::NonUniform)
});
}
Ok(())
}
Loading