Skip to content

Commit

Permalink
Add support for SPV_KHR_float_controls2 (KhronosGroup#5543)
Browse files Browse the repository at this point in the history
* Test asm/dis for SPV_KHR_float_controls2
* SPV_KHR_float_controls2 validation

---------

Co-authored-by: David Neto <[email protected]>
  • Loading branch information
alan-baker and dneto0 authored Jan 25, 2024
1 parent de3d5ac commit ef2f432
Show file tree
Hide file tree
Showing 9 changed files with 961 additions and 23 deletions.
4 changes: 4 additions & 0 deletions source/val/validate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ spv_result_t ValidateEntryPoints(ValidationState_t& _) {
}
}

if (auto error = ValidateFloatControls2(_)) {
return error;
}

return SPV_SUCCESS;
}

Expand Down
12 changes: 12 additions & 0 deletions source/val/validate.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ spv_result_t ValidateAdjacency(ValidationState_t& _);
/// @return SPV_SUCCESS if no errors are found.
spv_result_t ValidateInterfaces(ValidationState_t& _);

/// @brief Validates entry point call tree requirements of
/// SPV_KHR_float_controls2
///
/// Checks that no entry point using FPFastMathDefault uses:
/// * FPFastMathMode Fast
/// * NoContraction
///
/// @param[in] _ the validation state of the module
///
/// @return SPV_SUCCESS if no errors are found.
spv_result_t ValidateFloatControls2(ValidationState_t& _);

/// @brief Validates memory instructions
///
/// @param[in] _ the validation state of the module
Expand Down
28 changes: 28 additions & 0 deletions source/val/validate_annotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,34 @@ spv_result_t ValidateDecorate(ValidationState_t& _, const Instruction* inst) {
}
}

if (decoration == spv::Decoration::FPFastMathMode) {
if (_.HasDecoration(target_id, spv::Decoration::NoContraction)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "FPFastMathMode and NoContraction cannot decorate the same "
"target";
}
auto mask = inst->GetOperandAs<spv::FPFastMathModeMask>(2);
if ((mask & spv::FPFastMathModeMask::AllowTransform) !=
spv::FPFastMathModeMask::MaskNone &&
((mask & (spv::FPFastMathModeMask::AllowContract |
spv::FPFastMathModeMask::AllowReassoc)) !=
(spv::FPFastMathModeMask::AllowContract |
spv::FPFastMathModeMask::AllowReassoc))) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "AllowReassoc and AllowContract must be specified when "
"AllowTransform is specified";
}
}

// This is checked from both sides since we register decorations as we go.
if (decoration == spv::Decoration::NoContraction) {
if (_.HasDecoration(target_id, spv::Decoration::FPFastMathMode)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "FPFastMathMode and NoContraction cannot decorate the same "
"target";
}
}

