Skip to content

Commit 365655c

Browse files
authored
[Torch Dialect] add canonicalize pattern for aten.floor with integer … (#2534)
…type
1 parent a2e694d commit 365655c

File tree

5 files changed

+82
-46
lines changed

5 files changed

+82
-46
lines changed

e2e_testing/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,7 @@
957957
"ElementwiseEluModule_basic",
958958
"ElementwiseEluNonDefaultModule_basic",
959959
"ElementwiseFloorModule_basic",
960+
"ElementwiseFloorIntModule_basic",
960961
"ElementwiseLogModule_basic",
961962
"ElementwiseBinaryStaticShapeModule_basic",
962963
"ElementwiseMinimumModule_basic",

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 46 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,51 +1023,6 @@ def Torch_AtenNeg_Op : Torch_Op<"aten.neg_", [
10231023
}];
10241024
}
10251025

1026-
def Torch_AtenFloorOp : Torch_Op<"aten.floor", [
1027-
AllowsTypeRefinement,
1028-
HasValueSemantics,
1029-
ReadOnly
1030-
]> {
1031-
let summary = "Generated op for `aten::floor : (Tensor) -> (Tensor)`";
1032-
let arguments = (ins
1033-
AnyTorchTensorType:$self
1034-
);
1035-
let results = (outs
1036-
AnyTorchTensorType:$result
1037-
);
1038-
let hasCustomAssemblyFormat = 1;
1039-
let extraClassDefinition = [{
1040-
ParseResult AtenFloorOp::parse(OpAsmParser &parser, OperationState &result) {
1041-
return parseDefaultTorchOp(parser, result, 1, 1);
1042-
}
1043-
void AtenFloorOp::print(OpAsmPrinter &printer) {
1044-
printDefaultTorchOp(printer, *this, 1, 1);
1045-
}
1046-
}];
1047-
}
1048-
1049-
def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [
1050-
IsTrailingUnderscoreInplaceVariant,
1051-
AllowsTypeRefinement
1052-
]> {
1053-
let summary = "Generated op for `aten::floor_ : (Tensor) -> (Tensor)`";
1054-
let arguments = (ins
1055-
Torch_NonValueTensorType:$self
1056-
);
1057-
let results = (outs
1058-
Torch_NonValueTensorType:$result
1059-
);
1060-
let hasCustomAssemblyFormat = 1;
1061-
let extraClassDefinition = [{
1062-
ParseResult AtenFloor_Op::parse(OpAsmParser &parser, OperationState &result) {
1063-
return parseDefaultTorchOp(parser, result, 1, 1);
1064-
}
1065-
void AtenFloor_Op::print(OpAsmPrinter &printer) {
1066-
printDefaultTorchOp(printer, *this, 1, 1);
1067-
}
1068-
}];
1069-
}
1070-
10711026
def Torch_AtenCeilOp : Torch_Op<"aten.ceil", [
10721027
AllowsTypeRefinement,
10731028
HasValueSemantics,
@@ -3657,6 +3612,52 @@ def Torch_AtenMul_ScalarOp : Torch_Op<"aten.mul_.Scalar", [
36573612
}];
36583613
}
36593614

