Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 112 additions & 10 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2304,9 +2304,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed)))
return rewriter.notifyMatchFailure(
op, "Unimplemented: non-constant value for transposed not supported");
if (transposed)
return rewriter.notifyMatchFailure(
op, "Unimplemented: transposed convolution not supported");

auto input = adaptor.getInput();
auto weight = adaptor.getWeight();
Expand Down Expand Up @@ -2338,12 +2335,17 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
auto bias = adaptor.getBias();

if (isa<Torch::NoneType>(bias.getType())) {
auto bias_result = tosa::getConvBiasForNoneType(op, rewriter, inputElemTy,
outputElemTy, weightShape);
if (failed(bias_result))
SmallVector<int64_t, 4> biasWeightShape =
transposed ? SmallVector<int64_t, 4>{weightShape[1], weightShape[0],
weightShape[2], weightShape[3]}
: weightShape;

auto biasResult = tosa::getConvBiasForNoneType(
op, rewriter, inputElemTy, outputElemTy, biasWeightShape);
if (failed(biasResult))
return rewriter.notifyMatchFailure(
op, "Failed to create bias tensor for none type.");
bias = bias_result.value();
bias = biasResult.value();
} else {
if (!isa<RankedTensorType>(bias.getType()))
return rewriter.notifyMatchFailure(
Expand All @@ -2370,8 +2372,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
m_TorchListOfConstantInts(padding_2d)))
return rewriter.notifyMatchFailure(op,
"non-const padding list unsupported");
// TOSA uses 4D padding {top, bottom, left, right} while Torch defines 2D
// padding {height, width}. The Torch OFM computation uses 2*pad in each
// TOSA uses 4D padding {top, bottom, left, right} while PyTorch defines 2D
// padding {height, width}. The PyTorch OFM computation uses 2*pad in each
// spatial direction, implying the same top=bottom=height and left=right=width
// values for TOSA.
SmallVector<int64_t> padding(
Expand All @@ -2388,9 +2390,15 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(
op, "failed to get accumulator type for convolution ops");

// Weight layout reference:
// Conv : PyTorch OIHW -> TOSA OHWI
// Depthwise : PyTorch OIHW* -> TOSA HWIM (*out = in * multiplier)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this reference, though it's not clear what you are trying to imply with (*out = in * multiplier) here.

// Grouped : PyTorch O(I/G)HW -> N/A
// Transposed : PyTorch IOHW -> TOSA OHWI
// TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights.
// Perform the necessary transformations.
SmallVector<int32_t> nchwToNhwcDims({0, 2, 3, 1});
SmallVector<int32_t> nhwcToNchwDims({0, 3, 1, 2});
SmallVector<int64_t> transposedInputShape(
{inputShape[0], inputShape[2], inputShape[3], inputShape[1]});
auto transposedInputType = RankedTensorType::get(
Expand All @@ -2403,6 +2411,101 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
rewriter.getDenseI32ArrayAttr(nchwToNhwcDims))
.getResult();

if (transposed) {
if (groups != 1)
return rewriter.notifyMatchFailure(
op, "Unimplemented: grouped transposed convolution not supported by "
"TOSA");
if (dilation[0] != 1 || dilation[1] != 1)
return rewriter.notifyMatchFailure(
op, "Unimplemented: dilated transposed convolution not supported by "
"TOSA");

SmallVector<int32_t> iohwToOhwi({1, 2, 3, 0});
SmallVector<int64_t> ohwiWeightShape(
{weightShape[1], weightShape[2], weightShape[3], weightShape[0]});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#nit: weightShape[iohwToOhwi[0]], weightShape[iohwToOhwi[1]], ... and so on will help readability I think

auto ohwiWeightType = RankedTensorType::get(
makeShapeLLVMCompatible(ohwiWeightShape), weightElemTy);
Value transformedWeight =
rewriter
.create<tosa::TransposeOp>(
op->getLoc(), getTypeConverter()->convertType(ohwiWeightType),
weight, rewriter.getDenseI32ArrayAttr(iohwToOhwi))
.getResult();

// TOSA 'out_pad' is a 4D array {top,bottom,left,right}.
// Map from PyTorch's (padding, output_padding):
// out_pad_total(H/W) = output_padding(H/W) - 2*padding(H/W)
// Negative values are allowed and will be handled by the TOSA
// decomposition.
SmallVector<int64_t, 2> outPadding2D;
if (!matchPattern(adaptor.getOutputPadding(),
m_TorchListOfConstantInts(outPadding2D)))
return rewriter.notifyMatchFailure(
op, "non-const output_padding list unsupported for transposed conv");

int64_t outPadH = outPadding2D[0] - 2 * padding_2d[0];
int64_t outPadW = outPadding2D[1] - 2 * padding_2d[1];
int64_t outPadTop = outPadH / 2;
int64_t outPadBottom = outPadH - outPadTop;
int64_t outPadLeft = outPadW / 2;
int64_t outPadRight = outPadW - outPadLeft;
SmallVector<int64_t, 4> outPad(
{outPadTop, outPadBottom, outPadLeft, outPadRight});

// Result type is NHWC (we'll transpose back).
auto outNCHW = makeShapeTorchCompatible(outputTy.getShape());
SmallVector<int64_t> outNHWC(
{outNCHW[0], outNCHW[2], outNCHW[3], outNCHW[1]});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar comment to above to use the nchwToNhwcDims instead of hardcoding 0,2,3,1

auto transConvOpTy =
RankedTensorType::get(makeShapeLLVMCompatible(outNHWC), biasElemTy);

// Zero-points.
auto zps = tosa::createZPsAsConst(rewriter, input, weight);
Value inputZp = zps.first ? zps.first
: tosa::createZeroPointTensor(
rewriter, op->getLoc(), inputElemTy, 0)
.value();
Value weightZp = zps.second ? zps.second
: tosa::createZeroPointTensor(
rewriter, op->getLoc(), weightElemTy, 0)
.value();

Value convTOut =
rewriter
.create<tosa::TransposeConv2DOp>(
op->getLoc(), getTypeConverter()->convertType(transConvOpTy),
transposedInput, transformedWeight, bias, inputZp, weightZp,
rewriter.getDenseI64ArrayAttr(outPad),
rewriter.getDenseI64ArrayAttr(stride), accType)
.getResult();

SmallVector<int64_t> transposedOutputShape(
{outNHWC[0], outNHWC[3], outNHWC[1], outNHWC[2]});
auto transposedOutputType = RankedTensorType::get(
makeShapeLLVMCompatible(transposedOutputShape), biasElemTy);
Value transposedOutput =
rewriter
.create<tosa::TransposeOp>(
op->getLoc(),
getTypeConverter()->convertType(transposedOutputType), convTOut,
rewriter.getDenseI32ArrayAttr(nhwcToNchwDims))
.getResult();

// Quantized rescale.
Value rescaledResult = transposedOutput;
if (isa<quant::QuantizedType>(inputElemTy)) {
rescaledResult = tosa::buildRescaleOpConvOutput(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConvTranspose2DQInt8_basic e2e test likely triggers this code path -- any idea why that is still failing? If that's not the correct test, are there any e2e tests that triggers this code?

rewriter, op, transposedOutput, inputTy, weightTy, outputTy);
}

// Final cast to requested output type.
rewriter.replaceOp(
op, {tosa::tosaCastTensorToType(rewriter, rescaledResult, outputTy)
.value()});
return success();
}

SmallVector<int64_t> transformedWeightShape;
RankedTensorType transformedWeightType;
Value transformedWeight;
Expand Down Expand Up @@ -2583,7 +2686,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
llvm_unreachable("Unhandled convolution type");
}

