diff --git a/mlir-tensorrt/compiler/lib/Conversion/TensorRTToTensorRTRuntime/TensorRTToTensorRTRuntime.cpp b/mlir-tensorrt/compiler/lib/Conversion/TensorRTToTensorRTRuntime/TensorRTToTensorRTRuntime.cpp index cc4cc2c14..f8382596f 100644 --- a/mlir-tensorrt/compiler/lib/Conversion/TensorRTToTensorRTRuntime/TensorRTToTensorRTRuntime.cpp +++ b/mlir-tensorrt/compiler/lib/Conversion/TensorRTToTensorRTRuntime/TensorRTToTensorRTRuntime.cpp @@ -57,6 +57,78 @@ struct RewriteConstants : public OpRewritePattern { } }; +// Helper function to convert both CallOp and CallAllocOp +static LogicalResult +convertCallOp(Operation *op, IRRewriter &rewriter, + SymbolTableCollection &symbolTable, DataFlowSolver &solver, + ModuleOp module, + SmallVectorImpl &trtModules) { + MLIRContext *ctx = rewriter.getContext(); + Location loc = op->getLoc(); + ValueRange inputs; + ValueRange outputs; + func::FuncOp trtFunc; + + if (auto callOp = dyn_cast(op)) { + trtFunc = callOp.getFuncCallee(symbolTable); + inputs = callOp.getInputs(); + outputs = callOp.getOutputs(); + } else if (auto callAllocOp = dyn_cast(op)) { + trtFunc = callAllocOp.getFuncCallee(symbolTable); + inputs = callAllocOp.getInputs(); + // CallAllocOp doesn't have outputs as operands + } else { + llvm_unreachable("unexpected type of operation. Only callOp and " + "callAllocOp are supported."); + return failure(); + } + + solver.load(); + solver.load(); + solver.load(symbolTable); + if (failed(solver.initializeAndRun(trtFunc))) + return trtFunc.emitError() << "failed to run TensorKindAnalysis"; + + // Check which tensors should be host tensors. + SmallVector hostTensorArgs; + for (auto [idx, arg] : llvm::enumerate(trtFunc.getArguments())) { + const TensorKindLattice *kind = solver.lookupState(arg); + RankedTensorType rtt = cast(arg.getType()); + // To be conservative, we only do this if type is i32 and num elements + // <= 8. + if (kind && !kind->getValue().isUninitialized() && + kind->getValue().isHostVisible() && + rtt.getElementType().isInteger(32) && rtt.getNumElements() <= 8) + hostTensorArgs.push_back(idx); + } + + rewriter.setInsertionPoint(op); + + Value executionContext = rewriter.create( + loc, + SymbolRefAttr::get(rewriter.getStringAttr(*trtModules.front().getName()), + {FlatSymbolRefAttr::get(trtFunc)})); + Value stream = rewriter.create(loc, 0); + + Operation *enqueueOp; + if (isa(op)) { + enqueueOp = rewriter.create( + loc, executionContext, stream, inputs, outputs, + /*host_tensors_args=*/hostTensorArgs.empty() + ? DenseI64ArrayAttr{} + : DenseI64ArrayAttr::get(ctx, hostTensorArgs)); + } else { + enqueueOp = rewriter.create( + loc, op->getResultTypes(), executionContext, stream, inputs, + /*host_tensors_args=*/hostTensorArgs.empty() + ? DenseI64ArrayAttr{} + : DenseI64ArrayAttr::get(ctx, hostTensorArgs)); + } + rewriter.replaceOp(op, enqueueOp->getResults()); + + return success(); +} + class ConvertTensorRTToRuntimePass : public mlir::impl::ConvertTensorRTToTensorRTRuntimePassBase< ConvertTensorRTToRuntimePass> { @@ -83,53 +155,19 @@ class ConvertTensorRTToRuntimePass return signalPassFailure(); } - SmallVector callOps; - module.walk( - [&](tensorrt::CallOp compileOp) { callOps.push_back(compileOp); }); + SmallVector callOps; + module.walk([&](Operation *op) { + if (isa(op)) + callOps.push_back(op); + }); for (auto callOp : llvm::make_early_inc_range(callOps)) { - Location loc = callOp.getLoc(); - func::FuncOp trtFunc = dyn_cast_or_null( - module.lookupSymbol(callOp.getCallee())); - SymbolTableCollection symbolTable; DataFlowSolver solver; - solver.load(); - solver.load(); - solver.load(symbolTable); - if (failed(solver.initializeAndRun(trtFunc))) { - trtFunc.emitError() << "failed to run TensorKindAnalysis"; + if (failed(convertCallOp(callOp, rewriter, symbolTable, solver, module, + trtModules))) { return signalPassFailure(); } - - // Check which tensors should be host tensors. - SmallVector hostTensorArgs; - for (auto [idx, arg] : llvm::enumerate(trtFunc.getArguments())) { - const TensorKindLattice *kind = - solver.lookupState(arg); - RankedTensorType rtt = cast(arg.getType()); - // To be conservative, we only do this if type is i32 and num elements - // <= 8. - if (kind && !kind->getValue().isUninitialized() && - kind->getValue().isHostVisible() && - rtt.getElementType().isInteger(32) && rtt.getNumElements() <= 8) - hostTensorArgs.push_back(idx); - } - - rewriter.setInsertionPoint(callOp); - Value executionContext = rewriter.create( - loc, SymbolRefAttr::get( - rewriter.getStringAttr(trtModules.front().getSymName()), - {FlatSymbolRefAttr::get(trtFunc)})); - Value stream = rewriter.create(loc, 0); - auto enqueueOp = rewriter.create( - loc, executionContext, stream, callOp.getInputs(), - callOp.getOutputs(), - /*host_tensors_args=*/hostTensorArgs.empty() - ? DenseI64ArrayAttr{} - : DenseI64ArrayAttr::get(ctx, hostTensorArgs)); - rewriter.setInsertionPointAfter(enqueueOp); - rewriter.replaceOp(callOp, enqueueOp->getResults()); } } }; diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td index 7be45fb25..dad8043ba 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td @@ -129,6 +129,58 @@ def TensorRT_CallOp : TensorRT_Op<"call", [ }]; } +//===----------------------------------------------------------------------===// +// CallAllocOp +//===----------------------------------------------------------------------===// + +def TensorRT_CallAllocOp : TensorRT_Op<"call_alloc", [ + DeclareOpInterfaceMethods, + CallOpInterface +]> { + let summary = "calls a TensorRT engine defined in a `tensorrt.module` and allocates output tensors"; + + let description = [{ + This operation calls a TensorRT engine and allocates output tensors. It will be converted to an + `enqueue_alloc` operation in a later pass. + }]; + + let arguments = (ins + Variadic>:$inputs, + SymbolRefAttr:$callee + ); + + let results = (outs Variadic>:$results); + + let assemblyFormat = [{ + $callee `(` ($inputs^ `:` type($inputs))? `)` + attr-dict (`->` type($results)^)? + }]; + + let extraClassDeclaration = [{ + /// Return the function representing the TRT engine that is being called. + func::FuncOp getFuncCallee(SymbolTableCollection &symbolTable); + + //===------------------------------------------------------------------===// + // CallOpInterface + //===------------------------------------------------------------------===// + operand_range getArgOperands() { + return getInputs(); + } + + CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + + void setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); + } + + MutableOperandRange getArgOperandsMutable() { + return getInputsMutable(); + } + }]; +} + //===----------------------------------------------------------------------===// // ActivationOp //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp index de35fbe88..8fcb876bf 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRT.cpp @@ -192,6 +192,35 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } +//===----------------------------------------------------------------------===// +// CallAllocOp +//===----------------------------------------------------------------------===// + +func::FuncOp CallAllocOp::getFuncCallee(SymbolTableCollection &symbolTable) { + Operation *module = (*this)->getParentWithTrait(); + assert(module && "expected call to be nested within symbol table"); + return dyn_cast_or_null( + symbolTable.lookupNearestSymbolFrom(module, getCallee())); +} + +LogicalResult +CallAllocOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + func::FuncOp kernel = getFuncCallee(symbolTable); + if (!kernel) + return emitOpError() << "no valid kernel found with symbol name " + << getCallee(); + FunctionType funcType = kernel.getFunctionType(); + + if (funcType.getNumInputs() != getInputs().size() || + funcType.getNumResults() != getResultTypes().size() || + !areTensorTypesCompatible(TypeRange(getInputs()), funcType.getInputs())) + return emitOpError() + << "callee has function type " << funcType + << " which is not compatible with input/result types of call"; + + return success(); +} + //===----------------------------------------------------------------------===// // ElementwiseOp //===----------------------------------------------------------------------===// diff --git a/mlir-tensorrt/test/Conversion/TensorRTToTensorRTRuntime/tensorrt-to-tensorrt-runtime.mlir b/mlir-tensorrt/test/Conversion/TensorRTToTensorRTRuntime/tensorrt-to-tensorrt-runtime.mlir index 28058a1af..e24f5ac85 100644 --- a/mlir-tensorrt/test/Conversion/TensorRTToTensorRTRuntime/tensorrt-to-tensorrt-runtime.mlir +++ b/mlir-tensorrt/test/Conversion/TensorRTToTensorRTRuntime/tensorrt-to-tensorrt-runtime.mlir @@ -26,6 +26,30 @@ func.func @main(%arg0: tensor<1x3x256x256xf32>, %arg1: tensor<1x3x256x256xf32>) // ----- +tensorrt.module @trt_engines { + func.func @trt_func(%arg0: tensor<1x3x256x256xf32>) -> tensor<1x3x256x256xf32> attributes { + "tensorrt.engine" = dense<0> : vector<8xi8> + } { + %cst_f32 = tensorrt.constant dense<0.00392156886> : tensor<1xf32> + %0 = tensorrt.shuffle {first_transpose = array, reshape = array, second_transpose = array, zero_is_placeholder = false} ins(%cst_f32 : tensor<1xf32>) -> tensor<1x1x1x1xf32> + %1 = tensorrt.element_wise (%arg0, %0 : tensor<1x3x256x256xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x256x256xf32> + return %1 : tensor<1x3x256x256xf32> + } +} +func.func @main(%arg0: tensor<1x3x256x256xf32>, %arg1: tensor<1x3x256x256xf32>) -> tensor<1x3x256x256xf32> { + %1 = tensorrt.call_alloc @trt_engines::@trt_func(%arg0 : tensor<1x3x256x256xf32>) -> tensor<1x3x256x256xf32> + return %1 : tensor<1x3x256x256xf32> +} + +// CHECK-LABEL: @main +// CHECK-SAME: (%[[arg0:.+]]: tensor<1x3x256x256xf32>, %[[arg1:.+]]: tensor<1x3x256x256xf32>) -> tensor<1x3x256x256xf32> { +// CHECK: %[[v1:.+]] = trtrt.compile @trt_engines::@trt_func : !trtrt.context +// CHECK: %[[v2:.+]] = cuda.get_global_stream 0 +// CHECK: %[[v3:.+]] = trtrt.enqueue_alloc %[[v1]] stream(%[[v2]]) (%[[arg0]]) : (tensor<1x3x256x256xf32>) -> tensor<1x3x256x256xf32> +// CHECK: return %[[v3]] : tensor<1x3x256x256xf32> + +// ----- + // CHECK-LABEL: @convert_tensorrt_const // CHECK-NEXT: arith.constant // CHECK-NEXT: return