-
Notifications
You must be signed in to change notification settings - Fork 611
[TOSA] Add transposed conv support #4360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2304,9 +2304,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Unimplemented: non-constant value for transposed not supported"); | ||
| if (transposed) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Unimplemented: transposed convolution not supported"); | ||
|
|
||
| auto input = adaptor.getInput(); | ||
| auto weight = adaptor.getWeight(); | ||
|
|
@@ -2338,12 +2335,17 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| auto bias = adaptor.getBias(); | ||
|
|
||
| if (isa<Torch::NoneType>(bias.getType())) { | ||
| auto bias_result = tosa::getConvBiasForNoneType(op, rewriter, inputElemTy, | ||
| outputElemTy, weightShape); | ||
| if (failed(bias_result)) | ||
| SmallVector<int64_t, 4> biasWeightShape = | ||
| transposed ? SmallVector<int64_t, 4>{weightShape[1], weightShape[0], | ||
| weightShape[2], weightShape[3]} | ||
| : weightShape; | ||
|
|
||
| auto biasResult = tosa::getConvBiasForNoneType( | ||
| op, rewriter, inputElemTy, outputElemTy, biasWeightShape); | ||
| if (failed(biasResult)) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Failed to create bias tensor for none type."); | ||
| bias = bias_result.value(); | ||
| bias = biasResult.value(); | ||
| } else { | ||
| if (!isa<RankedTensorType>(bias.getType())) | ||
| return rewriter.notifyMatchFailure( | ||
|
|
@@ -2370,8 +2372,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| m_TorchListOfConstantInts(padding_2d))) | ||
| return rewriter.notifyMatchFailure(op, | ||
| "non-const padding list unsupported"); | ||
| // TOSA uses 4D padding {top, bottom, left, right} while Torch defines 2D | ||
| // padding {height, width}. The Torch OFM computation uses 2*pad in each | ||
| // TOSA uses 4D padding {top, bottom, left, right} while PyTorch defines 2D | ||
| // padding {height, width}. The PyTorch OFM computation uses 2*pad in each | ||
| // spatial direction, implying the same top=bottom=height and left=right=width | ||
| // values for TOSA. | ||
| SmallVector<int64_t> padding( | ||
|
|
@@ -2388,11 +2390,19 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| return rewriter.notifyMatchFailure( | ||
| op, "failed to get accumulator type for convolution ops"); | ||
|
|
||
| // Weight layout reference: | ||
| // Conv : PyTorch OIHW -> TOSA OHWI | ||
| // Depthwise : PyTorch OIHW* -> TOSA HWIM | ||
| // (PyTorch depthwise uses out_ch=in_ch*depth_multiplier) | ||
| // Grouped : PyTorch O(I/G)HW -> N/A | ||
| // Transposed : PyTorch IOHW -> TOSA OHWI | ||
| // TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights. | ||
| // Perform the necessary transformations. | ||
| SmallVector<int32_t> nchwToNhwcDims({0, 2, 3, 1}); | ||
| SmallVector<int64_t> transposedInputShape( | ||
| {inputShape[0], inputShape[2], inputShape[3], inputShape[1]}); | ||
| SmallVector<int32_t> nhwcToNchwDims({0, 3, 1, 2}); | ||
| SmallVector<int64_t, 4> transposedInputShape; | ||
| for (int32_t dim : nchwToNhwcDims) | ||
| transposedInputShape.push_back(inputShape[dim]); | ||
| auto transposedInputType = RankedTensorType::get( | ||
| makeShapeLLVMCompatible(transposedInputShape), inputElemTy); | ||
| auto transposedInput = | ||
|
|
@@ -2403,6 +2413,104 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) | ||
| .getResult(); | ||
|
|
||
| if (transposed) { | ||
| if (groups != 1) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Unimplemented: grouped transposed convolution not supported by " | ||
| "TOSA"); | ||
| if (dilation[0] != 1 || dilation[1] != 1) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "Unimplemented: dilated transposed convolution not supported by " | ||
| "TOSA"); | ||
|
|
||
|
Comment on lines
+2418
to
+2425
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think these notify failures need to happen before any IR rewrites take place, otherwise the pattern rewriter ends up in a recursive loop. For example, on line 2410 we'd have already introduced |
||
| SmallVector<int32_t> iohwToOhwi({1, 2, 3, 0}); | ||
| SmallVector<int64_t, 4> ohwiWeightShape; | ||
| for (int32_t dim : iohwToOhwi) | ||
| ohwiWeightShape.push_back(weightShape[dim]); | ||
| auto ohwiWeightType = RankedTensorType::get( | ||
| makeShapeLLVMCompatible(ohwiWeightShape), weightElemTy); | ||
| Value transformedWeight = | ||
| rewriter | ||
| .create<tosa::TransposeOp>( | ||
| op->getLoc(), getTypeConverter()->convertType(ohwiWeightType), | ||
| weight, rewriter.getDenseI32ArrayAttr(iohwToOhwi)) | ||
| .getResult(); | ||
|
|
||
| // TOSA 'out_pad' is a 4D array {top,bottom,left,right}. | ||
| // Map from PyTorch's (padding, output_padding): | ||
| // out_pad_total(H/W) = output_padding(H/W) - 2*padding(H/W) | ||
| // Negative values are allowed and will be handled by the TOSA | ||
| // decomposition. | ||
| SmallVector<int64_t, 2> outPadding2D; | ||
| if (!matchPattern(adaptor.getOutputPadding(), | ||
| m_TorchListOfConstantInts(outPadding2D))) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "non-const output_padding list unsupported for transposed conv"); | ||
|
|
||
| int64_t outPadH = outPadding2D[0] - 2 * padding_2d[0]; | ||
| int64_t outPadW = outPadding2D[1] - 2 * padding_2d[1]; | ||
| int64_t outPadTop = outPadH / 2; | ||
| int64_t outPadBottom = outPadH - outPadTop; | ||
| int64_t outPadLeft = outPadW / 2; | ||
| int64_t outPadRight = outPadW - outPadLeft; | ||
| SmallVector<int64_t, 4> outPad( | ||
| {outPadTop, outPadBottom, outPadLeft, outPadRight}); | ||
|
|
||
| // Result type is NHWC (we'll transpose back). | ||
| auto outNCHW = makeShapeTorchCompatible(outputTy.getShape()); | ||
| SmallVector<int64_t, 4> outNHWC; | ||
| for (int32_t dim : nchwToNhwcDims) | ||
| outNHWC.push_back(outNCHW[dim]); | ||
| auto transConvOpTy = | ||
| RankedTensorType::get(makeShapeLLVMCompatible(outNHWC), biasElemTy); | ||
|
|
||
| // Zero-points. | ||
| auto zps = tosa::createZPsAsConst(rewriter, input, weight); | ||
| Value inputZp = zps.first ? zps.first | ||
| : tosa::createZeroPointTensor( | ||
| rewriter, op->getLoc(), inputElemTy, 0) | ||
| .value(); | ||
| Value weightZp = zps.second ? zps.second | ||
| : tosa::createZeroPointTensor( | ||
| rewriter, op->getLoc(), weightElemTy, 0) | ||
| .value(); | ||
|
|
||
| Value convTOut = | ||
| rewriter | ||
| .create<tosa::TransposeConv2DOp>( | ||
| op->getLoc(), getTypeConverter()->convertType(transConvOpTy), | ||
| transposedInput, transformedWeight, bias, inputZp, weightZp, | ||
| rewriter.getDenseI64ArrayAttr(outPad), | ||
| rewriter.getDenseI64ArrayAttr(stride), accType) | ||
| .getResult(); | ||
|
|
||
| SmallVector<int64_t, 4> transposedOutputShape; | ||
| for (int32_t dim : nhwcToNchwDims) | ||
| transposedOutputShape.push_back(outNHWC[dim]); | ||
| auto transposedOutputType = RankedTensorType::get( | ||
| makeShapeLLVMCompatible(transposedOutputShape), biasElemTy); | ||
| Value transposedOutput = | ||
| rewriter | ||
| .create<tosa::TransposeOp>( | ||
| op->getLoc(), | ||
| getTypeConverter()->convertType(transposedOutputType), convTOut, | ||
| rewriter.getDenseI32ArrayAttr(nhwcToNchwDims)) | ||
| .getResult(); | ||
|
|
||
| // Quantized rescale. | ||
| Value rescaledResult = transposedOutput; | ||
| if (isa<quant::QuantizedType>(inputElemTy)) { | ||
| rescaledResult = tosa::buildRescaleOpConvOutput( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ConvTranspose2DQInt8_basic e2e test likely triggers this code path -- any idea why that is still failing? If that's not the correct test, are there any e2e tests that triggers this code? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The failing XFAIL stops at |
||
| rewriter, op, transposedOutput, inputTy, weightTy, outputTy); | ||
| } | ||
|
|
||
| // Final cast to requested output type. | ||
| rewriter.replaceOp( | ||
| op, {tosa::tosaCastTensorToType(rewriter, rescaledResult, outputTy) | ||
| .value()}); | ||
| return success(); | ||
| } | ||
|
|
||
| SmallVector<int64_t> transformedWeightShape; | ||
| RankedTensorType transformedWeightType; | ||
| Value transformedWeight; | ||
|
|
@@ -2487,7 +2595,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| if (remainderHDim != 0) { | ||
| if (remainderHDim > padding[1]) { | ||
| SmallVector<int64_t> startHSlice(inputTy.getRank(), 0); | ||
| SmallVector<int64_t> sizeHSlice(transposedInputShape); | ||
| SmallVector<int64_t, 4> sizeHSlice(transposedInputShape); | ||
| // TOSA uses NHWC, so we will slice dim 1 for Height value | ||
| sizeHSlice[1] = inputHDim - (remainderHDim - padding[1]); | ||
| transposedInput = tosa::CreateOpAndInfer<tosa::SliceOp>( | ||
|
|
@@ -2583,7 +2691,6 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite( | |
| llvm_unreachable("Unhandled convolution type"); | ||
| } | ||
|
|
||
| SmallVector<int32_t> nhwcToNchwDims({0, 3, 1, 2}); | ||
| SmallVector<int64_t> transposedOutputShape( | ||
| {outputShape[0], outputShape[3], outputShape[1], outputShape[2]}); | ||
| auto transposedOutputType = RankedTensorType::get( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a comment about why the biasShape is different for tranposed case?