Skip to content

Commit f460ca2

Browse files
committed
[Torch] Add support for aten.round.decimals op
* Added decomposition to aten.round * Added test to projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
1 parent 60379d7 commit f460ca2

File tree

8 files changed

+150
-0
lines changed

8 files changed

+150
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4562,6 +4562,54 @@ def Torch_AtenRound_Op : Torch_Op<"aten.round_", [
45624562
}];
45634563
}
45644564

4565+
def Torch_AtenRoundDecimalsOp : Torch_Op<"aten.round.decimals", [
4566+
AllowsTypeRefinement,
4567+
HasValueSemantics,
4568+
ReadOnly
4569+
]> {
4570+
let summary = "Generated op for `aten::round.decimals : (Tensor, int) -> (Tensor)`";
4571+
let arguments = (ins
4572+
AnyTorchTensorType:$self,
4573+
Torch_IntType:$decimals
4574+
);
4575+
let results = (outs
4576+
AnyTorchOptionalTensorType:$result
4577+
);
4578+
let hasCustomAssemblyFormat = 1;
4579+
let extraClassDefinition = [{
4580+
ParseResult AtenRoundDecimalsOp::parse(OpAsmParser &parser, OperationState &result) {
4581+
return parseDefaultTorchOp(parser, result, 2, 1);
4582+
}
4583+
void AtenRoundDecimalsOp::print(OpAsmPrinter &printer) {
4584+
printDefaultTorchOp(printer, *this, 2, 1);
4585+
}
4586+
}];
4587+
let hasFolder = 1;
4588+
}
4589+
4590+
def Torch_AtenRound_DecimalsOp : Torch_Op<"aten.round_.decimals", [
4591+
IsTrailingUnderscoreInplaceVariant,
4592+
AllowsTypeRefinement
4593+
]> {
4594+
let summary = "Generated op for `aten::round_.decimals : (Tensor, int) -> (Tensor)`";
4595+
let arguments = (ins
4596+
Torch_NonValueTensorType:$self,
4597+
Torch_IntType:$decimals
4598+
);
4599+
let results = (outs
4600+
AnyTorchOptionalNonValueTensorType:$result
4601+
);
4602+
let hasCustomAssemblyFormat = 1;
4603+
let extraClassDefinition = [{
4604+
ParseResult AtenRound_DecimalsOp::parse(OpAsmParser &parser, OperationState &result) {
4605+
return parseDefaultTorchOp(parser, result, 2, 1);
4606+
}
4607+
void AtenRound_DecimalsOp::print(OpAsmPrinter &printer) {
4608+
printDefaultTorchOp(printer, *this, 2, 1);
4609+
}
4610+
}];
4611+
}
4612+
45654613
def Torch_AtenTruncOp : Torch_Op<"aten.trunc", [
45664614
AllowsTypeRefinement,
45674615
HasValueSemantics,

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,6 +1992,19 @@ OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) {
19921992
return {};
19931993
}
19941994

1995+
//===----------------------------------------------------------------------===//
1996+
// AtenRoundDecimalsOp
1997+
//===----------------------------------------------------------------------===//
1998+
1999+
OpFoldResult AtenRoundDecimalsOp::fold(FoldAdaptor adaptor) {
2000+
auto resultType = dyn_cast<ValueTensorType>(getType());
2001+
if (resultType && resultType.hasDtype() &&
2002+
isa<mlir::IntegerType>(resultType.getDtype())) {
2003+
return getSelf();
2004+
}
2005+
return {};
2006+
}
2007+
19952008
//===----------------------------------------------------------------------===//
19962009
// AtenRoundOp
19972010
//===----------------------------------------------------------------------===//

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6750,6 +6750,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
67506750
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
67516751
" return %0 : !torch.list<int>\n"
67526752
" }\n"
6753+
" func.func @\"__torch_mlir_shape_fn.aten.round.decimals\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
6754+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
6755+
" return %0 : !torch.list<int>\n"
6756+
" }\n"
67536757
" func.func @\"__torch_mlir_shape_fn.aten.glu\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
67546758
" %none = torch.constant.none\n"
67556759
" %str = torch.constant.str \"AssertionError: glu's dim size must be multiply of 2\"\n"
@@ -12896,6 +12900,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1289612900
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1289712901
" return %0#1 : !torch.int\n"
1289812902
" }\n"
12903+
" func.func @\"__torch_mlir_dtype_fn.aten.round.decimals\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
12904+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12905+
" return %0#1 : !torch.int\n"
12906+
" }\n"
1289912907
" func.func @\"__torch_mlir_dtype_fn.aten.glu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
1290012908
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1290112909
" return %0#1 : !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11804,6 +11804,53 @@ class DecomposeAten_AssertScalarOp
1180411804
};
1180511805
} // namespace
1180611806

