Skip to content

Commit

Permalink
Validate nested constructs (KhronosGroup#3068)
Browse files Browse the repository at this point in the history
* Validate that if a construct contains a header and it's merge is
reachable, the construct also contains the merge
* updated block merging to not merge into the continue
* update inlining to mark the original block of a single block loop as
the continue
* updated some tests
* remove dead code
* rename kBlockTypeHeader to kBlockTypeSelection for clarity
  • Loading branch information
alan-baker authored and dneto0 committed Nov 27, 2019
1 parent 52e9cc9 commit b334829
Show file tree
Hide file tree
Showing 18 changed files with 291 additions and 256 deletions.
17 changes: 17 additions & 0 deletions source/opt/block_merge_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ bool IsMerge(IRContext* context, BasicBlock* block) {
return IsMerge(context, block->id());
}

// Returns true if |id| is the continue target of a merge instruction.
bool IsContinue(IRContext* context, uint32_t id) {
return !context->get_def_use_mgr()->WhileEachUse(
id, [](Instruction* user, uint32_t index) {
SpvOp op = user->opcode();
if (op == SpvOpLoopMerge && index == 1u) {
return false;
}
return true;
});
}

// Removes any OpPhi instructions in |block|, which should have exactly one
// predecessor, replacing uses of OpPhi ids with the ids associated with the
// predecessor.
Expand Down Expand Up @@ -86,6 +98,11 @@ bool CanMergeWithSuccessor(IRContext* context, BasicBlock* block) {
return false;
}

if (pred_is_merge && IsContinue(context, lab_id)) {
// Cannot merge a continue target with a merge block.
return false;
}

// Don't bother trying to merge unreachable blocks.
if (auto dominators = context->GetDominatorAnalysis(block->GetParent())) {
if (!dominators->IsReachable(block)) return false;
Expand Down
22 changes: 6 additions & 16 deletions source/opt/inline_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
static const int kSpvFunctionCallFunctionId = 2;
static const int kSpvFunctionCallArgumentId = 3;
static const int kSpvReturnValueId = 0;
static const int kSpvLoopMergeContinueTargetIdInIdx = 1;

namespace spvtools {
namespace opt {
Expand Down Expand Up @@ -285,19 +284,14 @@ bool InlinePass::GenInlineCode(
if (rid != 0) callee_result_ids.insert(rid);
});

// If the caller is in a single-block loop, and the callee has multiple
// blocks, then the normal inlining logic will place the OpLoopMerge in
// the last of several blocks in the loop. Instead, it should be placed
// at the end of the first block. First determine if the caller is in a
// single block loop. We'll wait to move the OpLoopMerge until the end
// of the regular inlining logic, and only if necessary.
bool caller_is_single_block_loop = false;
// If the caller is a loop header and the callee has multiple blocks, then the
// normal inlining logic will place the OpLoopMerge in the last of several
// blocks in the loop. Instead, it should be placed at the end of the first
// block. We'll wait to move the OpLoopMerge until the end of the regular
// inlining logic, and only if necessary.
bool caller_is_loop_header = false;
if (auto* loop_merge = call_block_itr->GetLoopMergeInst()) {
if (call_block_itr->GetLoopMergeInst()) {
caller_is_loop_header = true;
caller_is_single_block_loop =
call_block_itr->id() ==
loop_merge->GetSingleWordInOperand(kSpvLoopMergeContinueTargetIdInIdx);
}

bool callee_begins_with_structured_header =
Expand Down Expand Up @@ -611,10 +605,6 @@ bool InlinePass::GenInlineCode(
--loop_merge_itr;
assert(loop_merge_itr->opcode() == SpvOpLoopMerge);
std::unique_ptr<Instruction> cp_inst(loop_merge_itr->Clone(context()));
if (caller_is_single_block_loop) {
// Also, update its continue target to point to the last block.
cp_inst->SetInOperand(kSpvLoopMergeContinueTargetIdInIdx, {last->id()});
}
first->tail().InsertBefore(std::move(cp_inst));

// Remove the loop merge from the last block.
Expand Down
4 changes: 2 additions & 2 deletions source/val/basic_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
#ifndef SOURCE_VAL_BASIC_BLOCK_H_
#define SOURCE_VAL_BASIC_BLOCK_H_

#include <cstdint>
#include <bitset>
#include <cstdint>
#include <functional>
#include <memory>
#include <vector>
Expand All @@ -28,7 +28,7 @@ namespace val {

enum BlockType : uint32_t {
kBlockTypeUndefined,
kBlockTypeHeader,
kBlockTypeSelection,
kBlockTypeLoop,
kBlockTypeMerge,
kBlockTypeBreak,
Expand Down
7 changes: 3 additions & 4 deletions source/val/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

#include "source/val/function.h"

#include <cassert>

#include <algorithm>
#include <cassert>
#include <sstream>
#include <unordered_map>
#include <unordered_set>
Expand Down Expand Up @@ -99,7 +98,7 @@ spv_result_t Function::RegisterLoopMerge(uint32_t merge_id,
spv_result_t Function::RegisterSelectionMerge(uint32_t merge_id) {
RegisterBlock(merge_id, false);
BasicBlock& merge_block = blocks_.at(merge_id);
current_block_->set_type(kBlockTypeHeader);
current_block_->set_type(kBlockTypeSelection);
merge_block.set_type(kBlockTypeMerge);
merge_block_header_[&merge_block] = current_block_;

Expand Down Expand Up @@ -344,7 +343,7 @@ int Function::GetBlockDepth(BasicBlock* bb) {
BasicBlock* header = merge_block_header_[bb];
assert(header);
block_depth_[bb] = GetBlockDepth(header);
} else if (bb_dom->is_type(kBlockTypeHeader) ||
} else if (bb_dom->is_type(kBlockTypeSelection) ||
bb_dom->is_type(kBlockTypeLoop)) {
// The dominator of the given block is a header block. So, the nesting
// depth of this block is: 1 + nesting depth of the header.
Expand Down
23 changes: 21 additions & 2 deletions source/val/validate_cfg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "source/val/validate.h"

#include <algorithm>
#include <cassert>
#include <functional>
Expand All @@ -34,6 +32,7 @@
#include "source/val/basic_block.h"
#include "source/val/construct.h"
#include "source/val/function.h"
#include "source/val/validate.h"
#include "source/val/validation_state.h"

namespace spvtools {
Expand Down Expand Up @@ -755,6 +754,26 @@ spv_result_t StructuredControlFlowChecks(
<< header_name << " <ID> " << header->id();
}
}

if (block->is_type(BlockType::kBlockTypeSelection) ||
block->is_type(BlockType::kBlockTypeLoop)) {
size_t index = (block->terminator() - &_.ordered_instructions()[0]) - 1;
const auto& merge_inst = _.ordered_instructions()[index];
if (merge_inst.opcode() == SpvOpSelectionMerge ||
merge_inst.opcode() == SpvOpLoopMerge) {
uint32_t merge_id = merge_inst.GetOperandAs<uint32_t>(0);
auto merge_block = function->GetBlock(merge_id).first;
if (merge_block->reachable() &&
!construct_blocks.count(merge_block)) {
return _.diag(SPV_ERROR_INVALID_CFG, _.FindDef(block->id()))
<< "Header block " << _.getIdName(block->id())
<< " is contained in the " << construct_name
<< " construct headed by " << _.getIdName(header->id())
<< ", but it's merge block " << _.getIdName(merge_id)
<< " is not";
}
}
}
}

// Checks rules for case constructs.
Expand Down
112 changes: 0 additions & 112 deletions test/fuzz/transformation_add_dead_continue_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,117 +209,6 @@ TEST(TransformationAddDeadContinueTest, SimpleExample) {
ASSERT_TRUE(IsEqual(env, after_transformation, context.get()));
}

TEST(TransformationAddDeadContinueTest,
DoNotAllowContinueToMergeBlockOfAnotherLoop) {
// A loop header must dominate its merge block if that merge block is
// reachable. We are thus not allowed to add a dead continue that would result
// in violation of this property. This test checks for such a scenario.

std::string shader = R"(
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %4 "main" %16 %139
OpExecutionMode %4 OriginUpperLeft
OpSource ESSL 310
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%6 = OpTypeFloat 32
%7 = OpTypePointer Function %6
%8 = OpTypeBool
%14 = OpTypeVector %6 4
%15 = OpTypePointer Input %14
%16 = OpVariable %15 Input
%138 = OpTypePointer Output %14
%139 = OpVariable %138 Output
%400 = OpConstantTrue %8
%4 = OpFunction %2 None %3
%5 = OpLabel
OpBranch %500
%500 = OpLabel
OpLoopMerge %501 %502 None
OpBranch %503 ; We are not allowed to change this to OpBranchConditional %400 %503 %502
%503 = OpLabel
OpLoopMerge %502 %504 None
OpBranchConditional %400 %505 %504
%505 = OpLabel
OpBranch %502
%504 = OpLabel
OpBranch %503
%502 = OpLabel
OpBranchConditional %400 %501 %500
%501 = OpLabel
OpReturn
OpFunctionEnd
)";

const auto env = SPV_ENV_UNIVERSAL_1_3;
const auto consumer = nullptr;
const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
ASSERT_TRUE(IsValid(env, context.get()));
FactManager fact_manager;

ASSERT_FALSE(TransformationAddDeadContinue(500, true, {})
.IsApplicable(context.get(), fact_manager));
ASSERT_FALSE(TransformationAddDeadContinue(500, false, {})
.IsApplicable(context.get(), fact_manager));
}

TEST(TransformationAddDeadContinueTest, DoNotAllowContinueToSelectionMerge) {
// A selection header must dominate its merge block if that merge block is
// reachable. We are thus not allowed to add a dead continue that would result
// in violation of this property. This test checks for such a scenario.

std::string shader = R"(
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %4 "main" %16 %139
OpExecutionMode %4 OriginUpperLeft
OpSource ESSL 310
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%6 = OpTypeFloat 32
%7 = OpTypePointer Function %6
%8 = OpTypeBool
%14 = OpTypeVector %6 4
%15 = OpTypePointer Input %14
%16 = OpVariable %15 Input
%138 = OpTypePointer Output %14
%139 = OpVariable %138 Output
%400 = OpConstantTrue %8
%4 = OpFunction %2 None %3
%5 = OpLabel
OpBranch %500
%500 = OpLabel
OpLoopMerge %501 %502 None
OpBranch %503 ; We are not allowed to change this to OpBranchConditional %400 %503 %502
%503 = OpLabel
OpSelectionMerge %502 None
OpBranchConditional %400 %505 %504
%505 = OpLabel
OpBranch %502
%504 = OpLabel
OpBranch %502
%502 = OpLabel
OpBranchConditional %400 %501 %500
%501 = OpLabel
OpReturn
OpFunctionEnd
)";

const auto env = SPV_ENV_UNIVERSAL_1_3;
const auto consumer = nullptr;
const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
ASSERT_TRUE(IsValid(env, context.get()));
FactManager fact_manager;

ASSERT_FALSE(TransformationAddDeadContinue(500, true, {})
.IsApplicable(context.get(), fact_manager));
ASSERT_FALSE(TransformationAddDeadContinue(500, false, {})
.IsApplicable(context.get(), fact_manager));
}

TEST(TransformationAddDeadContinueTest, LoopNest) {
// Checks some allowed and disallowed scenarios for a nest of loops, including
// continuing a loop from an if or switch.
Expand Down Expand Up @@ -1420,7 +1309,6 @@ TEST(TransformationAddDeadContinueTest, Miscellaneous2) {
OpLoopMerge %1557 %1570 None
OpBranchConditional %395 %1562 %1557
%1562 = OpLabel
OpSelectionMerge %1570 None
OpBranchConditional %395 %1571 %1570
%1571 = OpLabel
OpBranch %1557
Expand Down
2 changes: 1 addition & 1 deletion test/fuzz/transformation_copy_object_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ TEST(TransformationCopyObjectTest, CheckIllegalCases) {
%31 = OpLabel
%42 = OpAccessChain %36 %18 %41
%43 = OpLoad %11 %42
OpSelectionMerge %47 None
OpSelectionMerge %45 None
OpSwitch %43 %46 0 %44 1 %45
%46 = OpLabel
%69 = OpIAdd %11 %96 %27
Expand Down
20 changes: 12 additions & 8 deletions test/fuzz/transformation_replace_id_with_synonym_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,20 @@ const std::string kComplexShader = R"(
%65 = OpAccessChain %13 %11 %64
%66 = OpLoad %6 %65
%67 = OpSGreaterThan %29 %84 %66
OpSelectionMerge %69 None
OpSelectionMerge %1000 None
OpBranchConditional %67 %68 %72
%68 = OpLabel
%71 = OpIAdd %6 %84 %26
OpBranch %69
OpBranch %1000
%72 = OpLabel
%74 = OpIAdd %6 %84 %64
%205 = OpCopyObject %6 %74
OpBranch %69
%69 = OpLabel
OpBranch %1000
%1000 = OpLabel
%86 = OpPhi %6 %71 %68 %74 %72
%301 = OpPhi %6 %71 %68 %15 %72
OpBranch %69
%69 = OpLabel
OpBranch %20
%22 = OpLabel
%75 = OpAccessChain %46 %42 %50
Expand Down Expand Up @@ -421,18 +423,20 @@ TEST(TransformationReplaceIdWithSynonymTest, LegalTransformations) {
%65 = OpAccessChain %13 %11 %64
%66 = OpLoad %6 %65
%67 = OpSGreaterThan %29 %84 %66
OpSelectionMerge %69 None
OpSelectionMerge %1000 None
OpBranchConditional %67 %68 %72
%68 = OpLabel
%71 = OpIAdd %6 %84 %26
OpBranch %69
OpBranch %1000
%72 = OpLabel
%74 = OpIAdd %6 %84 %64
%205 = OpCopyObject %6 %74
OpBranch %69
%69 = OpLabel
OpBranch %1000
%1000 = OpLabel
%86 = OpPhi %6 %71 %68 %205 %72
%301 = OpPhi %6 %71 %68 %15 %72
OpBranch %69
%69 = OpLabel
OpBranch %20
%22 = OpLabel
%75 = OpAccessChain %46 %42 %50
Expand Down
1 change: 0 additions & 1 deletion test/opt/aggressive_dead_code_elim_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5932,7 +5932,6 @@ OpBranch %42
%42 = OpLabel
%43 = OpLoad %int %i
%44 = OpSLessThan %bool %43 %int_1
OpSelectionMerge %45 None
OpBranchConditional %44 %46 %40
%46 = OpLabel
%47 = OpLoad %int %i
Expand Down
Loading

0 comments on commit b334829

Please sign in to comment.