Skip to content

Commit

Permalink
[TIR][Hexagon] Use the "target" value in T.func_attr for VTCM limit (a…
Browse files Browse the repository at this point in the history
…pache#14567)

* [TIR][Hexagon] Use the "target" value in T.func_attr for VTCM limit

For the VerifyVTCMLimit, read directly from the function attribute, if
the function has already been annotated with the target.

* Retain passing of target to VerifyVTCMLimit
  • Loading branch information
Lunderberg authored Apr 13, 2023
1 parent 68ce1e8 commit b48fcab
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 25 deletions.
6 changes: 4 additions & 2 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/ir/module.h>
#include <tvm/ir/transform.h>
#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op_attr_types.h>
Expand Down Expand Up @@ -348,12 +349,13 @@ TVM_DLL Pass VerifyGPUCode(Map<String, PrimExpr> constraints);
/*!
* \brief Pass to checks if the size of the allocated vtcm memory satisfies the limit
*
* \param limit The limit to check.
* \param target The target whose VTCM limit should be used for any
* functions not already annotated with `tvm::attr::kTarget`.
*
* \returns The pass.
* \sa tvm::tir::CalculateAllocatedBytes
*/
TVM_DLL Pass VerifyVTCMLimit(const Integer& limit);
TVM_DLL Pass VerifyVTCMLimit(Optional<Target> target = NullOpt);

/*!
* \brief Statically check TIR code for out of bounds array access.
Expand Down
4 changes: 1 addition & 3 deletions src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1408,9 +1408,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i
}
if (IsHexagonTask(task)) {
Target target = task->target;
const auto vtcm_capacity = target->GetAttr<Integer>("vtcm-capacity").value().IntValue();
const auto& optimize =
tir::transform::Sequential({tir::transform::VerifyVTCMLimit(vtcm_capacity)});
const auto& optimize = tir::transform::Sequential({tir::transform::VerifyVTCMLimit(target)});
optimize(mod);
}
const auto& optimize =
Expand Down
11 changes: 1 addition & 10 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -544,22 +544,13 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg,
return TIRToRuntime(inputs, target_host);
}

int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) {
if (!target.defined()) target = Target::Current(/*allow_not_defined=*/true);
if (target.defined() && target->kind->name == "hexagon") {
auto value = Downcast<Integer>(target->attrs.at("vtcm-capacity"))->value;
if (value > 0) return value;
}
return pass_ctx->GetConfig<Integer>("tir.vtcm_capacity", Integer(0)).value()->value;
}

transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) {
transform::PassContext pass_ctx = transform::PassContext::Current();

Array<Pass> mixed_pass_list;

// VerifyVTCMLimit must occur before LowerVtcmAlloc
mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(GetVTCMCapacity(target, pass_ctx)));
mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target));
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc());

Expand Down
39 changes: 29 additions & 10 deletions src/tir/analysis/calculate_allocated_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,39 @@ bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) {
return true;
}

int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) {
if (!target.defined()) target = Target::Current(/*allow_not_defined=*/true);
if (target.defined() && target->kind->name == "hexagon") {
auto value = Downcast<Integer>(target->attrs.at("vtcm-capacity"))->value;
if (value > 0) return value;
}
return pass_ctx->GetConfig<Integer>("tir.vtcm_capacity", Integer(0)).value()->value;
}

namespace transform {

Pass VerifyVTCMLimit(const Integer& limit) {
Pass VerifyVTCMLimit(Optional<Target> default_target) {
auto pass_func = [=](IRModule mod, PassContext ctx) {
for (auto kv : mod->functions) {
if (auto func = kv.second.as<PrimFunc>()) {
auto sizes = CalculateAllocatedBytes(func.value());
const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) {
LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation limit has been "
"exceeded(allocated: "
<< vtcm_allocated << ", limit: " << limit << ").\n"
<< "In function\n"
<< func;
if (auto opt = kv.second.as<PrimFunc>()) {
auto func = opt.value();

std::optional<int64_t> limit = std::nullopt;
if (auto func_target = func->GetAttr<Target>(tvm::attr::kTarget)) {
limit = GetVTCMCapacity(func_target.value(), ctx);
} else if (default_target) {
limit = GetVTCMCapacity(default_target.value(), ctx);
}

if (limit.has_value() && limit.value() > 0) {
auto sizes = CalculateAllocatedBytes(func);
const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
if (vtcm_allocated.IntValue() > limit.value()) {
LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation limit has been exceeded "
<< "(allocated: " << vtcm_allocated << ", limit: " << limit.value() << ").\n"
<< "In function\n"
<< func;
}
}
}
}
Expand Down

0 comments on commit b48fcab

Please sign in to comment.