Skip to content

Commit

Permalink
[tint] Early evaluation errors for subgroupShuffle
Browse files Browse the repository at this point in the history
This covers the functions subgroupShuffle, subgroupShuffleUp,
subgroupShuffleDown, and subgroupShuffleXor.

There is a CTS in the works:
gpuweb/cts#4065


Bug: 380862306
Change-Id: I0077557f62b4140bcbdd8601cbe6bc0a1933cf56
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/217074
Reviewed-by: dan sinclair <[email protected]>
Commit-Queue: Peter McNeeley <[email protected]>
  • Loading branch information
Peter McNeeley authored and Dawn LUCI CQ committed Nov 28, 2024
1 parent d26ee47 commit 2f6b9b5
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 0 deletions.
69 changes: 69 additions & 0 deletions src/tint/lang/wgsl/resolver/builtin_validation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,75 @@ TEST_F(ResolverBuiltinValidationTest, WorkgroupUniformLoad_AtomicInStruct) {
R"(error: workgroupUniformLoad must not be called with an argument that contains an atomic type)");
}

TEST_F(ResolverBuiltinValidationTest, SubgroupShuffleLaneArgMustBeNonNeg) {
Enable(wgsl::Extension::kSubgroups);
Func("func", tint::Empty, ty.u32(),
Vector{
Return(Call("subgroupShuffle", 1_u, Expr(Source{{12, 34}}, -1_i))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(12:34 error: the sourceLaneIndex argument of subgroupShuffle must be greater than or equal to zero)");
}

TEST_F(ResolverBuiltinValidationTest, SubgroupShuffleLaneArgMustLessThan128Signed) {
Enable(wgsl::Extension::kSubgroups);
Func("func", tint::Empty, ty.u32(),
Vector{
Return(Call("subgroupShuffle", 1_u, Expr(Source{{12, 34}}, 128_i))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(12:34 error: the sourceLaneIndex argument of subgroupShuffle must be less than 128)");
}

TEST_F(ResolverBuiltinValidationTest, SubgroupShuffleLaneArgMustLessThan128) {
Enable(wgsl::Extension::kSubgroups);
Func("func", tint::Empty, ty.u32(),
Vector{
Return(Call("subgroupShuffle", 1_u, Expr(Source{{12, 34}}, 128_u))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(12:34 error: the sourceLaneIndex argument of subgroupShuffle must be less than 128)");
}

TEST_F(ResolverBuiltinValidationTest, SubgroupShuffleUpDeltaArgMustLessThan128) {
Enable(wgsl::Extension::kSubgroups);
Func("func", tint::Empty, ty.u32(),
Vector{
Return(Call("subgroupShuffleUp", 1_u, Expr(Source{{12, 34}}, 128_u))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(12:34 error: the delta argument of subgroupShuffleUp must be less than 128)");
}

TEST_F(ResolverBuiltinValidationTest, SubgroupShuffleDownDeltaArgMustLessThan128) {
Enable(wgsl::Extension::kSubgroups);
Func("func", tint::Empty, ty.u32(),
Vector{
Return(Call("subgroupShuffleDown", 1_u, Expr(Source{{12, 34}}, 128_u))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(12:34 error: the delta argument of subgroupShuffleDown must be less than 128)");
}

TEST_F(ResolverBuiltinValidationTest, SubgroupShuffleXorMaskArgMustLessThan128) {
Enable(wgsl::Extension::kSubgroups);
Func("func", tint::Empty, ty.u32(),
Vector{
Return(Call("subgroupShuffleXor", 1_u, Expr(Source{{12, 34}}, 128_u))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(12:34 error: the mask argument of subgroupShuffleXor must be less than 128)");
}

TEST_F(ResolverBuiltinValidationTest, SubgroupBallotWithoutExtension) {
// fn func { return subgroupBallot(true); }
Func("func", tint::Empty, ty.vec4<u32>(),
Expand Down
8 changes: 8 additions & 0 deletions src/tint/lang/wgsl/resolver/resolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2491,6 +2491,14 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr,
return nullptr;
}
break;
case wgsl::BuiltinFn::kSubgroupShuffle:
case wgsl::BuiltinFn::kSubgroupShuffleUp:
case wgsl::BuiltinFn::kSubgroupShuffleDown:
case wgsl::BuiltinFn::kSubgroupShuffleXor:
if (!validator_.SubgroupShuffleFunction(fn, call)) {
return nullptr;
}
break;

case wgsl::BuiltinFn::kQuadBroadcast:
if (!validator_.QuadBroadcast(call)) {
Expand Down
55 changes: 55 additions & 0 deletions src/tint/lang/wgsl/resolver/validator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1888,6 +1888,61 @@ bool Validator::BuiltinCall(const sem::Call* call) const {
return true;
}

bool Validator::SubgroupShuffleFunction(wgsl::BuiltinFn fn, const sem::Call* call) const {
auto* builtin = call->Target()->As<sem::BuiltinFn>();
if (!builtin) {
return false;
}

TINT_ASSERT(call->Arguments().Length() == 2);
auto* id = call->Arguments()[1];
auto* constant_value = id->ConstantValue();

if (!constant_value) {
// Non const values are allowed as parameters.
return true;
}

// User friendly param name.
std::string paramName = "sourceLaneIndex";
switch (fn) {
case wgsl::BuiltinFn::kSubgroupShuffleXor:
paramName = "mask";
break;
case wgsl::BuiltinFn::kSubgroupShuffleUp:
case wgsl::BuiltinFn::kSubgroupShuffleDown:
paramName = "delta";
break;
default:
break;
}

if (id->Type()->IsSignedIntegerScalar() && constant_value->ValueAs<i32>() < 0) {
AddError(id->Declaration()->source)
<< "the " << paramName << " argument of " << builtin->str()
<< " must be greater than or equal to zero";
return false;
}

if (id->Type()->IsSignedIntegerScalar() &&
constant_value->ValueAs<i32>() >= tint::internal_limits::kMaxSubgroupSize) {
AddError(id->Declaration()->source)
<< "the " << paramName << " argument of " << builtin->str() << " must be less than "
<< tint::internal_limits::kMaxSubgroupSize;
return false;
}

if (id->Type()->IsUnsignedIntegerScalar() &&
constant_value->ValueAs<u32>() >= tint::internal_limits::kMaxSubgroupSize) {
AddError(id->Declaration()->source)
<< "the " << paramName << " argument of " << builtin->str() << " must be less than "
<< tint::internal_limits::kMaxSubgroupSize;
return false;
}

return true;
}

bool Validator::TextureBuiltinFn(const sem::Call* call) const {
auto* builtin = call->Target()->As<sem::BuiltinFn>();
if (!builtin) {
Expand Down
6 changes: 6 additions & 0 deletions src/tint/lang/wgsl/resolver/validator.h
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,12 @@ class Validator {
/// @returns true on success, false otherwise
bool ArrayConstructor(const ast::CallExpression* ctor, const sem::Array* arr_type) const;

/// Validates a subgroupShuffle builtin functions including Up,Down, and Xor.
/// @param fn the builtin call type
/// @param call the builtin call to validate
/// @returns true on success, false otherwise
bool SubgroupShuffleFunction(wgsl::BuiltinFn fn, const sem::Call* call) const;

/// Validates a texture builtin function
/// @param call the builtin call to validate
/// @returns true on success, false otherwise
Expand Down

0 comments on commit 2f6b9b5

Please sign in to comment.