Skip to content

[msl-out] Fix ReadZeroSkipWrite bounds check mode for pointer arguments #7323

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions naga/src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@ holding the result.
[msl]: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
[all-atom]: crate::valid::Capabilities::SHADER_INT64_ATOMIC_ALL_OPS

## Pointer-typed bounds-checked expressions and OOB locals

MSL (unlike HLSL and GLSL) has native support for pointer-typed function
arguments. When the [`BoundsCheckPolicy`] is `ReadZeroSkipWrite` and an
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: It'd be nice to doc-link ReadZeroSkipWrite here, but I see that some existing commentary has also not linked it.

out-of-bounds index expression is used for such an argument, our strategy is to
pass a pointer to a dummy variable. These dummy variables are called "OOB
locals". We emit at most one OOB local per function for each type, since all
expressions producing a result of that type can share the same OOB local. (Note
that the OOB local mechanism is not actually implementing "skip write", nor even
"read zero" in some cases of read-after-write, but doing so would require
additional effort and the difference is unlikely to matter.)

[`BoundsCheckPolicy`]: crate::proc::BoundsCheckPolicy

*/

use alloc::{
Expand Down
269 changes: 162 additions & 107 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ use crate::{
arena::{Handle, HandleSet},
back::{self, Baked},
common,
proc::{self, index, NameKey, TypeResolution},
proc::{
self,
index::{self, BoundsCheck},
NameKey, TypeResolution,
},
valid, FastHashMap, FastHashSet,
};

Expand Down Expand Up @@ -599,11 +603,34 @@ impl crate::Type {
}
}

#[derive(Clone, Copy)]
enum FunctionOrigin {
Handle(Handle<crate::Function>),
EntryPoint(proc::EntryPointIndex),
}

trait NameKeyExt {
fn local(origin: FunctionOrigin, local_handle: Handle<crate::LocalVariable>) -> NameKey {
match origin {
FunctionOrigin::Handle(handle) => NameKey::FunctionLocal(handle, local_handle),
FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointLocal(idx, local_handle),
}
}

/// Return the name key for a local variable used by ReadZeroSkipWrite bounds-check
/// policy when it needs to produce a pointer-typed result for an OOB access. These
/// are unique per accessed type, so the second argument is a type handle. See docs
/// for [`crate::back::msl`].
fn oob_local_for_type(origin: FunctionOrigin, ty: Handle<crate::Type>) -> NameKey {
match origin {
FunctionOrigin::Handle(handle) => NameKey::FunctionOobLocal(handle, ty),
FunctionOrigin::EntryPoint(idx) => NameKey::EntryPointOobLocal(idx, ty),
}
}
}

impl NameKeyExt for NameKey {}

/// A level of detail argument.
///
/// When [`BoundsCheckPolicy::Restrict`] applies to an [`ImageLoad`] access, we
Expand Down Expand Up @@ -681,6 +708,7 @@ impl<'a> ExpressionContext<'a> {
.choose_policy(pointer, &self.module.types, self.info)
}

/// See docs for [`proc::index::access_needs_check`].
fn access_needs_check(
&self,
base: Handle<crate::Expression>,
Expand All @@ -695,6 +723,19 @@ impl<'a> ExpressionContext<'a> {
)
}

/// See docs for [`proc::index::bounds_check_iter`].
fn bounds_check_iter(
&self,
chain: Handle<crate::Expression>,
) -> impl Iterator<Item = BoundsCheck> + '_ {
index::bounds_check_iter(chain, self.module, self.function, self.info)
}

/// See docs for [`proc::index::oob_local_types`].
fn oob_local_types(&self) -> FastHashSet<Handle<crate::Type>> {
index::oob_local_types(self.module, self.function, self.info, self.policies)
}

