Skip to content

Commit

Permalink
[TIR][Driver] Move ShouldAnnotateEntryFunc logic into transform (apac…
Browse files Browse the repository at this point in the history
…he#14562)

* [TIR][Driver] Move ShouldAnnotateEntryFunc logic into transform

Prior to this commit, the `MixedModulePassManager` determined whether
the module should have its entry function annotated, while the
`AnnotateEntryFunc` validates that this condition was upheld,
duplicating the logic applied.  This commit moves the logic to
`AnnotateEntryFunc`, which is unconditionally called from the
`MixedModulePassManager`, but which is a no-op for cases where no
annotation is required.

* Maintain previous behavior for modules with one function
  • Loading branch information
Lunderberg authored May 5, 2023
1 parent ddd2e81 commit a954742
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 14 deletions.
11 changes: 1 addition & 10 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,6 @@ bool LLVMEnabled() {
return pf != nullptr;
}

bool ShouldAnnotateEntryFunc(const IRModule mod) {
Optional<tvm::relay::Executor> executor = mod->GetAttr<tvm::relay::Executor>("executor");
const bool aot_executor = executor.defined() && executor.value()->name == "aot";
const bool single_entry_func = (mod->functions.size() == 1);
return single_entry_func && !aot_executor;
}

/*! \return The default host target for a given device target */
Target DefaultTargetHost(Target target) {
if (target.defined() && target->GetTargetDeviceType() == kDLCPU) {
Expand Down Expand Up @@ -558,9 +551,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)

mixed_pass_list.push_back(tir::transform::VerifyMemory());

if (ShouldAnnotateEntryFunc(mixed_mod)) {
mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc());
}
mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc());

bool detect_global_barrier =
pass_ctx->GetConfig<Bool>("tir.detect_global_barrier", Bool(false)).value();
Expand Down
47 changes: 43 additions & 4 deletions src/tir/transforms/primfunc_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/

#include <tvm/driver/driver_api.h>
#include <tvm/relay/executor.h>
#include <tvm/tir/transform.h>

namespace tvm {
Expand All @@ -40,11 +41,49 @@ transform::Pass BindTarget(Target target) {
}

transform::Pass AnnotateEntryFunc() {
auto fpass = [](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
ICHECK(m->functions.size() == 1);
return WithAttr(std::move(f), tir::attr::kIsEntryFunc, Bool(true));
auto fpass = [](IRModule mod, transform::PassContext ctx) -> IRModule {
// AOT tracks the entry function, no annotation required
auto executor = mod->GetAttr<tvm::relay::Executor>("executor");
const bool is_aot_executor = executor.defined() && executor.value()->name == "aot";
if (is_aot_executor) {
return mod;
}

// If only a single function exists, that function must be the entry
if (mod->functions.size() == 1) {
auto [gvar, base_func] = *mod->functions.begin();
if (!base_func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
if (auto ptr = base_func.as<PrimFuncNode>()) {
mod->Update(gvar, WithAttr(GetRef<PrimFunc>(ptr), tir::attr::kIsEntryFunc, Bool(true)));
}
}
return mod;
}

// If the module has multiple functions, but only one is exposed
// externally, that function must be the entry.
bool has_external_non_primfuncs = false;
IRModule with_annotations;
for (const auto& [gvar, base_func] : mod->functions) {
bool is_external = base_func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined();
if (is_external) {
if (auto ptr = base_func.as<PrimFuncNode>()) {
with_annotations->Add(
gvar, WithAttr(GetRef<PrimFunc>(ptr), tir::attr::kIsEntryFunc, Bool(true)));
} else {
has_external_non_primfuncs = true;
}
}
}
if (with_annotations->functions.size() == 1 && !has_external_non_primfuncs) {
mod->Update(with_annotations);
return mod;
}

// Default fallback, no annotations may be inferred.
return mod;
};
return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.AnnotateEntryFunc", {});
return tvm::transform::CreateModulePass(fpass, 0, "tir.AnnotateEntryFunc", {});
}

transform::Pass Filter(runtime::TypedPackedFunc<bool(PrimFunc)> fcond) {
Expand Down

0 comments on commit a954742

Please sign in to comment.