Skip to content

Commit 11cc92d

Browse files
authored
[onnx] Lowerings from onnx.tan (#2642)
Started work on the `tan` lowerings for ONNX to Torch. Uses `sin` and `cos` to represent a `tan`.
1 parent a24aadb commit 11cc92d

File tree

8 files changed

+147
-11
lines changed

8 files changed

+147
-11
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,51 @@ def Torch_AtenAcos_Op : Torch_Op<"aten.acos_", [
10661066
}];
10671067
}
10681068

1069+
def Torch_AtenTanOp : Torch_Op<"aten.tan", [
1070+
AllowsTypeRefinement,
1071+
HasValueSemantics,
1072+
ReadOnly
1073+
]> {
1074+
let summary = "Generated op for `aten::tan : (Tensor) -> (Tensor)`";
1075+
let arguments = (ins
1076+
AnyTorchTensorType:$self
1077+
);
1078+
let results = (outs
1079+
AnyTorchTensorType:$result
1080+
);
1081+
let hasCustomAssemblyFormat = 1;
1082+
let extraClassDefinition = [{
1083+
ParseResult AtenTanOp::parse(OpAsmParser &parser, OperationState &result) {
1084+
return parseDefaultTorchOp(parser, result, 1, 1);
1085+
}
1086+
void AtenTanOp::print(OpAsmPrinter &printer) {
1087+
printDefaultTorchOp(printer, *this, 1, 1);
1088+
}
1089+
}];
1090+
}
1091+
1092+
def Torch_AtenTan_Op : Torch_Op<"aten.tan_", [
1093+
IsTrailingUnderscoreInplaceVariant,
1094+
AllowsTypeRefinement
1095+
]> {
1096+
let summary = "Generated op for `aten::tan_ : (Tensor) -> (Tensor)`";
1097+
let arguments = (ins
1098+
Torch_NonValueTensorType:$self
1099+
);
1100+
let results = (outs
1101+
Torch_NonValueTensorType:$result
1102+
);
1103+
let hasCustomAssemblyFormat = 1;
1104+
let extraClassDefinition = [{
1105+
ParseResult AtenTan_Op::parse(OpAsmParser &parser, OperationState &result) {
1106+
return parseDefaultTorchOp(parser, result, 1, 1);
1107+
}
1108+
void AtenTan_Op::print(OpAsmPrinter &printer) {
1109+
printDefaultTorchOp(printer, *this, 1, 1);
1110+
}
1111+
}];
1112+
}
1113+
10691114
def Torch_AtenAtanOp : Torch_Op<"aten.atan", [
10701115
AllowsTypeRefinement,
10711116
HasValueSemantics,

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,19 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
794794
binder.op, resultType, operand);
795795
return success();
796796
});
797-
797+
798+
patterns.onOp("Tan", 7,
799+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
800+
Torch::ValueTensorType resultType;
801+
Value operand;
802+
if (binder.tensorOperand(operand) ||
803+
binder.tensorResultType(resultType))
804+
return failure();
805+
rewriter.replaceOpWithNewOp<Torch::AtenTanOp>(
806+
binder.op, resultType, operand);
807+
return success();
808+
});
809+
798810
patterns.onOp(
799811
"Transpose", 13,
800812
[](OpBinder binder, ConversionPatternRewriter &rewriter) {

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
216216
return b.create<math::FloorOp>(loc, payloadArgs[0]);
217217
if (isa<AtenCeilOp>(op))
218218
return b.create<math::CeilOp>(loc, payloadArgs[0]);
219+
if (isa<AtenTanOp>(op)) {
220+
return createCalculationForMathOpWithDtypeConversion<math::TanOp>(
221+
b, converter, payloadArgs[0], op);
222+
}
219223
if (isa<AtenTanhOp>(op)) {
220224
return createCalculationForMathOpWithDtypeConversion<math::TanhOp>(
221225
b, converter, payloadArgs[0], op);
@@ -1319,15 +1323,15 @@ class ConvertElementwiseOp : public ConversionPattern {
13191323
LogicalResult
13201324
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
13211325
ConversionPatternRewriter &rewriter) const override {
1322-
if (!isa<AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp, AtenPreluOp,
1323-
AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp,
1324-
AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp, AtenAtan2Op,
1325-
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op,
1326-
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
1327-
AtenClampTensorOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
1328-
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp,
1329-
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
1330-
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
1326+
if (!isa<AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp,
1327+
AtenPreluOp, AtenGeluOp, AtenGeluBackwardOp, AtenAddTensorOp,
1328+
AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
1329+
AtenSubTensorOp, AtenAtan2Op, AtenLerpTensorOp, AtenSigmoidOp,
1330+
AtenExpOp, AtenExpm1Op, AtenMinimumOp, AtenMaximumOp,
1331+
AtenToDtypeOp, AtenClampOp, AtenClampTensorOp, AtenRsubScalarOp,
1332+
AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp,
1333+
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp,
1334+
AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
13311335
AtenRemainderScalarOp, AtenAbsOp, AtenReciprocalOp,
13321336
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
13331337
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
@@ -1972,7 +1976,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
19721976
ConversionTarget &target) {
19731977
MLIRContext *context = patterns.getContext();
19741978
target.addIllegalOp<
1975-
AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp, AtenGeluOp,
1979+
AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp, AtenGeluOp,
19761980
AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp,
19771981
AtenDivTensorModeOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
19781982
AtenMinimumOp, AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6238,6 +6238,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
62386238
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
62396239
" return %0 : !torch.list<int>\n"
62406240
" }\n"
6241+
" func.func @\"__torch_mlir_shape_fn.aten.tan\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
6242+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
6243+
" return %0 : !torch.list<int>\n"
6244+
" }\n"
62416245
" func.func @\"__torch_mlir_shape_fn.aten.atan\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
62426246
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
62436247
" return %0 : !torch.list<int>\n"
@@ -11396,6 +11400,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1139611400
" }\n"
1139711401
" return %4 : !torch.tuple<int, int>\n"
1139811402
" }\n"
11403+
" func.func @\"__torch_mlir_dtype_fn.aten.tan\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
11404+
" %int6 = torch.constant.int 6\n"
11405+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
11406+
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
11407+
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
11408+
" torch.prim.If.yield %int6 : !torch.int\n"
11409+
" } else {\n"
11410+
" torch.prim.If.yield %0#1 : !torch.int\n"
11411+
" }\n"
11412+
" return %2 : !torch.int\n"
11413+
" }\n"
1139911414
" func.func @\"__torch_mlir_dtype_fn.aten.atan2\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
1140011415
" %int6 = torch.constant.int 6\n"
1140111416
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ def aten〇triu〡shape(self: List[int], diagonal: int = 0) -> List[int]:
5959
def aten〇tril〡shape(self: List[int], diagonal: int = 0) -> List[int]:
6060
return upstream_shape_functions.unary(self)
6161

62+
def aten〇tan〡shape(self: List[int]) -> List[int]:
63+
return upstream_shape_functions.unary(self)
64+
6265
def aten〇atan〡shape(self: List[int]) -> List[int]:
6366
return upstream_shape_functions.unary(self)
6467

@@ -3721,6 +3724,13 @@ def aten〇var_mean〡dtype(self_rank_dtype: Tuple[int, int], unbiased: bool = T
37213724
return torch.float64, self_dtype
37223725
return self_dtype, self_dtype
37233726

3727+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
3728+
def aten〇tan〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
3729+
self_rank, self_dtype = self_rank_dtype
3730+
if is_integer_dtype(self_dtype):
3731+
return torch.float32
3732+
return self_dtype
3733+
37243734
@check_dtype_function(_check_two_tensor_op())
37253735
def aten〇atan2〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
37263736
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def emit_with_mutating_variants(key, **kwargs):
278278
"aten::expm1 : (Tensor) -> (Tensor)",
279279
"aten::cos : (Tensor) -> (Tensor)",
280280
"aten::acos : (Tensor) -> (Tensor)",
281+
"aten::tan : (Tensor) -> (Tensor)",
281282
"aten::atan : (Tensor) -> (Tensor)",
282283
"aten::atan2 : (Tensor, Tensor) -> (Tensor)",
283284
"aten::neg : (Tensor) -> (Tensor)",

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3009,6 +3009,46 @@ def ElementwiseAcosIntModule_basic(module, tu: TestUtils):
30093009

30103010
# ==============================================================================
30113011

3012+
class ElementwiseTanModule(torch.nn.Module):
3013+
3014+
def __init__(self):
3015+
super().__init__()
3016+
3017+
@export
3018+
@annotate_args([
3019+
None,
3020+
([-1, -1], torch.float32, True),
3021+
])
3022+
def forward(self, a):
3023+
return torch.tan(a)
3024+
3025+
3026+
@register_test_case(module_factory=lambda: ElementwiseTanModule())
3027+
def ElementwiseTanModule_basic(module, tu: TestUtils):
3028+
module.forward(tu.rand(3, 4))
3029+
3030+
# ==============================================================================
3031+
3032+
class ElementwiseTanIntModule(torch.nn.Module):
3033+
3034+
def __init__(self):
3035+
super().__init__()
3036+
3037+
@export
3038+
@annotate_args([
3039+
None,
3040+
([-1, -1], torch.int32, True),
3041+
])
3042+
def forward(self, a):
3043+
return torch.tan(a)
3044+
3045+
3046+
@register_test_case(module_factory=lambda: ElementwiseTanIntModule())
3047+
def ElementwiseTanIntModule_basic(module, tu: TestUtils):
3048+
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
3049+
3050+
# ==============================================================================
3051+
30123052
class ElementwiseNegModule(torch.nn.Module):
30133053

30143054
def __init__(self):

test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,15 @@ func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[
795795

796796
// -----
797797

798+
// CHECK-LABEL: func.func @test_tan
799+
func.func @test_tan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
800+
// CHECK: %[[TAN:.+]] = torch.aten.tan %arg0
801+
%0 = torch.operator "onnx.Tan"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
802+
return %0 : !torch.vtensor<[3,4,5],f32>
803+
}
804+
805+
// -----
806+
798807
// CHECK-LABEL: func.func @test_transpose_default
799808
func.func @test_transpose_default(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} {
800809
// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0

0 commit comments

Comments
 (0)