SmallVector<int32_t> nhwcToNchwDims({0, 3, 1, 2});
SmallVector<int64_t> transposedOutputShape(
{outputShape[0], outputShape[3], outputShape[1], outputShape[2]});
auto transposedOutputType = RankedTensorType::get(
Expand Down
6 changes: 0 additions & 6 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3588,7 +3588,6 @@
"AvgPool3dCountIncludePadFalseWithoutPadding_basic",
"Conv_Transpose1dModule_basic",
"Conv_Transpose1dStaticModule_basic",
"Conv_Transpose2dStaticModule_basic",
"Conv_Transpose3dModule_basic",
"Conv_Transpose3dStaticModule_basic",
"IndexPutWithNoneAndBroadcastModule_basic",
Expand Down Expand Up @@ -3713,16 +3712,11 @@
"Conv3dWithValidPaddingModule_basic",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"Conv_Transpose2dModule_basic",
"ConvolutionBackwardModule2DPadded_basic",
"ConvolutionBackwardModule2DStatic_basic",
"ConvolutionBackwardModule2DStrided_basic",
"ConvolutionBackwardModule2D_basic",
"ConvolutionModule2DGroups_basic",
"ConvolutionModule2DTransposeNonUnitOutputPadding_basic",
"ConvolutionModule2DTransposeStridedStatic_basic",
"ConvolutionModule2DTransposeStrided_basic",
"ConvolutionModule2DTranspose_basic",
"ConvolutionModule2DGroupedTranspose_basic",
"ConvolutionModule3DGroups_basic",
"ConvolutionModule3DGroupsStrided_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
# that depend on TOSA as well as TOSA-to-Standard.
"tosa-to-arith",
"tosa-to-scf",
# Required for transposed convolution support (decomposes to conv ops).
"tosa-optional-decompositions",
# Named ops must be legalized prior to general tosa-to-linalg
"tosa-to-linalg-named",
# TOSA-to-LinAlg may generate tosa.const() ops, so we want to lower them
Expand Down
22 changes: 18 additions & 4 deletions test/Conversion/TorchToTosa/conv2d_transpose.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is verify-diagnostics needed anymore unless you plan to add tests to verify the failure cases?


// The following test ensures that a tranposed convolution op is not
// lowered in the torch-to-tosa conversion pass.
// The lowering now legalizes transpose convolutions into the TOSA dialect.
// Verify that we emit tosa.transpose_conv2d with the expected reshapes/
// permutations.

// CHECK-LABEL: func.func @forward
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> {
// CHECK: %[[IN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT]] : !torch.vtensor<[1,64,1,100],f32> -> tensor<1x64x1x100xf32>
// CHECK: %[[WEIGHT:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<64x64x3x3xf32>}> : () -> tensor<64x64x3x3xf32>
// CHECK: %[[BIAS:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<64xf32>}> : () -> tensor<64xf32>
// CHECK: %[[TRANS_IN:.*]] = tosa.transpose %[[IN_TENSOR]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x64x1x100xf32>) -> tensor<1x1x100x64xf32>
// CHECK: %[[W_OHWI:.*]] = tosa.transpose %[[WEIGHT]] {perms = array<i32: 1, 2, 3, 0>} : (tensor<64x64x3x3xf32>) -> tensor<64x3x3x64xf32>
// CHECK: %[[ZP0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[ZP1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[TCONV:.*]] = tosa.transpose_conv2d %[[TRANS_IN]], %[[W_OHWI]], %[[BIAS]], %[[ZP0]], %[[ZP1]] {acc_type = f32, out_pad = array<i64: 0, -1, 0, -1>, stride = array<i64: 2, 2>} : (tensor<1x1x100x64xf32>, tensor<64x3x3x64xf32>, tensor<64xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x2x200x64xf32>
// CHECK: %[[TRANS_OUT:.*]] = tosa.transpose %[[TCONV]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x2x200x64xf32>) -> tensor<1x64x2x200xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[TRANS_OUT]] : tensor<1x64x2x200xf32> -> !torch.vtensor<[1,64,2,200],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,64,2,200],f32>
// CHECK: }
func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> {
%true = torch.constant.bool true
%int1 = torch.constant.int 1
Expand All @@ -11,7 +26,6 @@ func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[
%bias = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32>
%stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
%int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// expected-error@+1 {{failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal}}
%output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,64,2,200],f32>
return %output : !torch.vtensor<[1,64,2,200],f32>
}
Loading