-
Notifications
You must be signed in to change notification settings - Fork 610
[TOSA] Add transposed conv support #4360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2304,9 +2304,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Unimplemented: non-constant value for transposed not supported"); | ||
| if (transposed) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Unimplemented: transposed convolution not supported"); | ||
|
|
||
| auto input = adaptor.getInput(); | ||
| auto weight = adaptor.getWeight(); | ||
|
|
@@ -2338,12 +2335,17 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| auto bias = adaptor.getBias(); | ||
|
|
||
| if (isa<Torch::NoneType>(bias.getType())) { | ||
| auto bias_result = tosa::getConvBiasForNoneType(op, rewriter, inputElemTy, | ||
| outputElemTy, weightShape); | ||
| if (failed(bias_result)) | ||
| SmallVector<int64_t, 4> biasWeightShape = | ||
| transposed ? SmallVector<int64_t, 4>{weightShape[1], weightShape[0], | ||
| weightShape[2], weightShape[3]} | ||
| : weightShape; | ||
|
|
||
| auto biasResult = tosa::getConvBiasForNoneType( | ||
| op, rewriter, inputElemTy, outputElemTy, biasWeightShape); | ||
| if (failed(biasResult)) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Failed to create bias tensor for none type."); | ||
| bias = bias_result.value(); | ||
| bias = biasResult.value(); | ||
| } else { | ||
| if (!isa<RankedTensorType>(bias.getType())) | ||
| return rewriter.notifyMatchFailure( | ||
|
|
@@ -2370,8 +2372,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| m_TorchListOfConstantInts(padding_2d))) | ||
| return rewriter.notifyMatchFailure(op, | ||
| "non-const padding list unsupported"); | ||
| // TOSA uses 4D padding {top, bottom, left, right} while Torch defines 2D | ||
| // padding {height, width}. The Torch OFM computation uses 2*pad in each | ||
| // TOSA uses 4D padding {top, bottom, left, right} while PyTorch defines 2D | ||
| // padding {height, width}. The PyTorch OFM computation uses 2*pad in each | ||
| // spatial direction, implying the same top=bottom=height and left=right=width | ||
| // values for TOSA. | ||
| SmallVector<int64_t> padding( | ||
|
|
@@ -2388,9 +2390,15 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| return rewriter.notifyMatchFailure( | ||
| op, "failed to get accumulator type for convolution ops"); | ||
|
|
||
| // Weight layout reference: | ||
| // Conv : PyTorch OIHW -> TOSA OHWI | ||
| // Depthwise : PyTorch OIHW* -> TOSA HWIM (*out = in * multiplier) | ||
| // Grouped : PyTorch O(I/G)HW -> N/A | ||
| // Transposed : PyTorch IOHW -> TOSA OHWI | ||
| // TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights. | ||
| // Perform the necessary transformations. | ||
| SmallVector<int32_t> nchwToNhwcDims({0, 2, 3, 1}); | ||
| SmallVector<int32_t> nhwcToNchwDims({0, 3, 1, 2}); | ||
| SmallVector<int64_t> transposedInputShape( | ||
| {inputShape[0], inputShape[2], inputShape[3], inputShape[1]}); | ||
| auto transposedInputType = RankedTensorType::get( | ||
|
|
@@ -2403,6 +2411,101 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) | ||
| .getResult(); | ||
|
|
||
| if (transposed) { | ||
| if (groups != 1) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Unimplemented: grouped transposed convolution not supported by " | ||
| "TOSA"); | ||
| if (dilation[0] != 1 || dilation[1] != 1) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Unimplemented: dilated transposed convolution not supported by " | ||
| "TOSA"); | ||
|
|
||
| SmallVector<int32_t> iohwToOhwi({1, 2, 3, 0}); | ||
| SmallVector<int64_t> ohwiWeightShape( | ||
| {weightShape[1], weightShape[2], weightShape[3], weightShape[0]}); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #nit: weightShape[iohwToOhwi[0]], weightShape[iohwToOhwi[1]], ... and so on will help readability I think |
||
| auto ohwiWeightType = RankedTensorType::get( | ||
| makeShapeLLVMCompatible(ohwiWeightShape), weightElemTy); | ||
| Value transformedWeight = | ||
| rewriter | ||
| .create<tosa::TransposeOp>( | ||
| op->getLoc(), getTypeConverter()->convertType(ohwiWeightType), | ||
| weight, rewriter.getDenseI32ArrayAttr(iohwToOhwi)) | ||
| .getResult(); | ||
|
|
||
| // TOSA 'out_pad' is a 4D array {top,bottom,left,right}. | ||
| // Map from PyTorch's (padding, output_padding): | ||
| // out_pad_total(H/W) = output_padding(H/W) - 2*padding(H/W) | ||
| // Negative values are allowed and will be handled by the TOSA | ||
| // decomposition. | ||
| SmallVector<int64_t, 2> outPadding2D; | ||
| if (!matchPattern(adaptor.getOutputPadding(), | ||
| m_TorchListOfConstantInts(outPadding2D))) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "non-const output_padding list unsupported for transposed conv"); | ||
|
|
||
| int64_t outPadH = outPadding2D[0] - 2 * padding_2d[0]; | ||
| int64_t outPadW = outPadding2D[1] - 2 * padding_2d[1]; | ||
| int64_t outPadTop = outPadH / 2; | ||
| int64_t outPadBottom = outPadH - outPadTop; | ||
| int64_t outPadLeft = outPadW / 2; | ||
| int64_t outPadRight = outPadW - outPadLeft; | ||
| SmallVector<int64_t, 4> outPad( | ||
| {outPadTop, outPadBottom, outPadLeft, outPadRight}); | ||
|
|
||
| // Result type is NHWC (we'll transpose back). | ||
| auto outNCHW = makeShapeTorchCompatible(outputTy.getShape()); | ||
| SmallVector<int64_t> outNHWC( | ||
| {outNCHW[0], outNCHW[2], outNCHW[3], outNCHW[1]}); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar comment to above to use the nchwToNhwcDims instead of hardcoding 0,2,3,1 |
||
| auto transConvOpTy = | ||
| RankedTensorType::get(makeShapeLLVMCompatible(outNHWC), biasElemTy); | ||
|
|
||
| // Zero-points. | ||
| auto zps = tosa::createZPsAsConst(rewriter, input, weight); | ||
| Value inputZp = zps.first ? zps.first | ||
| : tosa::createZeroPointTensor( | ||
| rewriter, op->getLoc(), inputElemTy, 0) | ||
| .value(); | ||
| Value weightZp = zps.second ? zps.second | ||
| : tosa::createZeroPointTensor( | ||
| rewriter, op->getLoc(), weightElemTy, 0) | ||
| .value(); | ||
|
|
||
| Value convTOut = | ||
| rewriter | ||
| .create<tosa::TransposeConv2DOp>( | ||
| op->getLoc(), getTypeConverter()->convertType(transConvOpTy), | ||
| transposedInput, transformedWeight, bias, inputZp, weightZp, | ||
| rewriter.getDenseI64ArrayAttr(outPad), | ||
| rewriter.getDenseI64ArrayAttr(stride), accType) | ||
| .getResult(); | ||
|
|
||
| SmallVector<int64_t> transposedOutputShape( | ||
| {outNHWC[0], outNHWC[3], outNHWC[1], outNHWC[2]}); | ||
| auto transposedOutputType = RankedTensorType::get( | ||
| makeShapeLLVMCompatible(transposedOutputShape), biasElemTy); | ||
| Value transposedOutput = | ||
| rewriter | ||
| .create<tosa::TransposeOp>( | ||
| op->getLoc(), | ||
| getTypeConverter()->convertType(transposedOutputType), convTOut, | ||
| rewriter.getDenseI32ArrayAttr(nhwcToNchwDims)) | ||
| .getResult(); | ||
|
|
||
| // Quantized rescale. | ||
| Value rescaledResult = transposedOutput; | ||
| if (isa<quant::QuantizedType>(inputElemTy)) { | ||
| rescaledResult = tosa::buildRescaleOpConvOutput( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ConvTranspose2DQInt8_basic e2e test likely triggers this code path -- any idea why that is still failing? If that's not the correct test, are there any e2e tests that triggers this code? |
||
| rewriter, op, transposedOutput, inputTy, weightTy, outputTy); | ||
| } | ||
|
|
||
| // Final cast to requested output type. | ||
| rewriter.replaceOp( | ||
| op, {tosa::tosaCastTensorToType(rewriter, rescaledResult, outputTy) | ||
| .value()}); | ||
| return success(); | ||
| } | ||
|
|
||
| SmallVector<int64_t> transformedWeightShape; | ||
| RankedTensorType transformedWeightType; | ||
| Value transformedWeight; | ||
|
|
@@ -2583,7 +2686,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| llvm_unreachable("Unhandled convolution type"); | ||
| } | ||
|
|
||
| SmallVector<int32_t> nhwcToNchwDims({0, 3, 1, 2}); | ||
| SmallVector<int64_t> transposedOutputShape( | ||
| {outputShape[0], outputShape[3], outputShape[1], outputShape[2]}); | ||
| auto transposedOutputType = RankedTensorType::get( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,23 @@ | ||
| // RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics | ||
| // RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is verify-diagnostics needed anymore unless you plan to add tests to verify the failure cases? |
||
|
|
||
| // The following test ensures that a tranposed convolution op is not | ||
| // lowered in the torch-to-tosa conversion pass. | ||
| // The lowering now legalizes transpose convolutions into the TOSA dialect. | ||
| // Verify that we emit tosa.transpose_conv2d with the expected reshapes/ | ||
| // permutations. | ||
|
|
||
| // CHECK-LABEL: func.func @forward | ||
| // CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> { | ||
| // CHECK: %[[IN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT]] : !torch.vtensor<[1,64,1,100],f32> -> tensor<1x64x1x100xf32> | ||
| // CHECK: %[[WEIGHT:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<64x64x3x3xf32>}> : () -> tensor<64x64x3x3xf32> | ||
| // CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<64xf32>}> : () -> tensor<64xf32> | ||
| // CHECK: %[[TRANS_IN:.*]] = tosa.transpose %[[IN_TENSOR]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x64x1x100xf32>) -> tensor<1x1x100x64xf32> | ||
| // CHECK: %[[W_OHWI:.*]] = tosa.transpose %[[WEIGHT]] {perms = array<i32: 1, 2, 3, 0>} : (tensor<64x64x3x3xf32>) -> tensor<64x3x3x64xf32> | ||
| // CHECK: %[[ZP0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> | ||
| // CHECK: %[[ZP1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> | ||
| // CHECK: %[[TCONV:.*]] = tosa.transpose_conv2d %[[TRANS_IN]], %[[W_OHWI]], %[[BIAS]], %[[ZP0]], %[[ZP1]] {acc_type = f32, out_pad = array<i64: 0, -1, 0, -1>, stride = array<i64: 2, 2>} : (tensor<1x1x100x64xf32>, tensor<64x3x3x64xf32>, tensor<64xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x2x200x64xf32> | ||
| // CHECK: %[[TRANS_OUT:.*]] = tosa.transpose %[[TCONV]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x2x200x64xf32>) -> tensor<1x64x2x200xf32> | ||
| // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[TRANS_OUT]] : tensor<1x64x2x200xf32> -> !torch.vtensor<[1,64,2,200],f32> | ||
| // CHECK: return %[[RESULT]] : !torch.vtensor<[1,64,2,200],f32> | ||
| // CHECK: } | ||
| func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> { | ||
| %true = torch.constant.bool true | ||
| %int1 = torch.constant.int 1 | ||
|
|
@@ -11,7 +26,6 @@ func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[ | |
| %bias = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32> | ||
| %stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int> | ||
| %int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int> | ||
| // expected-error@+1 {{failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal}} | ||
| %output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,64,2,200],f32> | ||
| return %output : !torch.vtensor<[1,64,2,200],f32> | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this reference, though it's not clear what you are trying to imply with
(*out = in * multiplier)here.