Commit 88d4c47
authored
[Torch] Fix mixP case for non value semantic ops (#2540)
NonValueSemantic Ops like Add_, div_, etc. expect result DType to be the
same as the first input. However, current implementation would result in
wrong result type for case like:
```python
a = torch.randn(3, 3).half() # float16
b = torch.randn(3, 3) # float32
a += b # i.e. torch.ops.aten.add_(a, b)
```
torch expects `a` to be float16, but dtype refinement would infer
float32 type, since it's replaced by `aten.add`.1 parent 4901773 commit 88d4c47
File tree
3 files changed
+43
-3
lines changed- lib/Dialect/Torch/Transforms
- python/torch_mlir_e2e_test/test_suite
- test/Dialect/Torch
3 files changed
+43
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
243 | 243 | | |
244 | 244 | | |
245 | 245 | | |
246 | | - | |
247 | | - | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
248 | 260 | | |
249 | 261 | | |
250 | 262 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1012 | 1012 | | |
1013 | 1013 | | |
1014 | 1014 | | |
| 1015 | + | |
| 1016 | + | |
| 1017 | + | |
| 1018 | + | |
| 1019 | + | |
| 1020 | + | |
| 1021 | + | |
| 1022 | + | |
| 1023 | + | |
| 1024 | + | |
| 1025 | + | |
| 1026 | + | |
| 1027 | + | |
| 1028 | + | |
| 1029 | + | |
| 1030 | + | |
| 1031 | + | |
| 1032 | + | |
| 1033 | + | |
| 1034 | + | |
| 1035 | + | |
| 1036 | + | |
| 1037 | + | |
| 1038 | + | |
1015 | 1039 | | |
1016 | 1040 | | |
1017 | 1041 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
94 | 94 | | |
95 | 95 | | |
96 | 96 | | |
97 | | - | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
98 | 102 | | |
99 | 103 | | |
100 | 104 | | |
| |||
0 commit comments