Skip to content

Commit

Permalink
[Dialect/TensorRT] Update tensorrrt dialect to use non-dps calling co…
Browse files Browse the repository at this point in the history
…nvention (#287)

Add non-DPS variant of `tensorrt::CallOp` i.e. `tensorrt::CallAllocOp`.
Add conversion from call alloc op to runtime dialect enqueue alloc op.
  • Loading branch information
jhalakpatel authored Oct 26, 2024
1 parent 790bf78 commit 20bcb7c
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,78 @@ struct RewriteConstants : public OpRewritePattern<tensorrt::ConstantOp> {
}
};

// Helper function to convert both CallOp and CallAllocOp
static LogicalResult
convertCallOp(Operation *op, IRRewriter &rewriter,
SymbolTableCollection &symbolTable, DataFlowSolver &solver,
ModuleOp module,
SmallVectorImpl<tensorrt::TensorRTModuleOp> &trtModules) {
MLIRContext *ctx = rewriter.getContext();
Location loc = op->getLoc();
ValueRange inputs;
ValueRange outputs;
func::FuncOp trtFunc;

if (auto callOp = dyn_cast<tensorrt::CallOp>(op)) {
trtFunc = callOp.getFuncCallee(symbolTable);
inputs = callOp.getInputs();
outputs = callOp.getOutputs();
} else if (auto callAllocOp = dyn_cast<tensorrt::CallAllocOp>(op)) {
trtFunc = callAllocOp.getFuncCallee(symbolTable);
inputs = callAllocOp.getInputs();
// CallAllocOp doesn't have outputs as operands
} else {
llvm_unreachable("unexpected type of operation. Only callOp and "
"callAllocOp are supported.");
return failure();
}

solver.load<dataflow::DeadCodeAnalysis>();
solver.load<dataflow::SparseConstantPropagation>();
solver.load<TensorKindAnalysis>(symbolTable);
if (failed(solver.initializeAndRun(trtFunc)))
return trtFunc.emitError() << "failed to run TensorKindAnalysis";

// Check which tensors should be host tensors.
SmallVector<int64_t> hostTensorArgs;
for (auto [idx, arg] : llvm::enumerate(trtFunc.getArguments())) {
const TensorKindLattice *kind = solver.lookupState<TensorKindLattice>(arg);
RankedTensorType rtt = cast<RankedTensorType>(arg.getType());
// To be conservative, we only do this if type is i32 and num elements
// <= 8.
if (kind && !kind->getValue().isUninitialized() &&
kind->getValue().isHostVisible() &&
rtt.getElementType().isInteger(32) && rtt.getNumElements() <= 8)
hostTensorArgs.push_back(idx);
}

rewriter.setInsertionPoint(op);

Value executionContext = rewriter.create<trtrt::CompileOp>(
loc,
SymbolRefAttr::get(rewriter.getStringAttr(*trtModules.front().getName()),
{FlatSymbolRefAttr::get(trtFunc)}));
Value stream = rewriter.create<cuda::GetGlobalStreamOp>(loc, 0);

Operation *enqueueOp;
if (isa<tensorrt::CallOp>(op)) {
enqueueOp = rewriter.create<trtrt::EnqueueOp>(
loc, executionContext, stream, inputs, outputs,
/*host_tensors_args=*/hostTensorArgs.empty()
? DenseI64ArrayAttr{}
: DenseI64ArrayAttr::get(ctx, hostTensorArgs));
} else {
enqueueOp = rewriter.create<trtrt::EnqueueAllocOp>(
loc, op->getResultTypes(), executionContext, stream, inputs,
/*host_tensors_args=*/hostTensorArgs.empty()
? DenseI64ArrayAttr{}
: DenseI64ArrayAttr::get(ctx, hostTensorArgs));
}
rewriter.replaceOp(op, enqueueOp->getResults());

return success();
}