3615+
def Torch_AtenFloorOp : Torch_Op<"aten.floor", [
3616+
AllowsTypeRefinement,
3617+
HasValueSemantics,
3618+
ReadOnly
3619+
]> {
3620+
let summary = "Generated op for `aten::floor : (Tensor) -> (Tensor)`";
3621+
let arguments = (ins
3622+
AnyTorchTensorType:$self
3623+
);
3624+
let results = (outs
3625+
AnyTorchTensorType:$result
3626+
);
3627+
let hasCustomAssemblyFormat = 1;
3628+
let extraClassDefinition = [{
3629+
ParseResult AtenFloorOp::parse(OpAsmParser &parser, OperationState &result) {
3630+
return parseDefaultTorchOp(parser, result, 1, 1);
3631+
}
3632+
void AtenFloorOp::print(OpAsmPrinter &printer) {
3633+
printDefaultTorchOp(printer, *this, 1, 1);
3634+
}
3635+
}];
3636+
let hasCanonicalizer = 1;
3637+
}
3638+
3639+
def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [
3640+
IsTrailingUnderscoreInplaceVariant,
3641+
AllowsTypeRefinement
3642+
]> {
3643+
let summary = "Generated op for `aten::floor_ : (Tensor) -> (Tensor)`";
3644+
let arguments = (ins
3645+
Torch_NonValueTensorType:$self
3646+
);
3647+
let results = (outs
3648+
Torch_NonValueTensorType:$result
3649+
);
3650+
let hasCustomAssemblyFormat = 1;
3651+
let extraClassDefinition = [{
3652+
ParseResult AtenFloor_Op::parse(OpAsmParser &parser, OperationState &result) {
3653+
return parseDefaultTorchOp(parser, result, 1, 1);
3654+
}
3655+
void AtenFloor_Op::print(OpAsmPrinter &printer) {
3656+
printDefaultTorchOp(printer, *this, 1, 1);
3657+
}
3658+
}];
3659+
}
3660+
36603661
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
36613662
AllowsTypeRefinement,
36623663
HasValueSemantics,

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,22 @@ void AtenMulTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
11171117
});
11181118
}
11191119

1120+
//===----------------------------------------------------------------------===//
1121+
// AtenFloorOp
1122+
//===----------------------------------------------------------------------===//
1123+
void AtenFloorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1124+
MLIRContext *context) {
1125+
patterns.add(+[](AtenFloorOp op, PatternRewriter &rewriter) {
1126+
auto outputTy = op.getType().dyn_cast<ValueTensorType>();
1127+
if (outputTy && outputTy.hasDtype() &&
1128+
outputTy.getDtype().isa<mlir::IntegerType>()) {
1129+
rewriter.replaceOp(op, op.getSelf());
1130+
return success();
1131+
}
1132+
return failure();
1133+
});
1134+
}
1135+
11201136
//===----------------------------------------------------------------------===//
11211137
// AtenMulScalarOp
11221138
//===----------------------------------------------------------------------===//

python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,6 @@ def emit_with_mutating_variants(key, **kwargs):
273273
"aten::atan : (Tensor) -> (Tensor)",
274274
"aten::atan2 : (Tensor, Tensor) -> (Tensor)",
275275
"aten::neg : (Tensor) -> (Tensor)",
276-
"aten::floor : (Tensor) -> (Tensor)",
277276
"aten::ceil : (Tensor) -> (Tensor)",
278277
"aten::bitwise_not : (Tensor) -> (Tensor)",
279278
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
@@ -333,6 +332,7 @@ def emit_with_mutating_variants(key, **kwargs):
333332
emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
334333
emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
335334
emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
335+
emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_canonicalizer=True)
336336

337337
emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
338338
emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")

python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,6 +1420,24 @@ def forward(self, a):
14201420
def ElementwiseFloorModule_basic(module, tu: TestUtils):
14211421
module.forward(tu.rand(3, 4))
14221422

1423+
class ElementwiseFloorIntModule(torch.nn.Module):
1424+
1425+
def __init__(self):
1426+
super().__init__()
1427+
1428+
@export
1429+
@annotate_args([
1430+
None,
1431+
([-1, -1], torch.int32, True),
1432+
])
1433+
def forward(self, a):
1434+
return torch.floor(a)
1435+
1436+
1437+
@register_test_case(module_factory=lambda: ElementwiseFloorIntModule())
1438+
def ElementwiseFloorIntModule_basic(module, tu: TestUtils):
1439+
module.forward(tu.randint(3, 4, low=-10, high=10).to(torch.int32))
1440+
14231441

14241442
# ==============================================================================
14251443

0 commit comments

Comments
 (0)