fn get_packed_vec_kind(&self, expr_handle: Handle<crate::Expression>) -> Option<crate::Scalar> {
match self.function.expressions[expr_handle] {
crate::Expression::AccessIndex { base, index } => {
Expand Down Expand Up @@ -902,6 +943,59 @@ impl<W: Write> Writer<W> {
Ok(())
}

/// Writes the local variables of the given function, as well as any extra
/// out-of-bounds locals that are needed.
///
/// The names of the OOB locals are also added to `self.names` at the same
/// time.
fn put_locals(&mut self, context: &ExpressionContext) -> BackendResult {
let oob_local_types = context.oob_local_types();
for &ty in oob_local_types.iter() {
let name_key = NameKey::oob_local_for_type(context.origin, ty);
self.names.insert(name_key, self.namer.call("oob"));
}

for (name_key, ty, init) in context
.function
.local_variables
.iter()
.map(|(local_handle, local)| {
let name_key = NameKey::local(context.origin, local_handle);
(name_key, local.ty, local.init)
})
.chain(oob_local_types.iter().map(|&ty| {
let name_key = NameKey::oob_local_for_type(context.origin, ty);
(name_key, ty, None)
}))
{
let ty_name = TypeContext {
handle: ty,
gctx: context.module.to_ctx(),
names: &self.names,
access: crate::StorageAccess::empty(),
first_time: false,
};
write!(
self.out,
"{}{} {}",
back::INDENT,
ty_name,
self.names[&name_key]
)?;
match init {
Some(value) => {
write!(self.out, " = ")?;
self.put_expression(value, context, true)?;
}
None => {
write!(self.out, " = {{}}")?;
}
};
writeln!(self.out, ";")?;
}
Ok(())
}

fn put_level_of_detail(
&mut self,
level: LevelOfDetail,
Expand Down Expand Up @@ -1660,7 +1754,6 @@ impl<W: Write> Writer<W> {
}

let expression = &context.function.expressions[expr_handle];
log::trace!("expression {:?} = {:?}", expr_handle, expression);
match *expression {
crate::Expression::Literal(_)
| crate::Expression::Constant(_)
Expand Down Expand Up @@ -1696,7 +1789,42 @@ impl<W: Write> Writer<W> {
{
write!(self.out, " ? ")?;
self.put_access_chain(expr_handle, policy, context)?;
write!(self.out, " : DefaultConstructible()")?;
write!(self.out, " : ")?;

if context.resolve_type(base).pointer_space().is_some() {
// We can't just use `DefaultConstructible` if this is a pointer.
// Instead, we create a dummy local variable to serve as pointer
// target if the access is out of bounds.
let result_ty = context.info[expr_handle]
.ty
.inner_with(&context.module.types)
.pointer_base_type();
let result_ty_handle = match result_ty {
Some(TypeResolution::Handle(handle)) => handle,
Some(TypeResolution::Value(_)) => {
// As long as the result of a pointer access expression is
// passed to a function or stored in a let binding, the
// type will be in the arena. If additional uses of
// pointers become valid, this assumption might no longer
// hold. Note that the LHS of a load or store doesn't
// take this path -- there is dedicated code in `put_load`
// and `put_store`.
unreachable!(
"Expected type {result_ty:?} of access through pointer type {base:?} to be in the arena",
);
}
None => {
unreachable!(
"Expected access through pointer type {base:?} to return a pointer, but got {result_ty:?}",
)
}
};
let name_key =
NameKey::oob_local_for_type(context.origin, result_ty_handle);
self.out.write_str(&self.names[&name_key])?;
} else {
write!(self.out, "DefaultConstructible()")?;
}

if !is_scoped {
write!(self.out, ")")?;
Expand Down Expand Up @@ -1736,14 +1864,7 @@ impl<W: Write> Writer<W> {
write!(self.out, "{name}")?;
}
crate::Expression::LocalVariable(handle) => {
let name_key = match context.origin {
FunctionOrigin::Handle(fun_handle) => {
NameKey::FunctionLocal(fun_handle, handle)
}
FunctionOrigin::EntryPoint(ep_index) => {
NameKey::EntryPointLocal(ep_index, handle)
}
};
let name_key = NameKey::local(context.origin, handle);
let name = &self.names[&name_key];
write!(self.out, "{name}")?;
}
Expand Down Expand Up @@ -2647,68 +2768,44 @@ impl<W: Write> Writer<W> {
#[allow(unused_variables)]
fn put_bounds_checks(
&mut self,
mut chain: Handle<crate::Expression>,
chain: Handle<crate::Expression>,
context: &ExpressionContext,
level: back::Level,
prefix: &'static str,
) -> Result<bool, Error> {
let mut check_written = false;

// Iterate over the access chain, handling each expression.
loop {
// Produce a `GuardedIndex`, so we can shared code between the
// `Access` and `AccessIndex` cases.
let (base, guarded_index) = match context.function.expressions[chain] {
crate::Expression::Access { base, index } => {
(base, Some(index::GuardedIndex::Expression(index)))
}
crate::Expression::AccessIndex { base, index } => {
// Don't try to check indices into structs. Validation already took
// care of them, and index::needs_guard doesn't handle that case.
let mut base_inner = context.resolve_type(base);
if let crate::TypeInner::Pointer { base, .. } = *base_inner {
base_inner = &context.module.types[base].inner;
}
match *base_inner {
crate::TypeInner::Struct { .. } => (base, None),
_ => (base, Some(index::GuardedIndex::Known(index))),
}
}
_ => break,
};
// Iterate over the access chain, handling each required bounds check.
for item in context.bounds_check_iter(chain) {
let BoundsCheck {
base,
index,
length,
} = item;

if let Some(index) = guarded_index {
if let Some(length) = context.access_needs_check(base, index) {
if check_written {
write!(self.out, " && ")?;
} else {
write!(self.out, "{level}{prefix}")?;
check_written = true;
}
if check_written {
write!(self.out, " && ")?;
} else {
write!(self.out, "{level}{prefix}")?;
check_written = true;
}

// Check that the index falls within bounds. Do this with a single
// comparison, by casting the index to `uint` first, so that negative
// indices become large positive values.
write!(self.out, "uint(")?;
self.put_index(index, context, true)?;
self.out.write_str(") < ")?;
match length {
index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
index::IndexableLength::Dynamic => {
let global =
context.function.originating_global(base).ok_or_else(|| {
Error::GenericValidation(
"Could not find originating global".into(),
)
})?;
write!(self.out, "1 + ")?;
self.put_dynamic_array_max_index(global, context)?
}
}
// Check that the index falls within bounds. Do this with a single
// comparison, by casting the index to `uint` first, so that negative
// indices become large positive values.
write!(self.out, "uint(")?;
self.put_index(index, context, true)?;
self.out.write_str(") < ")?;
match length {
index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
index::IndexableLength::Dynamic => {
let global = context.function.originating_global(base).ok_or_else(|| {
Error::GenericValidation("Could not find originating global".into())
})?;
write!(self.out, "1 + ")?;
self.put_dynamic_array_max_index(global, context)?
}
}

chain = base
}

Ok(check_written)
Expand Down Expand Up @@ -5694,28 +5791,7 @@ template <typename A>
result_struct: None,
};

for (local_handle, local) in fun.local_variables.iter() {
let ty_name = TypeContext {
handle: local.ty,
gctx: module.to_ctx(),
names: &self.names,
access: crate::StorageAccess::empty(),
first_time: false,
};
let local_name = &self.names[&NameKey::FunctionLocal(fun_handle, local_handle)];
write!(self.out, "{}{} {}", back::INDENT, ty_name, local_name)?;
match local.init {
Some(value) => {
write!(self.out, " = ")?;
self.put_expression(value, &context.expression, true)?;
}
None => {
write!(self.out, " = {{}}")?;
}
};
writeln!(self.out, ";")?;
}

self.put_locals(&context.expression)?;
self.update_expressions_to_bake(fun, fun_info, &context.expression);
self.put_block(back::Level(1), &fun.body, &context)?;
writeln!(self.out, "}}")?;
Expand Down Expand Up @@ -6627,28 +6703,7 @@ template <typename A>

// Finally, declare all the local variables that we need
//TODO: we can postpone this till the relevant expressions are emitted
for (local_handle, local) in fun.local_variables.iter() {
let name = &self.names[&NameKey::EntryPointLocal(ep_index as _, local_handle)];
let ty_name = TypeContext {
handle: local.ty,
gctx: module.to_ctx(),
names: &self.names,
access: crate::StorageAccess::empty(),
first_time: false,
};
write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?;
match local.init {
Some(value) => {
write!(self.out, " = ")?;
self.put_expression(value, &context.expression, true)?;
}
None => {
write!(self.out, " = {{}}")?;
}
};
writeln!(self.out, ";")?;
}

self.put_locals(&context.expression)?;
self.update_expressions_to_bake(fun, fun_info, &context.expression);
self.put_block(back::Level(1), &fun.body, &context)?;
writeln!(self.out, "}}")?;
Expand Down
Loading