Skip to content

Commit 73c56dd

Browse files
committed
[Torch] Add support for aten.round.decimals op
* Added decomposition to aten.round * Added test to projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
1 parent c675b2f commit 73c56dd

File tree

8 files changed

+150
-0
lines changed

8 files changed

+150
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4562,6 +4562,54 @@ def Torch_AtenRound_Op : Torch_Op<"aten.round_", [
45624562
}];
45634563
}
45644564

4565+
def Torch_AtenRoundDecimalsOp : Torch_Op<"aten.round.decimals", [
4566+
AllowsTypeRefinement,
4567+
HasValueSemantics,
4568+
ReadOnly
4569+
]> {
4570+
let summary = "Generated op for `aten::round.decimals : (Tensor, int) -> (Tensor)`";
4571+
let arguments = (ins
4572+
AnyTorchTensorType:$self,
4573+
Torch_IntType:$decimals
4574+
);
4575+
let results = (outs
4576+
AnyTorchOptionalTensorType:$result
4577+
);
4578+
let hasCustomAssemblyFormat = 1;
4579+
let extraClassDefinition = [{
4580+
ParseResult AtenRoundDecimalsOp::parse(OpAsmParser &parser, OperationState &result) {
4581+
return parseDefaultTorchOp(parser, result, 2, 1);
4582+
}
4583+
void AtenRoundDecimalsOp::print(OpAsmPrinter &printer) {
4584+
printDefaultTorchOp(printer, *this, 2, 1);
4585+
}
4586+
}];
4587+
let hasFolder = 1;
4588+
}
4589+
4590+
def Torch_AtenRound_DecimalsOp : Torch_Op<"aten.round_.decimals", [
4591+
IsTrailingUnderscoreInplaceVariant,
4592+
AllowsTypeRefinement
4593+
]> {
4594+
let summary = "Generated op for `aten::round_.decimals : (Tensor, int) -> (Tensor)`";
4595+
let arguments = (ins
4596+
Torch_NonValueTensorType:$self,
4597+
Torch_IntType:$decimals
4598+
);
4599+
let results = (outs
4600+
AnyTorchOptionalNonValueTensorType:$result
4601+
);
4602+
let hasCustomAssemblyFormat = 1;
4603+
let extraClassDefinition = [{
4604+
ParseResult AtenRound_DecimalsOp::parse(OpAsmParser &parser, OperationState &result) {
4605+
return parseDefaultTorchOp(parser, result, 2, 1);
4606+
}
4607+
void AtenRound_DecimalsOp::print(OpAsmPrinter &printer) {
4608+
printDefaultTorchOp(printer, *this, 2, 1);
4609+
}
4610+
}];
4611+
}
4612+
45654613
def Torch_AtenTruncOp : Torch_Op<"aten.trunc", [
45664614
AllowsTypeRefinement,
45674615
HasValueSemantics,

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,6 +1992,19 @@ OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) {
19921992
return {};
19931993
}
19941994

1995+
//===----------------------------------------------------------------------===//
1996+
// AtenRoundDecimalsOp
1997+
//===----------------------------------------------------------------------===//
1998+
1999+
OpFoldResult AtenRoundDecimalsOp::fold(FoldAdaptor adaptor) {
2000+
auto resultType = dyn_cast<ValueTensorType>(getType());
2001+
if (resultType && resultType.hasDtype() &&
2002+
isa<mlir::IntegerType>(resultType.getDtype())) {
2003+
return getSelf();
2004+
}
2005+
return {};
2006+
}
2007+
19952008
//===----------------------------------------------------------------------===//
19962009
// AtenRoundOp
19972010
//===----------------------------------------------------------------------===//

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6750,6 +6750,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
67506750
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
67516751
" return %0 : !torch.list<int>\n"
67526752
" }\n"
6753+
" func.func @\"__torch_mlir_shape_fn.aten.round.decimals\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
6754+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
6755+
" return %0 : !torch.list<int>\n"
6756+
" }\n"
67536757
" func.func @\"__torch_mlir_shape_fn.aten.glu\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
67546758
" %none = torch.constant.none\n"
67556759
" %str = torch.constant.str \"AssertionError: glu's dim size must be multiply of 2\"\n"
@@ -12963,6 +12967,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1296312967
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1296412968
" return %0#1 : !torch.int\n"
1296512969
" }\n"
12970+
" func.func @\"__torch_mlir_dtype_fn.aten.round.decimals\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
12971+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12972+
" return %0#1 : !torch.int\n"
12973+
" }\n"
1296612974
" func.func @\"__torch_mlir_dtype_fn.aten.glu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
1296712975
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1296812976
" return %0#1 : !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11885,6 +11885,54 @@ class DecomposeAten_AssertScalarOp
1188511885
};
1188611886
} // namespace
1188711887

