Skip to content

Commit

Permalink
Add an option to dump textual pipeline during compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Oct 24, 2024
1 parent 8923f16 commit d17d489
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 74 deletions.
13 changes: 11 additions & 2 deletions mlir-tensorrt/compiler/include/mlir-tensorrt-c/Compiler/Compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define MLIR_TENSORRT_C_COMPILER_COMPILER

#include "mlir-c/IR.h"
#include "mlir-c/Pass.h"
#include "mlir-c/Support.h"
#include "mlir-executor-c/Common/Common.h"
#include "mlir-executor-c/Support/Status.h"
Expand All @@ -47,8 +48,8 @@ mtrtCompilerClientCreate(MlirContext context, MTRT_CompilerClient *client);
MLIR_CAPI_EXPORTED MTRT_Status
mtrtCompilerClientDestroy(MTRT_CompilerClient client);

static inline bool mtrtCompilerClientIsNull(MTRT_CompilerClient options) {
return !options.ptr;
static inline bool mtrtCompilerClientIsNull(MTRT_CompilerClient client) {
return !client.ptr;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -108,6 +109,14 @@ MLIR_CAPI_EXPORTED MTRT_Status mtrtCompilerStableHLOToExecutable(
MTRT_CompilerClient client, MlirOperation module,
MTRT_StableHLOToExecutableOptions options, MTRT_Executable *result);

MLIR_CAPI_EXPORTED MTRT_Status mtrtCompilerPopulatePassManager(
MTRT_CompilerClient compilerClient,
MTRT_StableHLOToExecutableOptions stableHloToExecutableOptions,
MlirPassManager *passManager);

MLIR_CAPI_EXPORTED MTRT_Status mtrtTranslateRuntimeToExecutable(
MlirOperation moduleOp, MTRT_Executable *result);

//===----------------------------------------------------------------------===//
// MTRT_StableHLOProgramSignatureRefinementOptions
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,6 @@ class StableHloToExecutableTask
static void populatePassManager(mlir::PassManager &pm,
const StableHLOToExecutableOptions &options);

/// Compile a StableHLO module into a MLIR-TensorRT Runtime executable.
/// This is the "functional" entrypoint that will allocate a new PassManager
/// for a single run.
static mlirtrt::StatusOr<std::unique_ptr<runtime::Executable>>
compileStableHLOToExecutable(mlir::ModuleOp module,
const StableHLOToExecutableOptions &options);

/// Compile a StableHLO module into a MLIR-TensorRT Runtime executable.
/// This is the "functional" entrypoint that will allocate a new PassManager
/// for a single run.
Expand Down
72 changes: 71 additions & 1 deletion mlir-tensorrt/compiler/lib/CAPI/Compiler/Compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,19 @@
//===----------------------------------------------------------------------===//
#include "mlir-tensorrt-c/Compiler/Compiler.h"
#include "mlir-c/IR.h"
#include "mlir-c/Pass.h"
#include "mlir-c/Support.h"
#include "mlir-executor-c/Support/Status.h"
#include "mlir-tensorrt-dialect/Target/TranslateToTensorRT.h"
#include "mlir-tensorrt/Compiler/Extension.h"
#include "mlir-tensorrt/Compiler/StableHloToExecutable.h"
#include "mlir-tensorrt/Compiler/TensorRTExtension/TensorRTExtension.h"
#include "mlir-executor/Target/Lua/TranslateToRuntimeExecutable.h"
#include "mlir-tensorrt/Dialect/Plan/IR/Plan.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Pass.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"

using namespace mlirtrt;
using namespace mlirtrt::compiler;
Expand All @@ -46,10 +50,14 @@ DEFINE_C_API_PTR_METHODS(MTRT_StableHLOToExecutableOptions,
StableHLOToExecutableOptions)
DEFINE_C_API_PTR_METHODS(MTRT_StableHLOProgramSignatureRefinementOptions,
StableHLOProgramSignatureRefinementOptions)

#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif

#define DEBUG_TYPE "compiler-api"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]")

/// Return the MTRT_StatusCode. These are auto-generated from the same schema as
/// the `mlirtrt::StatusCode`.
static MTRT_StatusCode
Expand Down Expand Up @@ -97,7 +105,7 @@ MTRT_Status mtrtCompilerClientCreate(MlirContext context,
}

MTRT_Status mtrtCompilerClientDestroy(MTRT_CompilerClient client) {
delete reinterpret_cast<MTRT_CompilerClient *>(client.ptr);
delete reinterpret_cast<CompilerClient *>(client.ptr);
return mtrtStatusGetOk();
}

Expand Down Expand Up @@ -256,6 +264,68 @@ MTRT_Status mtrtCompilerStableHLOToExecutable(
return mtrtStatusGetOk();
}

MTRT_Status mtrtCompilerPopulatePassManager(
MTRT_CompilerClient compilerClient,
MTRT_StableHLOToExecutableOptions stableHloToExecutableOptions,
MlirPassManager *pm) {

PassManager *passManager = llvm::dyn_cast<PassManager>(unwrap(*pm));

std::unique_ptr<StableHloToExecutableTask> runner{};

CompilerClient &client = *unwrap(compilerClient);
const StableHLOToExecutableOptions &options =
*unwrap(stableHloToExecutableOptions);

LLVM_DEBUG({
DBGS() << "compiling with options:\n";
options.print(llvm::dbgs());
llvm::dbgs() << "\n";
});

#ifndef NDEBUG
if (options.debugOptions.enableLLVMDebugFlag) {
SmallVector<const char *> debugTypeLiterals =
llvm::map_to_vector(options.debugOptions.llvmDebugTypes,
[](const std::string &x) { return x.c_str(); });
llvm::setCurrentDebugTypes(debugTypeLiterals.data(),
debugTypeLiterals.size());
llvm::DebugFlag = true;
}
#endif

if (options.getHash())
passManager =
&client.getOrCreatePassManager<StableHloToExecutableTask>(options);
else {
runner.reset(new StableHloToExecutableTask(client.getContext(), options));
CompilerClient::setupPassManagerLogging(*passManager, options.debugOptions);
passManager = runner.get();
}

return mtrtStatusGetOk();
}

MTRT_Status mtrtTranslateRuntimeToExecutable(MlirOperation moduleOp,
MTRT_Executable *result) {
ModuleOp module = llvm::dyn_cast<ModuleOp>(unwrap(moduleOp));

// Translate to Runtime Executable
FailureOr<std::unique_ptr<runtime::ExecutableStorage>> exeStorage =
mlir::translateToRuntimeExecutable(module);

if (failed(exeStorage))
return mtrtStatusCreate(MTRT_StatusCode::MTRT_StatusCode_InternalError,
"failed to translate compiled MLIR module to a "
"MLIR-TensorRT runtime Executable");

auto exe = std::make_unique<runtime::Executable>(std::move(*exeStorage));

result->ptr = exe.release();

return mtrtStatusGetOk();
}

//===----------------------------------------------------------------------===//
// Main StableHLO Program Signature Refinement Functions
//===----------------------------------------------------------------------===//
Expand Down
60 changes: 0 additions & 60 deletions mlir-tensorrt/compiler/lib/Compiler/StableHloToExecutable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,66 +388,6 @@ void StableHloToExecutableTask::populatePassManager(
mlir::executor::buildExecutorLoweringPipeline(pm, stdToExecOpts);
}

StatusOr<std::unique_ptr<runtime::Executable>>
StableHloToExecutableTask::compileStableHLOToExecutable(
mlir::ModuleOp module, const StableHLOToExecutableOptions &options) {
LLVM_DEBUG({
DBGS() << "compiling with options:\n";
options.print(llvm::dbgs());
llvm::dbgs() << "\n";
});

#ifndef NDEBUG
//===----------------------------------------------------------------------===//
// Set debug options.
//===----------------------------------------------------------------------===//
if (options.debugOptions.enableLLVMDebugFlag) {
SmallVector<const char *> debugTypeLiterals =
llvm::map_to_vector(options.debugOptions.llvmDebugTypes,
[](const std::string &x) { return x.c_str(); });
llvm::setCurrentDebugTypes(debugTypeLiterals.data(),
debugTypeLiterals.size());
llvm::DebugFlag = true;
}
#endif

//===----------------------------------------------------------------------===//
// Setup pass manager
//===----------------------------------------------------------------------===//

StableHloToExecutableTask runner(module->getContext(), options);
if (failed(setupPassManager(runner, options.debugOptions))) {
/// TODO: Ignored. This can fail if pass manager static CL options were not
/// registered/initialized. This happens through invocation of e.g. this
/// function in e.g. Python bindings or standalone calls to C++ or C API
/// without doing all the typical static CL setup. We should instead be
/// accepting a PassManager here that has already been setup to the caller's
/// specifications.
}
if (failed(runner.run(module)))
return getInternalErrorStatus(
"failed to run compilation on module with symbol name: {0}",
module.getName() ? *module.getName() : "no-symbol-name");

//===----------------------------------------------------------------------===//
// Translate to Runtime Executable
//===----------------------------------------------------------------------===//

FailureOr<std::unique_ptr<runtime::ExecutableStorage>> exeStorage =
mlir::translateToRuntimeExecutable(module);
if (failed(exeStorage))
return getStatusWithMsg(StatusCode::InternalError,
"failed to translate compiled MLIR module to a "
"MLIR-TensorRT runtime Executable");

#ifndef NDEBUG
// Turn debugging back off if we turned it on.
if (options.debugOptions.enableLLVMDebugFlag)
llvm::DebugFlag = false;
#endif

return std::make_unique<runtime::Executable>(std::move(*exeStorage));
}

mlirtrt::StatusOr<std::unique_ptr<runtime::Executable>>
StableHloToExecutableTask::compileStableHLOToExecutable(
Expand Down
42 changes: 42 additions & 0 deletions mlir-tensorrt/python/bindings/Compiler/CompilerPyBind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "../Utils.h"
#include "NvInferRuntime.h"
#include "mlir-c/IR.h"
#include "mlir-c/Pass.h"
#include "mlir-c/Support.h"
#include "mlir-executor-c/Common/Common.h"
#include "mlir-executor-c/Support/Status.h"
Expand All @@ -37,6 +38,26 @@ MTRT_DEFINE_COMPILER_INLINE_PY_CAPSULE_CASTER_FUNCS(

namespace {

// Define a type caster for MlirPassManager
namespace pybind11 { namespace detail {
template <> struct type_caster<MlirPassManager> {
PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager"));

// Conversion from Python to C++
bool load(py::handle src, bool) {
py::object capsule = mlirApiObjectToCapsule(src);
value = mlirPythonCapsuleToPassManager(capsule.ptr());
return !mlirPassManagerIsNull(value);
}

// Conversion from C++ to Python
static py::handle cast(MlirPassManager pm, py::return_value_policy, py::handle) {
if (mlirPassManagerIsNull(pm)) return py::none();
return py::reinterpret_steal<py::object>(mlirPythonPassManagerToCapsule(pm));
}
};
}}

//===----------------------------------------------------------------------===//
// Python Wrapper Classes
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -325,6 +346,27 @@ PYBIND11_MODULE(_api, m) {
},
py::arg("client"), py::arg("module"), py::arg("options"));

m.def(
"compiler_populate_pass_manager",
[](PyCompilerClient &client, PyStableHLOToExecutableOptions &options) {
MlirPassManager pm{nullptr};
MTRT_Status status =
mtrtCompilerPopulatePassManager(client, options, &pm);
THROW_IF_MTRT_ERROR(status);
return py::reinterpret_steal<py::object>(mlirPythonPassManagerToCapsule(pm));
},
py::arg("client"), py::arg("options"));

m.def(
"compiler_translate_to_executable",
[](MlirOperation module) {
MTRT_Executable exe{nullptr};
MTRT_Status status = mtrtTranslateRuntimeToExecutable(module, &exe);
THROW_IF_MTRT_ERROR(status);
return new PyExecutable(exe);
},
py::arg("module"));

m.def(
"get_stablehlo_program_refined_signature",
[](PyCompilerClient &client, MlirOperation module, std::string funcName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def flush():
sys.stderr.flush()


def compile_asm(ASM):
def compile_asm(ASM, use_pass_manager_api=False):
with Context() as context:
m = Module.parse(ASM)
client = api.CompilerClient(context)
Expand All @@ -93,7 +93,18 @@ def compile_asm(ASM):

print("running compilation (1)")
flush()
exe = api.compiler_stablehlo_to_executable(client, m.operation.clone(), opts)
if use_pass_manager_api:
pm = api.compiler_populate_pass_manager(client, opts)
import pdb

pdb.set_trace()
compiled_module = pm.run(m.operation.clone())
exe = api.compiler_translate_to_executable(compiled_module)
else:
exe = api.compiler_stablehlo_to_executable(
client, m.operation.clone(), opts
)

# Options don't change, so the cached pipeline should be re-used.
print("running compilation (2)")
flush()
Expand Down Expand Up @@ -126,7 +137,7 @@ def compile_asm(ASM):


print("Compiling static asm")
compile_asm(STATIC_ASM)
compile_asm(STATIC_ASM, use_pass_manager_api=True)
# CHECK-LABEL: Compiling static asm
# CHECK-LABEL: running compilation (1)
# CHECK: [translate-to-tensorrt] TranslateToTensorRTEnginePass is generating a new TensorRT builder
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# RUN: %PYTHON %s 2>&1
# RUN: %PYTHON %s 2>&1 | FileCheck %s
# REQUIRES: host-has-at-least-1-gpus
import os
import tempfile
Expand Down Expand Up @@ -50,3 +50,10 @@ def compile_asm(ASM):


compile_asm(ASM)

# CHECK: builtin.module
# CHECK: [translate-to-tensorrt] TranslateToTensorRTEnginePass is generating a new TensorRT builder
# CHECK: [translate-to-tensorrt] timing cache path was not specified, creating a fresh timing cache
# CHECK: [translate-to-tensorrt] deserializing TensorRT builder timing cache (0 bytes)
# CHECK: [translate-to-tensorrt] Setting builder optimization level to 3
# CHECK: [translate-to-tensorrt] replacing cache with updated data

0 comments on commit d17d489

Please sign in to comment.