From 2f6b9b59f51e2f964e1fc639aadd97b8b7b6167f Mon Sep 17 00:00:00 2001 From: Peter McNeeley Date: Thu, 28 Nov 2024 20:19:42 +0000 Subject: [PATCH] [tint] Early evaluation errors for subgroupShuffle This covers the functions subgroupShuffle, subgroupShuffleUp, subgroupShuffleDown, and subgroupShuffleXor. There is a CTS in the works: https://github.com/gpuweb/cts/pull/4065/ Bug: 380862306 Change-Id: I0077557f62b4140bcbdd8601cbe6bc0a1933cf56 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/217074 Reviewed-by: dan sinclair Commit-Queue: Peter McNeeley --- .../wgsl/resolver/builtin_validation_test.cc | 69 +++++++++++++++++++ src/tint/lang/wgsl/resolver/resolver.cc | 8 +++ src/tint/lang/wgsl/resolver/validator.cc | 55 +++++++++++++++ src/tint/lang/wgsl/resolver/validator.h | 6 ++ 4 files changed, 138 insertions(+) diff --git a/src/tint/lang/wgsl/resolver/builtin_validation_test.cc b/src/tint/lang/wgsl/resolver/builtin_validation_test.cc index ade78ee41a8..01bf7a1d48b 100644 --- a/src/tint/lang/wgsl/resolver/builtin_validation_test.cc +++ b/src/tint/lang/wgsl/resolver/builtin_validation_test.cc @@ -834,6 +834,75 @@ TEST_F(ResolverBuiltinValidationTest, WorkgroupUniformLoad_AtomicInStruct) { R"(error: workgroupUniformLoad must not be called with an argument that contains an atomic type)"); } +TEST_F(ResolverBuiltinValidationTest, SubgroupShuffleLaneArgMustBeNonNeg) { + Enable(wgsl::Extension::kSubgroups); + Func("func", tint::Empty, ty.u32(), + Vector{ + Return(Call("subgroupShuffle", 1_u, Expr(Source{{12, 34}}, -1_i))), + }); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + R"(12:34 error: the sourceLaneIndex argument of subgroupShuffle must be greater than or equal to zero)"); +} + +TEST_F(ResolverBuiltinValidationTest, SubgroupShuffleLaneArgMustLessThan128Signed) { + Enable(wgsl::Extension::kSubgroups); + Func("func", tint::Empty, ty.u32(), + Vector{ + Return(Call("subgroupShuffle", 1_u, Expr(Source{{12, 34}}, 128_i))), + }); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + R"(12:34 error: the sourceLaneIndex argument of subgroupShuffle must be less than 128)"); +} + +TEST_F(ResolverBuiltinValidationTest, SubgroupShuffleLaneArgMustLessThan128) { + Enable(wgsl::Extension::kSubgroups); + Func("func", tint::Empty, ty.u32(), + Vector{ + Return(Call("subgroupShuffle", 1_u, Expr(Source{{12, 34}}, 128_u))), + }); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + R"(12:34 error: the sourceLaneIndex argument of subgroupShuffle must be less than 128)"); +} + +TEST_F(ResolverBuiltinValidationTest, SubgroupShuffleUpDeltaArgMustLessThan128) { + Enable(wgsl::Extension::kSubgroups); + Func("func", tint::Empty, ty.u32(), + Vector{ + Return(Call("subgroupShuffleUp", 1_u, Expr(Source{{12, 34}}, 128_u))), + }); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + R"(12:34 error: the delta argument of subgroupShuffleUp must be less than 128)"); +} + +TEST_F(ResolverBuiltinValidationTest, SubgroupShuffleDownDeltaArgMustLessThan128) { + Enable(wgsl::Extension::kSubgroups); + Func("func", tint::Empty, ty.u32(), + Vector{ + Return(Call("subgroupShuffleDown", 1_u, Expr(Source{{12, 34}}, 128_u))), + }); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + R"(12:34 error: the delta argument of subgroupShuffleDown must be less than 128)"); +} + +TEST_F(ResolverBuiltinValidationTest, SubgroupShuffleXorMaskArgMustLessThan128) { + Enable(wgsl::Extension::kSubgroups); + Func("func", tint::Empty, ty.u32(), + Vector{ + Return(Call("subgroupShuffleXor", 1_u, Expr(Source{{12, 34}}, 128_u))), + }); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + R"(12:34 error: the mask argument of subgroupShuffleXor must be less than 128)"); +} + TEST_F(ResolverBuiltinValidationTest, SubgroupBallotWithoutExtension) { // fn func { return subgroupBallot(true); } Func("func", tint::Empty, ty.vec4(), diff --git a/src/tint/lang/wgsl/resolver/resolver.cc b/src/tint/lang/wgsl/resolver/resolver.cc index 3bc3f75263f..894f5cb4031 100644 --- a/src/tint/lang/wgsl/resolver/resolver.cc +++ b/src/tint/lang/wgsl/resolver/resolver.cc @@ -2491,6 +2491,14 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr, return nullptr; } break; + case wgsl::BuiltinFn::kSubgroupShuffle: + case wgsl::BuiltinFn::kSubgroupShuffleUp: + case wgsl::BuiltinFn::kSubgroupShuffleDown: + case wgsl::BuiltinFn::kSubgroupShuffleXor: + if (!validator_.SubgroupShuffleFunction(fn, call)) { + return nullptr; + } + break; case wgsl::BuiltinFn::kQuadBroadcast: if (!validator_.QuadBroadcast(call)) { diff --git a/src/tint/lang/wgsl/resolver/validator.cc b/src/tint/lang/wgsl/resolver/validator.cc index 8cf0c3dd297..a17be301b35 100644 --- a/src/tint/lang/wgsl/resolver/validator.cc +++ b/src/tint/lang/wgsl/resolver/validator.cc @@ -1888,6 +1888,61 @@ bool Validator::BuiltinCall(const sem::Call* call) const { return true; } +bool Validator::SubgroupShuffleFunction(wgsl::BuiltinFn fn, const sem::Call* call) const { + auto* builtin = call->Target()->As(); + if (!builtin) { + return false; + } + + TINT_ASSERT(call->Arguments().Length() == 2); + auto* id = call->Arguments()[1]; + auto* constant_value = id->ConstantValue(); + + if (!constant_value) { + // Non const values are allowed as parameters. + return true; + } + + // User friendly param name. + std::string paramName = "sourceLaneIndex"; + switch (fn) { + case wgsl::BuiltinFn::kSubgroupShuffleXor: + paramName = "mask"; + break; + case wgsl::BuiltinFn::kSubgroupShuffleUp: + case wgsl::BuiltinFn::kSubgroupShuffleDown: + paramName = "delta"; + break; + default: + break; + } + + if (id->Type()->IsSignedIntegerScalar() && constant_value->ValueAs() < 0) { + AddError(id->Declaration()->source) + << "the " << paramName << " argument of " << builtin->str() + << " must be greater than or equal to zero"; + return false; + } + + if (id->Type()->IsSignedIntegerScalar() && + constant_value->ValueAs() >= tint::internal_limits::kMaxSubgroupSize) { + AddError(id->Declaration()->source) + << "the " << paramName << " argument of " << builtin->str() << " must be less than " + << tint::internal_limits::kMaxSubgroupSize; + return false; + } + + if (id->Type()->IsUnsignedIntegerScalar() && + constant_value->ValueAs() >= tint::internal_limits::kMaxSubgroupSize) { + AddError(id->Declaration()->source) + << "the " << paramName << " argument of " << builtin->str() << " must be less than " + << tint::internal_limits::kMaxSubgroupSize; + return false; + } + + return true; +} + bool Validator::TextureBuiltinFn(const sem::Call* call) const { auto* builtin = call->Target()->As(); if (!builtin) { diff --git a/src/tint/lang/wgsl/resolver/validator.h b/src/tint/lang/wgsl/resolver/validator.h index a46bb8c0f6d..36b34f73b52 100644 --- a/src/tint/lang/wgsl/resolver/validator.h +++ b/src/tint/lang/wgsl/resolver/validator.h @@ -528,6 +528,12 @@ class Validator { /// @returns true on success, false otherwise bool ArrayConstructor(const ast::CallExpression* ctor, const sem::Array* arr_type) const; + /// Validates a subgroupShuffle builtin functions including Up,Down, and Xor. + /// @param fn the builtin call type + /// @param call the builtin call to validate + /// @returns true on success, false otherwise + bool SubgroupShuffleFunction(wgsl::BuiltinFn fn, const sem::Call* call) const; + /// Validates a texture builtin function /// @param call the builtin call to validate /// @returns true on success, false otherwise