if (DecorationTakesIdParameters(decoration)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Decorations taking ID parameters may not be used with "
Expand Down
3 changes: 2 additions & 1 deletion source/val/validate_instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,8 @@ spv_result_t InstructionPass(ValidationState_t& _, const Instruction* inst) {
}
_.set_addressing_model(inst->GetOperandAs<spv::AddressingModel>(0));
_.set_memory_model(inst->GetOperandAs<spv::MemoryModel>(1));
} else if (opcode == spv::Op::OpExecutionMode) {
} else if (opcode == spv::Op::OpExecutionMode ||
opcode == spv::Op::OpExecutionModeId) {
const uint32_t entry_point = inst->word(1);
_.RegisterExecutionModeForEntryPoint(entry_point,
spv::ExecutionMode(inst->word(2)));
Expand Down
173 changes: 157 additions & 16 deletions source/val/validate_mode_setting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,29 +340,92 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,

const auto mode = inst->GetOperandAs<spv::ExecutionMode>(1);
if (inst->opcode() == spv::Op::OpExecutionModeId) {
bool valid_mode = false;
switch (mode) {
case spv::ExecutionMode::SubgroupsPerWorkgroupId:
case spv::ExecutionMode::LocalSizeHintId:
case spv::ExecutionMode::LocalSizeId:
case spv::ExecutionMode::FPFastMathDefault:
valid_mode = true;
break;
default:
valid_mode = false;
break;
}
if (!valid_mode) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpExecutionModeId is only valid when the Mode operand is an "
"execution mode that takes Extra Operands that are id "
"operands.";
}

size_t operand_count = inst->operands().size();
for (size_t i = 2; i < operand_count; ++i) {
const auto operand_id = inst->GetOperandAs<uint32_t>(2);
const auto operand_id = inst->GetOperandAs<uint32_t>(i);
const auto* operand_inst = _.FindDef(operand_id);
if (mode == spv::ExecutionMode::SubgroupsPerWorkgroupId ||
mode == spv::ExecutionMode::LocalSizeHintId ||
mode == spv::ExecutionMode::LocalSizeId) {
if (!spvOpcodeIsConstant(operand_inst->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "For OpExecutionModeId all Extra Operand ids must be "
"constant "
"instructions.";
}
} else {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpExecutionModeId is only valid when the Mode operand is an "
"execution mode that takes Extra Operands that are id "
"operands.";
switch (mode) {
case spv::ExecutionMode::SubgroupsPerWorkgroupId:
case spv::ExecutionMode::LocalSizeHintId:
case spv::ExecutionMode::LocalSizeId:
if (!spvOpcodeIsConstant(operand_inst->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "For OpExecutionModeId all Extra Operand ids must be "
"constant instructions.";
}
break;
case spv::ExecutionMode::FPFastMathDefault:
if (i == 2) {
if (!_.IsFloatScalarType(operand_id)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "The Target Type operand must be a floating-point "
"scalar type";
}
} else {
bool is_int32 = false;
bool is_const = false;
uint32_t value = 0;
std::tie(is_int32, is_const, value) =
_.EvalInt32IfConst(operand_id);
if (is_int32 && is_const) {
// Valid values include up to 0x00040000 (AllowTransform).
uint32_t invalid_mask = 0xfff80000;
if ((invalid_mask & value) != 0) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "The Fast Math Default operand is an invalid bitmask "
"value";
}
if (value &
static_cast<uint32_t>(spv::FPFastMathModeMask::Fast)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "The Fast Math Default operand must not include Fast";
}
const auto reassoc_contract =
spv::FPFastMathModeMask::AllowContract |
spv::FPFastMathModeMask::AllowReassoc;
if ((value & static_cast<uint32_t>(
spv::FPFastMathModeMask::AllowTransform)) != 0 &&
((value & static_cast<uint32_t>(reassoc_contract)) !=
static_cast<uint32_t>(reassoc_contract))) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "The Fast Math Default operand must include "
"AllowContract and AllowReassoc when AllowTransform "
"is specified";
}
} else {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "The Fast Math Default operand must be a "
"non-specialization constant";
}
}
break;
default:
break;
}
}
} else if (mode == spv::ExecutionMode::SubgroupsPerWorkgroupId ||
mode == spv::ExecutionMode::LocalSizeHintId ||
mode == spv::ExecutionMode::LocalSizeId) {
mode == spv::ExecutionMode::LocalSizeId ||
mode == spv::ExecutionMode::FPFastMathDefault) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "OpExecutionMode is only valid when the Mode operand is an "
"execution mode that takes no Extra Operands, or takes Extra "
Expand Down Expand Up @@ -579,6 +642,20 @@ spv_result_t ValidateExecutionMode(ValidationState_t& _,
break;
}

if (mode == spv::ExecutionMode::FPFastMathDefault) {
const auto* modes = _.GetExecutionModes(entry_point_id);
if (modes && modes->count(spv::ExecutionMode::ContractionOff)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "FPFastMathDefault and ContractionOff execution modes cannot "
"be applied to the same entry point";
}
if (modes && modes->count(spv::ExecutionMode::SignedZeroInfNanPreserve)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "FPFastMathDefault and SignedZeroInfNanPreserve execution "
"modes cannot be applied to the same entry point";
}
}

if (spvIsVulkanEnv(_.context()->target_env)) {
if (mode == spv::ExecutionMode::OriginLowerLeft) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
Expand Down Expand Up @@ -636,6 +713,70 @@ spv_result_t ValidateMemoryModel(ValidationState_t& _,

} // namespace

spv_result_t ValidateFloatControls2(ValidationState_t& _) {
std::unordered_set<uint32_t> fp_fast_math_default_entry_points;
for (auto entry_point : _.entry_points()) {
const auto* exec_modes = _.GetExecutionModes(entry_point);
if (exec_modes &&
exec_modes->count(spv::ExecutionMode::FPFastMathDefault)) {
fp_fast_math_default_entry_points.insert(entry_point);
}
}

std::vector<std::pair<const Instruction*, spv::Decoration>> worklist;
for (const auto& inst : _.ordered_instructions()) {
if (inst.opcode() != spv::Op::OpDecorate) {
continue;
}

const auto decoration = inst.GetOperandAs<spv::Decoration>(1);
const auto target_id = inst.GetOperandAs<uint32_t>(0);
const auto target = _.FindDef(target_id);
if (decoration == spv::Decoration::NoContraction) {
worklist.push_back(std::make_pair(target, decoration));
} else if (decoration == spv::Decoration::FPFastMathMode) {
auto mask = inst.GetOperandAs<spv::FPFastMathModeMask>(2);
if ((mask & spv::FPFastMathModeMask::Fast) !=
spv::FPFastMathModeMask::MaskNone) {
worklist.push_back(std::make_pair(target, decoration));
}
}
}

std::unordered_set<const Instruction*> visited;
while (!worklist.empty()) {
const auto inst = worklist.back().first;
const auto decoration = worklist.back().second;
worklist.pop_back();

if (!visited.insert(inst).second) {
continue;
}

const auto function = inst->function();
if (function) {
const auto& entry_points = _.FunctionEntryPoints(function->id());
for (auto entry_point : entry_points) {
if (fp_fast_math_default_entry_points.count(entry_point)) {
const std::string dec = decoration == spv::Decoration::NoContraction
? "NoContraction"
: "FPFastMathMode Fast";
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< dec
<< " cannot be used by an entry point with the "
"FPFastMathDefault execution mode";
}
}
} else {
for (const auto& pair : inst->uses()) {
worklist.push_back(std::make_pair(pair.first, decoration));
}
}
}

return SPV_SUCCESS;
}

spv_result_t ModeSettingPass(ValidationState_t& _, const Instruction* inst) {
switch (inst->opcode()) {
case spv::Op::OpEntryPoint:
Expand Down
17 changes: 11 additions & 6 deletions test/operand_capabilities_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
#include "source/assembly_grammar.h"
#include "source/enum_set.h"
#include "source/operand.h"
#include "source/spirv_target_env.h"
#include "source/table.h"
#include "spirv-tools/libspirv.h"
#include "test/unit_spirv.h"

namespace spvtools {
Expand Down Expand Up @@ -58,15 +61,17 @@ struct EnumCapabilityCase {
uint32_t value;
CapabilitySet expected_capabilities;
};
// Emits an EnumCapabilityCase to the ostream, returning the ostream.
inline std::ostream& operator<<(std::ostream& out,
const EnumCapabilityCase& ecc) {
out << "EnumCapabilityCase{ " << spvOperandTypeStr(ecc.type) << "("
<< unsigned(ecc.type) << "), " << ecc.value << ", "
<< ecc.expected_capabilities << "}";

// Emits an EnumCapabilityCase to the given output stream. This is used
// to emit failure cases when they occur, which helps debug tests.
inline std::ostream& operator<<(std::ostream& out, EnumCapabilityCase e) {
out << "{" << spvOperandTypeStr(e.type) << " " << e.value << " "
<< e.expected_capabilities << " }";
return out;
}

using EnvEnumCapabilityCase = std::tuple<spv_target_env, EnumCapabilityCase>;

// Test fixture for testing EnumCapabilityCases.
using EnumCapabilityTest =
TestWithParam<std::tuple<spv_target_env, EnumCapabilityCase>>;
Expand Down
36 changes: 36 additions & 0 deletions test/text_to_binary.extension_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1264,5 +1264,41 @@ INSTANTIATE_TEST_SUITE_P(
{1, (uint32_t)spv::ExecutionMode::MaximallyReconvergesKHR})},
})));

// SPV_KHR_float_controls2

INSTANTIATE_TEST_SUITE_P(
SPV_KHR_float_controls2, ExtensionRoundTripTest,
Combine(
Values(SPV_ENV_UNIVERSAL_1_0, SPV_ENV_UNIVERSAL_1_5, SPV_ENV_VULKAN_1_0,
SPV_ENV_VULKAN_1_1, SPV_ENV_VULKAN_1_2, SPV_ENV_VULKAN_1_3),
ValuesIn(std::vector<AssemblyCase>{
{"OpExtension \"SPV_KHR_float_controls2\"\n",
MakeInstruction(spv::Op::OpExtension,
MakeVector("SPV_KHR_float_controls2"))},
{"OpCapability FloatControls2\n",
MakeInstruction(spv::Op::OpCapability,
{(uint32_t)spv::Capability::FloatControls2})},
{"OpExecutionMode %1 FPFastMathDefault %2 %3\n",
// The operands are: target type, flags constant
MakeInstruction(
spv::Op::OpExecutionMode,
{1, (uint32_t)spv::ExecutionMode::FPFastMathDefault, 2, 3})},
{"OpDecorate %1 FPFastMathMode AllowContract\n",
MakeInstruction(
spv::Op::OpDecorate,
{1, (uint32_t)spv::Decoration::FPFastMathMode,
(uint32_t)spv::FPFastMathModeMask::AllowContract})},
{"OpDecorate %1 FPFastMathMode AllowReassoc\n",
MakeInstruction(
spv::Op::OpDecorate,
{1, (uint32_t)spv::Decoration::FPFastMathMode,
(uint32_t)spv::FPFastMathModeMask::AllowReassoc})},
{"OpDecorate %1 FPFastMathMode AllowTransform\n",
MakeInstruction(
spv::Op::OpDecorate,
{1, (uint32_t)spv::Decoration::FPFastMathMode,
(uint32_t)spv::FPFastMathModeMask::AllowTransform})},
})));

} // namespace
} // namespace spvtools
Loading

0 comments on commit ef2f432

Please sign in to comment.