diff --git a/source/val/validate_type.cpp b/source/val/validate_type.cpp index 4d673b4a7fc..ff586bdcc3f 100644 --- a/source/val/validate_type.cpp +++ b/source/val/validate_type.cpp @@ -19,29 +19,40 @@ #include "source/val/instruction.h" #include "source/val/validate.h" #include "source/val/validation_state.h" +#include "spirv/unified1/spirv.h" namespace spvtools { namespace val { namespace { -// True if the integer constant is > 0. |const_words| are words of the -// constant-defining instruction (either OpConstant or -// OpSpecConstant). typeWords are the words of the constant's-type-defining -// OpTypeInt. -bool AboveZero(const std::vector& const_words, - const std::vector& type_words) { - const uint32_t width = type_words[2]; - const bool is_signed = type_words[3] > 0; +// Returns, as an int64_t, the literal value from an OpConstant or the +// default value of an OpSpecConstant, assuming it is an integral type. +// For signed integers, relies the rule that literal value is sign extended +// to fill out to word granularity. Assumes that the constant value +// has +int64_t ConstantLiteralAsInt64(uint32_t width, + const std::vector& const_words) { const uint32_t lo_word = const_words[3]; - if (width > 32) { - // The spec currently doesn't allow integers wider than 64 bits. - const uint32_t hi_word = const_words[4]; // Must exist, per spec. - if (is_signed && (hi_word >> 31)) return false; - return (lo_word | hi_word) > 0; - } else { - if (is_signed && (lo_word >> 31)) return false; - return lo_word > 0; - } + if (width <= 32) return int32_t(lo_word); + assert(width <= 64); + assert(const_words.size() > 4); + const uint32_t hi_word = const_words[4]; // Must exist, per spec. + return static_cast(uint64_t(lo_word) | uint64_t(hi_word) << 32); +} + +// Returns, as an uint64_t, the literal value from an OpConstant or the +// default value of an OpSpecConstant, assuming it is an integral type. +// For signed integers, relies the rule that literal value is sign extended +// to fill out to word granularity. Assumes that the constant value +// has +int64_t ConstantLiteralAsUint64(uint32_t width, + const std::vector& const_words) { + const uint32_t lo_word = const_words[3]; + if (width <= 32) return lo_word; + assert(width <= 64); + assert(const_words.size() > 4); + const uint32_t hi_word = const_words[4]; // Must exist, per spec. + return (uint64_t(lo_word) | uint64_t(hi_word) << 32); } // Validates that type declarations are unique, unless multiple declarations @@ -258,14 +269,33 @@ spv_result_t ValidateTypeArray(ValidationState_t& _, const Instruction* inst) { switch (length->opcode()) { case SpvOpSpecConstant: - case SpvOpConstant: - if (AboveZero(length->words(), const_result_type->words())) break; - // Else fall through! - case SpvOpConstantNull: { + case SpvOpConstant: { + auto& type_words = const_result_type->words(); + const bool is_signed = type_words[3] > 0; + const uint32_t width = type_words[2]; + const int64_t ivalue = ConstantLiteralAsInt64(width, length->words()); + const uint64_t uvalue = ConstantLiteralAsUint64(width, length->words()); + if (ivalue == 0 || (ivalue < 0 && is_signed)) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeArray Length '" << _.getIdName(length_id) + << "' default value must be at least 1: found " << ivalue; + } + if (spvIsWebGPUEnv(_.context()->target_env)) { + // WebGPU has maximum integer width of 32 bits, and max array size + // is one more than the max signed integer representation. + const uint64_t max_permitted = (uint64_t(1) << 31); + if (uvalue > max_permitted) { + return _.diag(SPV_ERROR_INVALID_ID, inst) + << "OpTypeArray Length '" << _.getIdName(length_id) + << "' size exceeds max value " << max_permitted + << " permitted by WebGPU: got " << uvalue; + } + } + } break; + case SpvOpConstantNull: return _.diag(SPV_ERROR_INVALID_ID, inst) << "OpTypeArray Length '" << _.getIdName(length_id) << "' default value must be at least 1."; - } case SpvOpSpecConstantOp: // Assume it's OK, rather than try to evaluate the operation. break; diff --git a/test/val/val_id_test.cpp b/test/val/val_id_test.cpp index 327aef19ec3..2cae7b09f32 100644 --- a/test/val/val_id_test.cpp +++ b/test/val/val_id_test.cpp @@ -749,20 +749,40 @@ TEST_F(ValidateIdWithMessage, OpTypeArrayElementTypeBad) { // Signed or unsigned. enum Signed { kSigned, kUnsigned }; -// Creates an assembly snippet declaring OpTypeArray with the given length. -std::string MakeArrayLength(const std::string& len, Signed isSigned, - int width) { +// Creates an assembly module declaring OpTypeArray with the given length. +std::string MakeArrayLength(const std::string& len, Signed isSigned, int width, + int max_int_width = 64, + bool use_vulkan_memory_model = false) { std::ostringstream ss; ss << R"( OpCapability Shader - OpCapability Linkage - OpCapability Int16 - OpCapability Int64 )"; - ss << "OpMemoryModel Logical GLSL450\n"; + if (use_vulkan_memory_model) { + ss << " OpCapability VulkanMemoryModel\n"; + } + if (width == 16) { + ss << " OpCapability Int16\n"; + } + if (max_int_width > 32) { + ss << "\n OpCapability Int64\n"; + } + if (use_vulkan_memory_model) { + ss << " OpExtension \"SPV_KHR_vulkan_memory_model\"\n"; + ss << "OpMemoryModel Logical Vulkan\n"; + } else { + ss << "OpMemoryModel Logical GLSL450\n"; + } + ss << "OpEntryPoint GLCompute %main \"main\"\n"; + ss << "OpExecutionMode %main LocalSize 1 1 1\n"; ss << " %t = OpTypeInt " << width << (isSigned == kSigned ? " 1" : " 0"); ss << " %l = OpConstant %t " << len; ss << " %a = OpTypeArray %t %l"; + ss << " %void = OpTypeVoid \n" + " %voidfn = OpTypeFunction %void \n" + " %main = OpFunction %void None %voidfn \n" + " %entry = OpLabel\n" + " OpReturn\n" + " OpFunctionEnd\n"; return ss.str(); } @@ -772,7 +792,8 @@ class OpTypeArrayLengthTest : public spvtest::TextToBinaryTestBase<::testing::TestWithParam> { protected: OpTypeArrayLengthTest() - : position_(spv_position_t{0, 0, 0}), + : env_(SPV_ENV_UNIVERSAL_1_0), + position_(spv_position_t{0, 0, 0}), diagnostic_(spvDiagnosticCreate(&position_, "")) {} ~OpTypeArrayLengthTest() { spvDiagnosticDestroy(diagnostic_); } @@ -783,7 +804,7 @@ class OpTypeArrayLengthTest spvDiagnosticDestroy(diagnostic_); diagnostic_ = nullptr; const auto status = - spvValidate(ScopedContext().context, &cbinary, &diagnostic_); + spvValidate(ScopedContext(env_).context, &cbinary, &diagnostic_); if (status != SPV_SUCCESS) { spvDiagnosticPrint(diagnostic_); EXPECT_THAT(std::string(diagnostic_->error), @@ -792,12 +813,15 @@ class OpTypeArrayLengthTest return status; } + protected: + spv_target_env env_; + private: spv_position_t position_; // For creating diagnostic_. spv_diagnostic diagnostic_; }; -TEST_P(OpTypeArrayLengthTest, LengthPositive) { +TEST_P(OpTypeArrayLengthTest, LengthPositiveSmall) { const int width = GetParam(); EXPECT_EQ(SPV_SUCCESS, Val(CompileSuccessfully(MakeArrayLength("1", kSigned, width)))); @@ -814,42 +838,113 @@ TEST_P(OpTypeArrayLengthTest, LengthPositive) { const std::string fpad(width / 4 - 1, 'F'); EXPECT_EQ( SPV_SUCCESS, - Val(CompileSuccessfully(MakeArrayLength("0x7" + fpad, kSigned, width)))); - EXPECT_EQ(SPV_SUCCESS, Val(CompileSuccessfully( - MakeArrayLength("0xF" + fpad, kUnsigned, width)))); + Val(CompileSuccessfully(MakeArrayLength("0x7" + fpad, kSigned, width)))) + << MakeArrayLength("0x7" + fpad, kSigned, width); } TEST_P(OpTypeArrayLengthTest, LengthZero) { const int width = GetParam(); - EXPECT_EQ(SPV_ERROR_INVALID_ID, - Val(CompileSuccessfully(MakeArrayLength("0", kSigned, width)), - "OpTypeArray Length '2\\[%.*\\]' default value must be at " - "least 1.")); - EXPECT_EQ(SPV_ERROR_INVALID_ID, - Val(CompileSuccessfully(MakeArrayLength("0", kUnsigned, width)), - "OpTypeArray Length '2\\[%.*\\]' default value must be at " - "least 1.")); + EXPECT_EQ( + SPV_ERROR_INVALID_ID, + Val(CompileSuccessfully(MakeArrayLength("0", kSigned, width)), + "OpTypeArray Length '[0-9]\\[%.*\\]' default value must be at " + "least 1.")); + EXPECT_EQ( + SPV_ERROR_INVALID_ID, + Val(CompileSuccessfully(MakeArrayLength("0", kUnsigned, width)), + "OpTypeArray Length '[0-9]\\[%.*\\]' default value must be at " + "least 1.")); } TEST_P(OpTypeArrayLengthTest, LengthNegative) { const int width = GetParam(); - EXPECT_EQ(SPV_ERROR_INVALID_ID, - Val(CompileSuccessfully(MakeArrayLength("-1", kSigned, width)), - "OpTypeArray Length '2\\[%.*\\]' default value must be at " - "least 1.")); - EXPECT_EQ(SPV_ERROR_INVALID_ID, - Val(CompileSuccessfully(MakeArrayLength("-2", kSigned, width)), - "OpTypeArray Length '2\\[%.*\\]' default value must be at " - "least 1.")); - EXPECT_EQ(SPV_ERROR_INVALID_ID, - Val(CompileSuccessfully(MakeArrayLength("-123", kSigned, width)), - "OpTypeArray Length '2\\[%.*\\]' default value must be at " - "least 1.")); + EXPECT_EQ( + SPV_ERROR_INVALID_ID, + Val(CompileSuccessfully(MakeArrayLength("-1", kSigned, width)), + "OpTypeArray Length '[0-9]\\[%.*\\]' default value must be at " + "least 1.")); + EXPECT_EQ( + SPV_ERROR_INVALID_ID, + Val(CompileSuccessfully(MakeArrayLength("-2", kSigned, width)), + "OpTypeArray Length '[0-9]\\[%.*\\]' default value must be at " + "least 1.")); + EXPECT_EQ( + SPV_ERROR_INVALID_ID, + Val(CompileSuccessfully(MakeArrayLength("-123", kSigned, width)), + "OpTypeArray Length '[0-9]\\[%.*\\]' default value must be at " + "least 1.")); const std::string neg_max = "0x8" + std::string(width / 4 - 1, '0'); - EXPECT_EQ(SPV_ERROR_INVALID_ID, - Val(CompileSuccessfully(MakeArrayLength(neg_max, kSigned, width)), - "OpTypeArray Length '2\\[%.*\\]' default value must be at " - "least 1.")); + EXPECT_EQ( + SPV_ERROR_INVALID_ID, + Val(CompileSuccessfully(MakeArrayLength(neg_max, kSigned, width)), + "OpTypeArray Length '[0-9]\\[%.*\\]' default value must be at " + "least 1.")); +} + +// Returns the string form of an integer of the form 0x80....0 of the +// given bit width. +std::string big_num_ending_0(int bit_width) { + return "0x8" + std::string(bit_width / 4 - 1, '0'); +} + +// Returns the string form of an integer of the form 0x80..001 of the +// given bit width. +std::string big_num_ending_1(int bit_width) { + return "0x8" + std::string(bit_width / 4 - 2, '0') + "1"; +} + +TEST_P(OpTypeArrayLengthTest, LengthPositiveHugeEnding0InVulkan) { + env_ = SPV_ENV_VULKAN_1_0; + const int width = GetParam(); + for (int max_int_width : {32, 64}) { + if (width > max_int_width) { + // Not valid to even make the OpConstant in this case. + continue; + } + const auto module = CompileSuccessfully(MakeArrayLength( + big_num_ending_0(width), kUnsigned, width, max_int_width)); + EXPECT_EQ(SPV_SUCCESS, Val(module)); + } +} + +TEST_P(OpTypeArrayLengthTest, LengthPositiveHugeEnding1InVulkan) { + env_ = SPV_ENV_VULKAN_1_0; + const int width = GetParam(); + for (int max_int_width : {32, 64}) { + if (width > max_int_width) { + // Not valid to even make the OpConstant in this case. + continue; + } + const auto module = CompileSuccessfully(MakeArrayLength( + big_num_ending_1(width), kUnsigned, width, max_int_width)); + EXPECT_EQ(SPV_SUCCESS, Val(module)); + } +} + +TEST_P(OpTypeArrayLengthTest, LengthPositiveHugeEnding0InWebGPU) { + env_ = SPV_ENV_WEBGPU_0; + const int width = GetParam(); + // WebGPU only has 32 bit integers. + if (width != 32) return; + const int max_int_width = 32; + const auto module = CompileSuccessfully(MakeArrayLength( + big_num_ending_0(width), kUnsigned, width, max_int_width, true)); + EXPECT_EQ(SPV_SUCCESS, Val(module)); +} + +TEST_P(OpTypeArrayLengthTest, LengthPositiveHugeEnding1InWebGPU) { + env_ = SPV_ENV_WEBGPU_0; + const int width = GetParam(); + // WebGPU only has 32 bit integers. + if (width != 32) return; + const int max_int_width = 32; + const auto module = CompileSuccessfully(MakeArrayLength( + big_num_ending_1(width), kUnsigned, width, max_int_width, true)); + EXPECT_EQ( + SPV_ERROR_INVALID_ID, + Val(module, + "OpTypeArray Length '[0-9]\\[%.*\\]' size exceeds max value " + "2147483648 permitted by WebGPU: got 2147483649")); } // The only valid widths for integers are 8, 16, 32, and 64.