class ConvertTensorRTToRuntimePass
: public mlir::impl::ConvertTensorRTToTensorRTRuntimePassBase<
ConvertTensorRTToRuntimePass> {
Expand All @@ -83,53 +155,19 @@ class ConvertTensorRTToRuntimePass
return signalPassFailure();
}

SmallVector<tensorrt::CallOp> callOps;
module.walk(
[&](tensorrt::CallOp compileOp) { callOps.push_back(compileOp); });
SmallVector<Operation *> callOps;
module.walk([&](Operation *op) {
if (isa<tensorrt::CallOp, tensorrt::CallAllocOp>(op))
callOps.push_back(op);
});

for (auto callOp : llvm::make_early_inc_range(callOps)) {
Location loc = callOp.getLoc();
func::FuncOp trtFunc = dyn_cast_or_null<func::FuncOp>(
module.lookupSymbol(callOp.getCallee()));

SymbolTableCollection symbolTable;
DataFlowSolver solver;
solver.load<dataflow::DeadCodeAnalysis>();
solver.load<dataflow::SparseConstantPropagation>();
solver.load<TensorKindAnalysis>(symbolTable);
if (failed(solver.initializeAndRun(trtFunc))) {
trtFunc.emitError() << "failed to run TensorKindAnalysis";
if (failed(convertCallOp(callOp, rewriter, symbolTable, solver, module,
trtModules))) {
return signalPassFailure();
}

// Check which tensors should be host tensors.
SmallVector<int64_t> hostTensorArgs;
for (auto [idx, arg] : llvm::enumerate(trtFunc.getArguments())) {
const TensorKindLattice *kind =
solver.lookupState<TensorKindLattice>(arg);
RankedTensorType rtt = cast<RankedTensorType>(arg.getType());
// To be conservative, we only do this if type is i32 and num elements
// <= 8.
if (kind && !kind->getValue().isUninitialized() &&
kind->getValue().isHostVisible() &&
rtt.getElementType().isInteger(32) && rtt.getNumElements() <= 8)
hostTensorArgs.push_back(idx);
}

rewriter.setInsertionPoint(callOp);
Value executionContext = rewriter.create<trtrt::CompileOp>(
loc, SymbolRefAttr::get(
rewriter.getStringAttr(trtModules.front().getSymName()),
{FlatSymbolRefAttr::get(trtFunc)}));
Value stream = rewriter.create<cuda::GetGlobalStreamOp>(loc, 0);
auto enqueueOp = rewriter.create<trtrt::EnqueueOp>(
loc, executionContext, stream, callOp.getInputs(),
callOp.getOutputs(),
/*host_tensors_args=*/hostTensorArgs.empty()
? DenseI64ArrayAttr{}
: DenseI64ArrayAttr::get(ctx, hostTensorArgs));
rewriter.setInsertionPointAfter(enqueueOp);
rewriter.replaceOp(callOp, enqueueOp->getResults());
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,58 @@ def TensorRT_CallOp : TensorRT_Op<"call", [
}];
}

//===----------------------------------------------------------------------===//
// CallAllocOp
//===----------------------------------------------------------------------===//

def TensorRT_CallAllocOp : TensorRT_Op<"call_alloc", [
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
CallOpInterface
]> {
let summary = "calls a TensorRT engine defined in a `tensorrt.module` and allocates output tensors";

let description = [{
This operation calls a TensorRT engine and allocates output tensors. It will be converted to an
`enqueue_alloc` operation in a later pass.
}];

let arguments = (ins
Variadic<AnyTypeOf<[AnyShaped, AnySignlessIntegerOrIndex]>>:$inputs,
SymbolRefAttr:$callee
);

let results = (outs Variadic<AnyTypeOf<[AnyMemRef, AnyTensor]>>:$results);

let assemblyFormat = [{
$callee `(` ($inputs^ `:` type($inputs))? `)`
attr-dict (`->` type($results)^)?
}];

let extraClassDeclaration = [{
/// Return the function representing the TRT engine that is being called.
func::FuncOp getFuncCallee(SymbolTableCollection &symbolTable);

//===------------------------------------------------------------------===//
// CallOpInterface
//===------------------------------------------------------------------===//
operand_range getArgOperands() {
return getInputs();
}

CallInterfaceCallable getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}

void setCalleeFromCallable(CallInterfaceCallable callee) {
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
}

MutableOperandRange getArgOperandsMutable() {
return getInputsMutable();
}
}];
}

