Skip to content

Commit 88d4c47

Browse files
authored
[Torch] Fix mixP case for non value semantic ops (#2540)
NonValueSemantic Ops like Add_, div_, etc. expect result DType to be the same as the first input. However, current implementation would result in wrong result type for case like: ```python a = torch.randn(3, 3).half() # float16 b = torch.randn(3, 3) # float32 a += b # i.e. torch.ops.aten.add_(a, b) ``` torch expects `a` to be float16, but dtype refinement would infer float32 type, since it's replaced by `aten.add`.
1 parent 4901773 commit 88d4c47

File tree

3 files changed

+43
-3
lines changed

3 files changed

+43
-3
lines changed

lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,20 @@ class ReduceTrailingUnderscoreInplaceVariant : public RewritePattern {
243243
"Torch JIT operators shouldn't have regions or successors");
244244

245245
Operation *newOp = rewriter.create(state);
246-
auto tensor =
247-
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
246+
// Note: need to convert result to first input's dtype because mix precision
247+
// compute would result in different behaviors.
248+
// For example:
249+
// a = torch.randn(3, 3).half() # float16
250+
// b = torch.randn(3, 3) # float32
251+
// a += b # i.e. torch.ops.aten.add_(a, b), result is float16
252+
// c = a + b # i.e. torch.ops.aten.add(a, b), result is float32
253+
Value none = rewriter.create<ConstantNoneOp>(op->getLoc());
254+
Value cstFalse = rewriter.create<ConstantBoolOp>(op->getLoc(), false);
255+
auto aDtype = rewriter.create<PrimDtypeOp>(op->getLoc(), op->getOperand(0));
256+
auto toDtype = rewriter.create<AtenToDtypeOp>(
257+
op->getLoc(), newOp->getResult(0).getType(), newOp->getResult(0),
258+
aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
259+
auto tensor = rewriter.create<CopyToValueTensorOp>(op->getLoc(), toDtype);
248260
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
249261
op->getOperand(0));
250262
rewriter.replaceOp(op, op->getOperand(0));

python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,30 @@ def AddSizeIntNegDimModule_basic(module, tu: TestUtils):
10121012
# ==============================================================================
10131013

10141014

1015+
class Add_MixPModule(torch.nn.Module):
1016+
1017+
def __init__(self):
1018+
super().__init__()
1019+
1020+
@export
1021+
@annotate_args([
1022+
None,
1023+
([-1, -1], torch.float32, True),
1024+
([-1, -1], torch.float64, True),
1025+
])
1026+
def forward(self, a, b):
1027+
a += b
1028+
return a
1029+
1030+
1031+
@register_test_case(module_factory=lambda: Add_MixPModule())
1032+
def Add_MixPModule_basic(module, tu: TestUtils):
1033+
module.forward(tu.rand(3, 3), tu.rand(3, 3).double())
1034+
1035+
1036+
# ==============================================================================
1037+
1038+
10151039
class EmbeddingModuleI64(torch.nn.Module):
10161040

10171041
def __init__(self):

test/Dialect/Torch/reduce-op-variants.mlir

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,11 @@ func.func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor,
9494
// (which is cleaned up by canonicalization) is an artifact of two patterns
9595
// being applied in sequence.
9696
// CHECK: %[[ARRAY_RESULT:.*]] = torch.copy.to_tensor %[[TENSOR_RESULT]] : !torch.tensor<[2,2],f32>
97-
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.to_vtensor %[[ARRAY_RESULT]] : !torch.vtensor<[2,2],f32>
97+
// CHECK: %[[NONE:.*]] = torch.constant.none
98+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
99+
// CHECK: %[[DTYPE:.*]] = torch.constant.int 6
100+
// CHECK: %[[DTYPE_RESULT:.*]] = torch.aten.to.dtype %[[ARRAY_RESULT]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor<[2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[2,2],f32>
101+
// CHECK: %[[TENSOR_AGAIN:.*]] = torch.copy.to_vtensor %[[DTYPE_RESULT]] : !torch.vtensor<[2,2],f32>
98102
// CHECK: torch.overwrite.tensor.contents %[[TENSOR_AGAIN]] overwrites %[[ARG0]] : !torch.vtensor<[2,2],f32>, !torch.tensor<[2,2],f32>
99103
// CHECK: return %[[ARG0]], %[[ARG0]] : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
100104
func.func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>, %arg1: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) {

0 commit comments

Comments
 (0)