11888+
namespace {
11889+
class DecomposeAtenRoundDecimalsOp
11890+
: public OpRewritePattern<AtenRoundDecimalsOp> {
11891+
public:
11892+
using OpRewritePattern<AtenRoundDecimalsOp>::OpRewritePattern;
11893+
LogicalResult matchAndRewrite(AtenRoundDecimalsOp op,
11894+
PatternRewriter &rewriter) const override {
11895+
// AtenRoundDecimalsOp is decomposed as follows if the decimals value is
11896+
// non-zero: scale = 10 ** decimals return round(x * scale) / scale
11897+
// otherwise:
11898+
// return round(x)
11899+
11900+
auto loc = op.getLoc();
11901+
auto input = op.getSelf();
11902+
auto inputType = cast<BaseTensorType>(input.getType());
11903+
11904+
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype())) {
11905+
return rewriter.notifyMatchFailure(
11906+
op, "unimplemented: non-floating point dtype");
11907+
}
11908+
11909+
int64_t decimals;
11910+
if (!matchPattern(op.getDecimals(), m_TorchConstantInt(&decimals))) {
11911+
return rewriter.notifyMatchFailure(
11912+
op, "non-constant decimal point is not supported.");
11913+
}
11914+
11915+
Value newOp = op->getOperand(0);
11916+
Value scale;
11917+
if (decimals) {
11918+
auto scaleVal = pow(10, decimals);
11919+
scale = rewriter.create<ConstantFloatOp>(
11920+
loc, rewriter.getF64FloatAttr(scaleVal));
11921+
newOp = rewriter.create<AtenMulScalarOp>(loc, op.getType(), input, scale);
11922+
}
11923+
11924+
newOp = rewriter.create<AtenRoundOp>(loc, op.getType(), newOp);
11925+
11926+
if (decimals) {
11927+
newOp = rewriter.create<AtenDivScalarOp>(loc, op.getType(), newOp, scale);
11928+
}
11929+
11930+
rewriter.replaceOp(op, newOp);
11931+
return success();
11932+
}
11933+
};
11934+
} // namespace
11935+
1188811936
namespace {
1188911937
class DecomposeComplexOpsPass
1189011938
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@@ -12197,6 +12245,7 @@ class DecomposeComplexOpsPass
1219712245
addPatternIfTargetOpIsIllegal<DecomposeAtenConstrainRangeForSizeOp>(
1219812246
patterns);
1219912247
addPatternIfTargetOpIsIllegal<DecomposeAten_AssertScalarOp>(patterns);
12248+
addPatternIfTargetOpIsIllegal<DecomposeAtenRoundDecimalsOp>(patterns);
1220012249

1220112250
GreedyRewriteConfig config;
1220212251
config.setUseTopDownTraversal(true);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
384384
target.addIllegalOp<AtenHstackOp>();
385385
target.addIllegalOp<AtenColumnStackOp>();
386386
target.addIllegalOp<AtenRollOp>();
387+
target.addIllegalOp<AtenRoundDecimalsOp>();
387388
target.addIllegalOp<AtenRepeatOp>();
388389
target.addIllegalOp<AtenRepeatInterleaveSelfIntOp>();
389390
target.addIllegalOp<AtenExpandOp>();

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,9 @@ def aten〇relu6〡shape(self: List[int]) -> List[int]:
349349
def aten〇round〡shape(self: List[int]) -> List[int]:
350350
return upstream_shape_functions.unary(self)
351351

352+
def aten〇round〇decimals〡shape(self: List[int], decimals: int) -> List[int]:
353+
return upstream_shape_functions.unary(self)
354+
352355
def aten〇glu〡shape(self: List[int], dim: int = -1) -> List[int]:
353356
if dim < 0:
354357
dim += len(self)
@@ -3648,6 +3651,11 @@ def aten〇round〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
36483651
self_rank, self_dtype = self_rank_dtype
36493652
return self_dtype
36503653

3654+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, decimals=0))
3655+
def aten〇round〇decimals〡dtype(self_rank_dtype: Tuple[int, int], decimals: int) -> int:
3656+
self_rank, self_dtype = self_rank_dtype
3657+
return self_dtype
3658+
36513659
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(100,)], dim=0))
36523660
def aten〇glu〡dtype(self_rank_dtype: Tuple[int, int], dim: int = -1) -> int:
36533661
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,9 @@ def emit_with_mutating_variants(key, **kwargs):
451451
emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True)
452452
emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True)
453453
emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True)
454+
emit_with_mutating_variants(
455+
"aten::round.decimals : (Tensor, int) -> (Tensor)", has_folder=True
456+
)
454457
emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True)
455458
emit("aten::special_expm1 : (Tensor) -> (Tensor)")
456459
emit_with_mutating_variants(

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6527,6 +6527,26 @@ def AtenRoundIntModule_basic(module, tu: TestUtils):
65276527
module.forward(tu.randint(5, 5, low=-10))
65286528

65296529

6530+
class AtenRoundFloatDecimalsModule(torch.nn.Module):
6531+
def __init__(self):
6532+
super().__init__()
6533+
6534+
@export
6535+
@annotate_args(
6536+
[
6537+
None,
6538+
([-1, -1], torch.float32, True),
6539+
]
6540+
)
6541+
def forward(self, x):
6542+
return torch.ops.aten.round(x, decimals=2)
6543+
6544+
6545+
@register_test_case(module_factory=lambda: AtenRoundFloatDecimalsModule())
6546+
def AtenRoundFloatDecimalsModule_basic(module, tu: TestUtils):
6547+
module.forward(tu.rand(5, 5, low=-3.0, high=3.0))
6548+
6549+
65306550
# ==============================================================================
65316551

65326552

0 commit comments

Comments
 (0)