diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 5bac25faa5fb..4ed164e5ad45 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -348,12 +349,13 @@ TVM_DLL Pass VerifyGPUCode(Map constraints); /*! * \brief Pass to checks if the size of the allocated vtcm memory satisfies the limit * - * \param limit The limit to check. + * \param target The target whose VTCM limit should be used for any + * functions not already annotated with `tvm::attr::kTarget`. * * \returns The pass. * \sa tvm::tir::CalculateAllocatedBytes */ -TVM_DLL Pass VerifyVTCMLimit(const Integer& limit); +TVM_DLL Pass VerifyVTCMLimit(Optional target = NullOpt); /*! * \brief Statically check TIR code for out of bounds array access. diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index 884215c24a13..65cc13eb61fc 100644 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -1408,9 +1408,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i } if (IsHexagonTask(task)) { Target target = task->target; - const auto vtcm_capacity = target->GetAttr("vtcm-capacity").value().IntValue(); - const auto& optimize = - tir::transform::Sequential({tir::transform::VerifyVTCMLimit(vtcm_capacity)}); + const auto& optimize = tir::transform::Sequential({tir::transform::VerifyVTCMLimit(target)}); optimize(mod); } const auto& optimize = diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 1962b9ab3b2d..486b40c9946a 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -544,22 +544,13 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, return TIRToRuntime(inputs, target_host); } -int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) { - if (!target.defined()) target = Target::Current(/*allow_not_defined=*/true); - if (target.defined() && target->kind->name == "hexagon") { - auto value = Downcast(target->attrs.at("vtcm-capacity"))->value; - if (value > 0) return value; - } - return pass_ctx->GetConfig("tir.vtcm_capacity", Integer(0)).value()->value; -} - transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { transform::PassContext pass_ctx = transform::PassContext::Current(); Array mixed_pass_list; // VerifyVTCMLimit must occur before LowerVtcmAlloc - mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(GetVTCMCapacity(target, pass_ctx))); + mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target)); // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc()); diff --git a/src/tir/analysis/calculate_allocated_memory.cc b/src/tir/analysis/calculate_allocated_memory.cc index 95fd7f134ed2..ffdfc1f80162 100644 --- a/src/tir/analysis/calculate_allocated_memory.cc +++ b/src/tir/analysis/calculate_allocated_memory.cc @@ -96,20 +96,39 @@ bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) { return true; } +int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) { + if (!target.defined()) target = Target::Current(/*allow_not_defined=*/true); + if (target.defined() && target->kind->name == "hexagon") { + auto value = Downcast(target->attrs.at("vtcm-capacity"))->value; + if (value > 0) return value; + } + return pass_ctx->GetConfig("tir.vtcm_capacity", Integer(0)).value()->value; +} + namespace transform { -Pass VerifyVTCMLimit(const Integer& limit) { +Pass VerifyVTCMLimit(Optional default_target) { auto pass_func = [=](IRModule mod, PassContext ctx) { for (auto kv : mod->functions) { - if (auto func = kv.second.as()) { - auto sizes = CalculateAllocatedBytes(func.value()); - const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0); - if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) { - LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation limit has been " - "exceeded(allocated: " - << vtcm_allocated << ", limit: " << limit << ").\n" - << "In function\n" - << func; + if (auto opt = kv.second.as()) { + auto func = opt.value(); + + std::optional limit = std::nullopt; + if (auto func_target = func->GetAttr(tvm::attr::kTarget)) { + limit = GetVTCMCapacity(func_target.value(), ctx); + } else if (default_target) { + limit = GetVTCMCapacity(default_target.value(), ctx); + } + + if (limit.has_value() && limit.value() > 0) { + auto sizes = CalculateAllocatedBytes(func); + const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0); + if (vtcm_allocated.IntValue() > limit.value()) { + LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation limit has been exceeded " + << "(allocated: " << vtcm_allocated << ", limit: " << limit.value() << ").\n" + << "In function\n" + << func; + } } } }