Skip to content

Commit 4e1dd3b

Browse files
authored
add e2e support for torch.log10 (#2479)
1 parent 8abfa5b commit 4e1dd3b

File tree

6 files changed

+111
-2
lines changed

6 files changed

+111
-2
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2527,6 +2527,51 @@ def Torch_AtenLog2_Op : Torch_Op<"aten.log2_", [
25272527
}];
25282528
}
25292529

2530+
def Torch_AtenLog10Op : Torch_Op<"aten.log10", [
2531+
AllowsTypeRefinement,
2532+
HasValueSemantics,
2533+
ReadOnly
2534+
]> {
2535+
let summary = "Generated op for `aten::log10 : (Tensor) -> (Tensor)`";
2536+
let arguments = (ins
2537+
AnyTorchTensorType:$self
2538+
);
2539+
let results = (outs
2540+
AnyTorchTensorType:$result
2541+
);
2542+
let hasCustomAssemblyFormat = 1;
2543+
let extraClassDefinition = [{
2544+
ParseResult AtenLog10Op::parse(OpAsmParser &parser, OperationState &result) {
2545+
return parseDefaultTorchOp(parser, result, 1, 1);
2546+
}
2547+
void AtenLog10Op::print(OpAsmPrinter &printer) {
2548+
printDefaultTorchOp(printer, *this, 1, 1);
2549+
}
2550+
}];
2551+
}
2552+
2553+
def Torch_AtenLog10_Op : Torch_Op<"aten.log10_", [
2554+
IsTrailingUnderscoreInplaceVariant,
2555+
AllowsTypeRefinement
2556+
]> {
2557+
let summary = "Generated op for `aten::log10_ : (Tensor) -> (Tensor)`";
2558+
let arguments = (ins
2559+
AnyTorchTensorType:$self
2560+
);
2561+
let results = (outs
2562+
AnyTorchTensorType:$result
2563+
);
2564+
let hasCustomAssemblyFormat = 1;
2565+
let extraClassDefinition = [{
2566+
ParseResult AtenLog10_Op::parse(OpAsmParser &parser, OperationState &result) {
2567+
return parseDefaultTorchOp(parser, result, 1, 1);
2568+
}
2569+
void AtenLog10_Op::print(OpAsmPrinter &printer) {
2570+
printDefaultTorchOp(printer, *this, 1, 1);
2571+
}
2572+
}];
2573+
}
2574+
25302575
def Torch_AtenSqrtOp : Torch_Op<"aten.sqrt", [
25312576
AllowsTypeRefinement,
25322577
HasValueSemantics,

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
235235
return createCalculationForMathOpWithDtypeConversion<math::Log2Op>(
236236
b, converter, payloadArgs[0], op);
237237
}
238+
if (isa<AtenLog10Op>(op)) {
239+
return createCalculationForMathOpWithDtypeConversion<math::Log10Op>(
240+
b, converter, payloadArgs[0], op);
241+
}
238242
if (isa<AtenLog1pOp>(op)) {
239243
return createCalculationForMathOpWithDtypeConversion<math::Log1pOp>(
240244
b, converter, payloadArgs[0], op);
@@ -1177,7 +1181,7 @@ class ConvertElementwiseOp : public ConversionPattern {
11771181
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
11781182
AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp,
11791183
AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp,
1180-
AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp,
1184+
AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp,
11811185
AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp,
11821186
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
11831187
AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp,
@@ -1712,7 +1716,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
17121716
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
17131717
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp,
17141718
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
1715-
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
1719+
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
17161720
AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
17171721
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
17181722
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6322,6 +6322,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
63226322
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
63236323
" return %0 : !torch.list<int>\n"
63246324
" }\n"
6325+
" func.func @\"__torch_mlir_shape_fn.aten.log10\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
6326+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
6327+
" return %0 : !torch.list<int>\n"
6328+
" }\n"
63256329
" func.func @\"__torch_mlir_shape_fn.aten.log1p\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
63266330
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
63276331
" return %0 : !torch.list<int>\n"
@@ -8291,6 +8295,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
82918295
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
82928296
" return %1 : !torch.int\n"
82938297
" }\n"
8298+
" func.func @\"__torch_mlir_dtype_fn.aten.log10\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
8299+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
8300+
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
8301+
" return %1 : !torch.int\n"
8302+
" }\n"
82948303
" func.func @\"__torch_mlir_dtype_fn.aten.log1p\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
82958304
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
82968305
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ def aten〇detach〡shape(self: List[int]) -> List[int]:
122122
def aten〇log2〡shape(self: List[int]) -> List[int]:
123123
return upstream_shape_functions.unary(self)
124124

125+
def aten〇log10〡shape(self: List[int]) -> List[int]:
126+
return upstream_shape_functions.unary(self)
127+
125128
def aten〇log1p〡shape(self: List[int]) -> List[int]:
126129
return upstream_shape_functions.unary(self)
127130

@@ -1438,6 +1441,11 @@ def aten〇log2〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
14381441
self_rank, self_dtype = self_rank_dtype
14391442
return _get_dtype_of_floating_point_op(self_dtype)
14401443

1444+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
1445+
def aten〇log10〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
1446+
self_rank, self_dtype = self_rank_dtype
1447+
return _get_dtype_of_floating_point_op(self_dtype)
1448+
14411449
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
14421450
def aten〇log1p〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
14431451
self_rank, self_dtype = self_rank_dtype

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def emit_with_mutating_variants(key, **kwargs):
294294
"aten::clamp_max : (Tensor, Scalar) -> (Tensor)",
295295
"aten::clamp_max.Tensor : (Tensor, Tensor) -> (Tensor)",
296296
"aten::log2 : (Tensor) -> (Tensor)",
297+
"aten::log10 : (Tensor) -> (Tensor)",
297298
"aten::sqrt : (Tensor) -> (Tensor)",
298299
"aten::log1p : (Tensor) -> (Tensor)",
299300
"aten::rsqrt : (Tensor) -> (Tensor)",

python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1683,6 +1683,48 @@ def ElementwiseLog2IntModule_basic(module, tu: TestUtils):
16831683
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
16841684

16851685

1686+
# ==============================================================================
1687+
1688+
class ElementwiseLog10Module(torch.nn.Module):
1689+
1690+
def __init__(self):
1691+
super().__init__()
1692+
1693+
@export
1694+
@annotate_args([
1695+
None,
1696+
([-1, -1], torch.float32, True),
1697+
])
1698+
def forward(self, a):
1699+
return torch.log10(a)
1700+
1701+
1702+
@register_test_case(module_factory=lambda: ElementwiseLog10Module())
1703+
def ElementwiseLog10Module_basic(module, tu: TestUtils):
1704+
module.forward(tu.rand(3, 4))
1705+
1706+
1707+
# ==============================================================================
1708+
1709+
class ElementwiseLog10IntModule(torch.nn.Module):
1710+
1711+
def __init__(self):
1712+
super().__init__()
1713+
1714+
@export
1715+
@annotate_args([
1716+
None,
1717+
([-1, -1], torch.int32, True),
1718+
])
1719+
def forward(self, a):
1720+
return torch.log10(a)
1721+
1722+
1723+
@register_test_case(module_factory=lambda: ElementwiseLog10IntModule())
1724+
def ElementwiseLog10IntModule_basic(module, tu: TestUtils):
1725+
module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32))
1726+
1727+
16861728
# ==============================================================================
16871729

16881730

0 commit comments

Comments
 (0)