Skip to content

Commit

Permalink
[CodeGen] avoid crash if an exception is raised during llvm cpu codeg…
Browse files Browse the repository at this point in the history
…en (apache#9786)

* avoid crash if an exception is raised during llvm cpu codegen

* use pytest.raises
  • Loading branch information
wrongtest-intellif authored Dec 22, 2021
1 parent 4e0bf23 commit d026d06
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 11 deletions.
37 changes: 26 additions & 11 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,28 @@ llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) {
}

void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
/*! \brief maintain states that should be guarded when step into compute scope */
struct ComputeScopeStates {
explicit ComputeScopeStates(CodeGenCPU* parent) : parent_(parent) {}

void EnterWithScope() {
std::swap(function_, parent_->function_);
std::swap(analyzer_, parent_->analyzer_);
std::swap(var_map_, parent_->var_map_);
}

void ExitWithScope() {
std::swap(function_, parent_->function_);
std::swap(analyzer_, parent_->analyzer_);
std::swap(var_map_, parent_->var_map_);
}

llvm::Function* function_{nullptr};
std::unordered_map<const VarNode*, llvm::Value*> var_map_;
std::unique_ptr<arith::Analyzer> analyzer_{std::make_unique<arith::Analyzer>()};
CodeGenCPU* parent_;
};

// There are two reasons why we create another function for compute_scope
// - Make sure the generated compute function is clearly separately(though it can get inlined)
// - Set noalias on all the pointer arguments, some of them are loaded from TVMArgs.
Expand All @@ -515,13 +537,13 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
llvm::Function::Create(ftype, llvm::Function::InternalLinkage,
value->value.operator llvm::StringRef(), module_.get());
BasicBlock* compute_call_end = CheckCallSuccess(builder_->CreateCall(fcompute, arg_values));
// setup compute function.
std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
// enter compute scope and setup compute function.
With<ComputeScopeStates> scope_states_guard(this);
size_t idx = 0;
for (auto it = fcompute->arg_begin(); it != fcompute->arg_end(); ++it, ++idx) {
llvm::Argument* v = &(*it);
const Var& var = vargs[idx];
new_vmap[var.get()] = v;
var_map_[var.get()] = v;
if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) {
// set non alias.
#if TVM_LLVM_VERSION >= 50
Expand All @@ -544,18 +566,11 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
}
#endif
}
auto new_analyzer = std::make_unique<arith::Analyzer>();
std::swap(function_, fcompute);
std::swap(analyzer_, new_analyzer);
std::swap(var_map_, new_vmap);
function_ = fcompute;
BasicBlock* compute_entry = BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(compute_entry);
this->VisitStmt(op->body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(var_map_, new_vmap);
std::swap(analyzer_, new_analyzer);
std::swap(function_, fcompute);
builder_->SetInsertPoint(compute_call_end);
}

Expand Down
18 changes: 18 additions & 0 deletions tests/python/unittest/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tvm import te
from tvm.relay.backend import Runtime
from tvm.contrib import utils, clang
import tvm.script.tir as T
import numpy as np

import math
Expand Down Expand Up @@ -906,5 +907,22 @@ def test_llvm_scalar_concat():
m = tvm.build(mod, [x, y, z], target="llvm")


@tvm.testing.requires_llvm
def test_raise_exception_during_codegen():
@T.prim_func
def threadpool_nested_parallel_loop(
A: T.Buffer[(4, 4), "float32"], B: T.Buffer[(4, 4), "float32"]
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i in T.parallel(4):
for j in T.parallel(4):
T.store(B.data, i * 4 + j, T.load("float32", A.data, i * 4 + j) * 2.0)

with pytest.raises(tvm.TVMError) as e:
tvm.build({"llvm": tvm.IRModule.from_expr(threadpool_nested_parallel_loop)})
msg = str(e)
assert msg.find("Nested parallel loop is not supported") != -1


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit d026d06

Please sign in to comment.