-
Notifications
You must be signed in to change notification settings - Fork 610
[TOSA] Add transposed conv support #4360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[TOSA] Add transposed conv support #4360
Conversation
Lower aten.conv_transpose2d into tosa.transpose_conv2d. Refresh FX importer TOSA xfails to drop the transpose-conv cases that now pass, and document the weight layout mapping. Change-Id: I23be2230a0948784402dca574597db1d979d5aee
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change looks good to me. Some minor comments and clarifying questions.
|
|
||
| // Weight layout reference: | ||
| // Conv : PyTorch OIHW -> TOSA OHWI | ||
| // Depthwise : PyTorch OIHW* -> TOSA HWIM (*out = in * multiplier) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this reference, though it's not clear what you are trying to imply with (*out = in * multiplier) here.
|
|
||
| SmallVector<int32_t> iohwToOhwi({1, 2, 3, 0}); | ||
| SmallVector<int64_t> ohwiWeightShape( | ||
| {weightShape[1], weightShape[2], weightShape[3], weightShape[0]}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#nit: weightShape[iohwToOhwi[0]], weightShape[iohwToOhwi[1]], ... and so on will help readability I think
| // Result type is NHWC (we'll transpose back). | ||
| auto outNCHW = makeShapeTorchCompatible(outputTy.getShape()); | ||
| SmallVector<int64_t> outNHWC( | ||
| {outNCHW[0], outNCHW[2], outNCHW[3], outNCHW[1]}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar comment to above to use the nchwToNhwcDims instead of hardcoding 0,2,3,1
| @@ -1,8 +1,23 @@ | |||
| // RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics | |||
| // RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is verify-diagnostics needed anymore unless you plan to add tests to verify the failure cases?
| // Quantized rescale. | ||
| Value rescaledResult = transposedOutput; | ||
| if (isa<quant::QuantizedType>(inputElemTy)) { | ||
| rescaledResult = tosa::buildRescaleOpConvOutput( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ConvTranspose2DQInt8_basic e2e test likely triggers this code path -- any idea why that is still failing? If that's not the correct test, are there any e2e tests that triggers this code?
Lower aten.conv_transpose2d into tosa.transpose_conv2d. Refresh FX importer TOSA xfails to drop the transpose-conv cases that now pass, and document the weight layout mapping.