11807+
namespace {
11808+
class DecomposeAtenRoundDecimalsOp
11809+
: public OpRewritePattern<AtenRoundDecimalsOp> {
11810+
public:
11811+
using OpRewritePattern<AtenRoundDecimalsOp>::OpRewritePattern;
11812+
LogicalResult matchAndRewrite(AtenRoundDecimalsOp op,
11813+
PatternRewriter &rewriter) const override {
11814+
// AtenRoundDecimalsOp is decomposed, if decimals is non-zero, as follow.
11815+
// scale = 10 ** decimals
11816+
// return round(x * scale) / scale
11817+
11818+
auto loc = op.getLoc();
11819+
auto input = op.getSelf();
11820+
auto inputType = cast<BaseTensorType>(input.getType());
11821+
11822+
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype())) {
11823+
return rewriter.notifyMatchFailure(
11824+
op, "unimplemented: non-floating point dtype");
11825+
}
11826+
11827+
int64_t decimals;
11828+
if (!matchPattern(op->getOperand(1), m_TorchConstantInt(&decimals))) {
11829+
return rewriter.notifyMatchFailure(
11830+
op, "non-constant decimal point is not supported.");
11831+
}
11832+
11833+
Value scale;
11834+
Value newOp;
11835+
if (decimals) {
11836+
auto scaleVal = pow(10, decimals);
11837+
scale = rewriter.create<ConstantFloatOp>(
11838+
loc, rewriter.getF64FloatAttr(scaleVal));
11839+
newOp = rewriter.create<AtenMulScalarOp>(loc, op.getType(), input, scale);
11840+
}
11841+
11842+
newOp = rewriter.create<AtenRoundOp>(loc, op.getType(), newOp);
11843+
11844+
if (decimals) {
11845+
newOp = rewriter.create<AtenDivScalarOp>(loc, op.getType(), newOp, scale);
11846+
}
11847+
11848+
rewriter.replaceOp(op, newOp);
11849+
return success();
11850+
}
11851+
};
11852+
} // namespace
11853+
1180711854
namespace {
1180811855
class DecomposeComplexOpsPass
1180911856
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@@ -12113,6 +12160,7 @@ class DecomposeComplexOpsPass
1211312160
addPatternIfTargetOpIsIllegal<DecomposeAtenConstrainRangeForSizeOp>(
1211412161
patterns);
1211512162
addPatternIfTargetOpIsIllegal<DecomposeAten_AssertScalarOp>(patterns);
12163+
addPatternIfTargetOpIsIllegal<DecomposeAtenRoundDecimalsOp>(patterns);
1211612164

1211712165
GreedyRewriteConfig config;
1211812166
config.setUseTopDownTraversal(true);

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@
547547
"AtenKthvalueModule_basic",
548548
"AtenPolarDoubleModule_basic",
549549
"AtenPolarFloatModule_basic",
550+
"AtenRoundFloatDecimalsModule_basic",
550551
"DiagonalWithStaticShapeModule_basic",
551552
"EinsumStaticDiagonalDimensionModule_basic",
552553
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
@@ -3408,6 +3409,7 @@
34083409
"AtenSymConstrainRange_basic",
34093410
"Aten_AssertScalar_basic",
34103411
"AvgPool2dSingleIntTupleParamsIncludePadModule_basic",
3412+
"AtenRoundFloatDecimalsModule_basic",
34113413
"ScatterAddDynamicModule_basic",
34123414
"UniformModule_basic",
34133415
"UniformStaticShapeModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,9 @@ def aten〇relu6〡shape(self: List[int]) -> List[int]:
349349
def aten〇round〡shape(self: List[int]) -> List[int]:
350350
return upstream_shape_functions.unary(self)
351351

352+
def aten〇round〇decimals〡shape(self: List[int], decimals: int) -> List[int]:
353+
return upstream_shape_functions.unary(self)
354+
352355
def aten〇glu〡shape(self: List[int], dim: int = -1) -> List[int]:
353356
if dim < 0:
354357
dim += len(self)
@@ -3616,6 +3619,11 @@ def aten〇round〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
36163619
self_rank, self_dtype = self_rank_dtype
36173620
return self_dtype
36183621

3622+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, decimals=0))
3623+
def aten〇round〇decimals〡dtype(self_rank_dtype: Tuple[int, int], decimals: int) -> int:
3624+
self_rank, self_dtype = self_rank_dtype
3625+
return self_dtype
3626+
36193627
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(100,)], dim=0))
36203628
def aten〇glu〡dtype(self_rank_dtype: Tuple[int, int], dim: int = -1) -> int:
36213629
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,9 @@ def emit_with_mutating_variants(key, **kwargs):
451451
emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True)
452452
emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True)
453453
emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True)
454+
emit_with_mutating_variants(
455+
"aten::round.decimals : (Tensor, int) -> (Tensor)", has_folder=True
456+
)
454457
emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True)
455458
emit("aten::special_expm1 : (Tensor) -> (Tensor)")
456459
emit_with_mutating_variants(

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6503,6 +6503,26 @@ def AtenRoundIntModule_basic(module, tu: TestUtils):
65036503
module.forward(tu.randint(5, 5, low=-10))
65046504

65056505

6506+
class AtenRoundFloatDecimalsModule(torch.nn.Module):
6507+
def __init__(self):
6508+
super().__init__()
6509+
6510+
@export
6511+
@annotate_args(
6512+
[
6513+
None,
6514+
([-1, -1], torch.float32, True),
6515+
]
6516+
)
6517+
def forward(self, x):
6518+
return torch.ops.aten.round(x, decimals=2)
6519+
6520+
6521+
@register_test_case(module_factory=lambda: AtenRoundFloatDecimalsModule())
6522+
def AtenRoundFloatDecimalsModule_basic(module, tu: TestUtils):
6523+
module.forward(tu.rand(5, 5, low=-3.0, high=3.0))
6524+
6525+
65066526
# ==============================================================================
65076527

65086528

0 commit comments

Comments
 (0)