From a954742fba28f235e1a14a322c7e1d89372af8e6 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 5 May 2023 12:06:46 -0500 Subject: [PATCH] [TIR][Driver] Move ShouldAnnotateEntryFunc logic into transform (#14562) * [TIR][Driver] Move ShouldAnnotateEntryFunc logic into transform Prior to this commit, the `MixedModulePassManager` determined whether the module should have its entry function annotated, while the `AnnotateEntryFunc` validates that this condition was upheld, duplicating the logic applied. This commit moves the logic to `AnnotateEntryFunc`, which is unconditionally called from the `MixedModulePassManager`, but which is a no-op for cases where no annotation is required. * Maintain previous behavior for modules with one function --- src/driver/driver_api.cc | 11 +------ src/tir/transforms/primfunc_utils.cc | 47 +++++++++++++++++++++++++--- 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 486b40c9946a..91bc57ccbeb2 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -71,13 +71,6 @@ bool LLVMEnabled() { return pf != nullptr; } -bool ShouldAnnotateEntryFunc(const IRModule mod) { - Optional executor = mod->GetAttr("executor"); - const bool aot_executor = executor.defined() && executor.value()->name == "aot"; - const bool single_entry_func = (mod->functions.size() == 1); - return single_entry_func && !aot_executor; -} - /*! \return The default host target for a given device target */ Target DefaultTargetHost(Target target) { if (target.defined() && target->GetTargetDeviceType() == kDLCPU) { @@ -558,9 +551,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::VerifyMemory()); - if (ShouldAnnotateEntryFunc(mixed_mod)) { - mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc()); - } + mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc()); bool detect_global_barrier = pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value(); diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc index 208077b492da..257e3eacda90 100644 --- a/src/tir/transforms/primfunc_utils.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -23,6 +23,7 @@ */ #include +#include #include namespace tvm { @@ -40,11 +41,49 @@ transform::Pass BindTarget(Target target) { } transform::Pass AnnotateEntryFunc() { - auto fpass = [](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { - ICHECK(m->functions.size() == 1); - return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(true)); + auto fpass = [](IRModule mod, transform::PassContext ctx) -> IRModule { + // AOT tracks the entry function, no annotation required + auto executor = mod->GetAttr("executor"); + const bool is_aot_executor = executor.defined() && executor.value()->name == "aot"; + if (is_aot_executor) { + return mod; + } + + // If only a single function exists, that function must be the entry + if (mod->functions.size() == 1) { + auto [gvar, base_func] = *mod->functions.begin(); + if (!base_func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + if (auto ptr = base_func.as()) { + mod->Update(gvar, WithAttr(GetRef(ptr), tir::attr::kIsEntryFunc, Bool(true))); + } + } + return mod; + } + + // If the module has multiple functions, but only one is exposed + // externally, that function must be the entry. + bool has_external_non_primfuncs = false; + IRModule with_annotations; + for (const auto& [gvar, base_func] : mod->functions) { + bool is_external = base_func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + if (is_external) { + if (auto ptr = base_func.as()) { + with_annotations->Add( + gvar, WithAttr(GetRef(ptr), tir::attr::kIsEntryFunc, Bool(true))); + } else { + has_external_non_primfuncs = true; + } + } + } + if (with_annotations->functions.size() == 1 && !has_external_non_primfuncs) { + mod->Update(with_annotations); + return mod; + } + + // Default fallback, no annotations may be inferred. + return mod; }; - return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.AnnotateEntryFunc", {}); + return tvm::transform::CreateModulePass(fpass, 0, "tir.AnnotateEntryFunc", {}); } transform::Pass Filter(runtime::TypedPackedFunc fcond) {