Commit fb21a85
authored
[TorchToLinalg] Lower grouped conv2d to linalg Op with correct dimension ordering (#2623)
The linalg Op `linalg.conv_2d_ngchw_fgchw` had a bug where
1. Weights were accessed as G,F,C,H,W instead of as F,G,C,H,W
2. Output was accessed as N,F,G,H,W instead of as N,G,F,H,W
Now this has been fixed in
llvm/llvm-project#73855 which broke the
torch-mlir lowering to that Op.
This patch switches lowering in torch-mlir to the newly introduced
`linalg.conv_2d_ngchw_gfchw` op which accesses weights in an order that
is compatible with PyTorch's memory layout.
Fix #26221 parent 8252656 commit fb21a85
File tree
2 files changed
+7
-23
lines changed- lib/Conversion/TorchToLinalg
- projects/pt1/e2e_testing
2 files changed
+7
-23
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
848 | 848 | | |
849 | 849 | | |
850 | 850 | | |
| 851 | + | |
851 | 852 | | |
852 | 853 | | |
853 | 854 | | |
| |||
868 | 869 | | |
869 | 870 | | |
870 | 871 | | |
871 | | - | |
| 872 | + | |
872 | 873 | | |
873 | 874 | | |
874 | 875 | | |
875 | | - | |
876 | | - | |
| 876 | + | |
| 877 | + | |
877 | 878 | | |
878 | | - | |
| 879 | + | |
879 | 880 | | |
880 | 881 | | |
881 | | - | |
882 | | - | |
883 | | - | |
884 | 882 | | |
885 | | - | |
| 883 | + | |
| 884 | + | |
886 | 885 | | |
887 | 886 | | |
888 | 887 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
26 | | - | |
27 | | - | |
28 | | - | |
29 | | - | |
30 | | - | |
31 | | - | |
32 | | - | |
33 | | - | |
34 | 26 | | |
35 | 27 | | |
36 | 28 | | |
| |||
316 | 308 | | |
317 | 309 | | |
318 | 310 | | |
319 | | - | |
320 | | - | |
321 | | - | |
322 | | - | |
323 | | - | |
324 | | - | |
325 | | - | |
326 | 311 | | |
327 | 312 | | |
328 | 313 | | |
| |||
0 commit comments