//===----------------------------------------------------------------------===//
// ActivationOp
//===----------------------------------------------------------------------===//
Expand Down
29 changes: 29 additions & 0 deletions mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,35 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}

//===----------------------------------------------------------------------===//
// CallAllocOp
//===----------------------------------------------------------------------===//

func::FuncOp CallAllocOp::getFuncCallee(SymbolTableCollection &symbolTable) {
Operation *module = (*this)->getParentWithTrait<OpTrait::SymbolTable>();
assert(module && "expected call to be nested within symbol table");
return dyn_cast_or_null<func::FuncOp>(
symbolTable.lookupNearestSymbolFrom(module, getCallee()));
}

LogicalResult
CallAllocOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
func::FuncOp kernel = getFuncCallee(symbolTable);
if (!kernel)
return emitOpError() << "no valid kernel found with symbol name "
<< getCallee();
FunctionType funcType = kernel.getFunctionType();

if (funcType.getNumInputs() != getInputs().size() ||
funcType.getNumResults() != getResultTypes().size() ||
!areTensorTypesCompatible(TypeRange(getInputs()), funcType.getInputs()))
return emitOpError()
<< "callee has function type " << funcType
<< " which is not compatible with input/result types of call";

return success();
}

//===----------------------------------------------------------------------===//
// ElementwiseOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,30 @@ func.func @main(%arg0: tensor<1x3x256x256xf32>, %arg1: tensor<1x3x256x256xf32>)

// -----

tensorrt.module @trt_engines {
func.func @trt_func(%arg0: tensor<1x3x256x256xf32>) -> tensor<1x3x256x256xf32> attributes {
"tensorrt.engine" = dense<0> : vector<8xi8>
} {
%cst_f32 = tensorrt.constant dense<0.00392156886> : tensor<1xf32>
%0 = tensorrt.shuffle {first_transpose = array<i64: 0>, reshape = array<i64: 1, 1, 1, 1>, second_transpose = array<i64: 0, 1, 2, 3>, zero_is_placeholder = false} ins(%cst_f32 : tensor<1xf32>) -> tensor<1x1x1x1xf32>
%1 = tensorrt.element_wise <kPROD>(%arg0, %0 : tensor<1x3x256x256xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x256x256xf32>
return %1 : tensor<1x3x256x256xf32>
}
}
func.func @main(%arg0: tensor<1x3x256x256xf32>, %arg1: tensor<1x3x256x256xf32>) -> tensor<1x3x256x256xf32> {
%1 = tensorrt.call_alloc @trt_engines::@trt_func(%arg0 : tensor<1x3x256x256xf32>) -> tensor<1x3x256x256xf32>
return %1 : tensor<1x3x256x256xf32>
}

// CHECK-LABEL: @main
// CHECK-SAME: (%[[arg0:.+]]: tensor<1x3x256x256xf32>, %[[arg1:.+]]: tensor<1x3x256x256xf32>) -> tensor<1x3x256x256xf32> {
// CHECK: %[[v1:.+]] = trtrt.compile @trt_engines::@trt_func : !trtrt.context
// CHECK: %[[v2:.+]] = cuda.get_global_stream 0
// CHECK: %[[v3:.+]] = trtrt.enqueue_alloc %[[v1]] stream(%[[v2]]) (%[[arg0]]) : (tensor<1x3x256x256xf32>) -> tensor<1x3x256x256xf32>
// CHECK: return %[[v3]] : tensor<1x3x256x256xf32>

// -----

// CHECK-LABEL: @convert_tensorrt_const
// CHECK-NEXT: arith.constant
// CHECK-NEXT: return
Expand Down

0 comments on commit 20bcb7c

Please sign in to comment.