Skip to content

Commit 521b58d

Browse files
Revert "[TorchToLinalg] Implement lowering of torch.aten.rrelu_with_noise and torch.aten.rrelu_with_noise_backward ops (fix) (#3748)"
This reverts commit 54d9e24.
1 parent c1892de commit 521b58d

File tree

9 files changed

+0
-568
lines changed

9 files changed

+0
-568
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 0 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -309,61 +309,6 @@ def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [
309309
}];
310310
}
311311

312-
def Torch_AtenRreluWithNoiseOp : Torch_Op<"aten.rrelu_with_noise", [
313-
AllowsTypeRefinement,
314-
HasValueSemantics,
315-
ReadOnly
316-
]> {
317-
let summary = "Generated op for `aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
318-
let arguments = (ins
319-
AnyTorchTensorType:$self,
320-
AnyTorchTensorType:$noise,
321-
AnyTorchScalarType:$lower,
322-
AnyTorchScalarType:$upper,
323-
Torch_BoolType:$training,
324-
AnyTorchOptionalGeneratorType:$generator
325-
);
326-
let results = (outs
327-
AnyTorchOptionalTensorType:$result
328-
);
329-
let hasCustomAssemblyFormat = 1;
330-
let extraClassDefinition = [{
331-
ParseResult AtenRreluWithNoiseOp::parse(OpAsmParser &parser, OperationState &result) {
332-
return parseDefaultTorchOp(parser, result, 6, 1);
333-
}
334-
void AtenRreluWithNoiseOp::print(OpAsmPrinter &printer) {
335-
printDefaultTorchOp(printer, *this, 6, 1);
336-
}
337-
}];
338-
}
339-
340-
def Torch_AtenRreluWithNoise_Op : Torch_Op<"aten.rrelu_with_noise_", [
341-
IsTrailingUnderscoreInplaceVariant,
342-
AllowsTypeRefinement
343-
]> {
344-
let summary = "Generated op for `aten::rrelu_with_noise_ : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
345-
let arguments = (ins
346-
Torch_NonValueTensorType:$self,
347-
Torch_NonValueTensorType:$noise,
348-
AnyTorchScalarType:$lower,
349-
AnyTorchScalarType:$upper,
350-
Torch_BoolType:$training,
351-
AnyTorchOptionalGeneratorType:$generator
352-
);
353-
let results = (outs
354-
AnyTorchOptionalNonValueTensorType:$result
355-
);
356-
let hasCustomAssemblyFormat = 1;
357-
let extraClassDefinition = [{
358-
ParseResult AtenRreluWithNoise_Op::parse(OpAsmParser &parser, OperationState &result) {
359-
return parseDefaultTorchOp(parser, result, 6, 1);
360-
}
361-
void AtenRreluWithNoise_Op::print(OpAsmPrinter &printer) {
362-
printDefaultTorchOp(printer, *this, 6, 1);
363-
}
364-
}];
365-
}
366-
367312
def Torch_AtenCeluOp : Torch_Op<"aten.celu", [
368313
AllowsTypeRefinement,
369314
HasValueSemantics,
@@ -17467,35 +17412,6 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [
1746717412
}];
1746817413
}
1746917414

17470-
def Torch_AtenRreluWithNoiseBackwardOp : Torch_Op<"aten.rrelu_with_noise_backward", [
17471-
AllowsTypeRefinement,
17472-
HasValueSemantics,
17473-
ReadOnly
17474-
]> {
17475-
let summary = "Generated op for `aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)`";
17476-
let arguments = (ins
17477-
AnyTorchTensorType:$grad_output,
17478-
AnyTorchTensorType:$self,
17479-
AnyTorchTensorType:$noise,
17480-
AnyTorchScalarType:$lower,
17481-
AnyTorchScalarType:$upper,
17482-
Torch_BoolType:$training,
17483-
Torch_BoolType:$self_is_result
17484-
);
17485-
let results = (outs
17486-
AnyTorchOptionalTensorType:$result
17487-
);
17488-
let hasCustomAssemblyFormat = 1;
17489-
let extraClassDefinition = [{
17490-
ParseResult AtenRreluWithNoiseBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
17491-
return parseDefaultTorchOp(parser, result, 7, 1);
17492-
}
17493-
void AtenRreluWithNoiseBackwardOp::print(OpAsmPrinter &printer) {
17494-
printDefaultTorchOp(printer, *this, 7, 1);
17495-
}
17496-
}];
17497-
}
17498-
1749917415
def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [
1750017416
AllowsTypeRefinement,
1750117417
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6690,10 +6690,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
66906690
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
66916691
" return %0 : !torch.list<int>\n"
66926692
" }\n"
6693-
" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.float, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.list<int> {\n"
6694-
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
6695-
" return %0 : !torch.list<int>\n"
6696-
" }\n"
66976693
" func.func @\"__torch_mlir_shape_fn.aten.hardtanh_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float) -> !torch.list<int> {\n"
66986694
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
66996695
" return %0 : !torch.list<int>\n"
@@ -7296,10 +7292,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
72967292
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
72977293
" return %0 : !torch.list<int>\n"
72987294
" }\n"
7299-
" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.list<int> {\n"
7300-
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
7301-
" return %0 : !torch.list<int>\n"
7302-
" }\n"
73037295
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
73047296
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
73057297
" return %0 : !torch.list<int>\n"
@@ -12409,14 +12401,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1240912401
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
1241012402
" return %4 : !torch.int\n"
1241112403
" }\n"
12412-
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number, %arg4: !torch.number, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.int {\n"
12413-
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12414-
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12415-
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
12416-
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
12417-
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
12418-
" return %4 : !torch.int\n"
12419-
" }\n"
1242012404
" func.func @\"__torch_mlir_dtype_fn.aten.lift_fresh_copy\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
1242112405
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1242212406
" return %0#1 : !torch.int\n"
@@ -12609,47 +12593,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1260912593
" }\n"
1261012594
" return %0#1 : !torch.int\n"
1261112595
" }\n"
12612-
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n"
12613-
" %none = torch.constant.none\n"
12614-
" %str = torch.constant.str \"AssertionError: \"\n"
12615-
" %true = torch.constant.bool true\n"
12616-
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12617-
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12618-
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
12619-
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
12620-
" torch.prim.If.yield %true : !torch.bool\n"
12621-
" } else {\n"
12622-
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
12623-
" torch.prim.If.yield %7 : !torch.bool\n"
12624-
" }\n"
12625-
" torch.prim.If %3 -> () {\n"
12626-
" torch.prim.If.yield\n"
12627-
" } else {\n"
12628-
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12629-
" torch.prim.If.yield\n"
12630-
" }\n"
12631-
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
12632-
" %5 = torch.prim.If %4 -> (!torch.bool) {\n"
12633-
" torch.prim.If.yield %true : !torch.bool\n"
12634-
" } else {\n"
12635-
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
12636-
" torch.prim.If.yield %7 : !torch.bool\n"
12637-
" }\n"
12638-
" torch.prim.If %5 -> () {\n"
12639-
" torch.prim.If.yield\n"
12640-
" } else {\n"
12641-
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12642-
" torch.prim.If.yield\n"
12643-
" }\n"
12644-
" %6 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n"
12645-
" torch.prim.If %6 -> () {\n"
12646-
" torch.prim.If.yield\n"
12647-
" } else {\n"
12648-
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12649-
" torch.prim.If.yield\n"
12650-
" }\n"
12651-
" return %0#1 : !torch.int\n"
12652-
" }\n"
1265312596
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
1265412597
" %none = torch.constant.none\n"
1265512598
" %str = torch.constant.str \"AssertionError: \"\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 0 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -3675,59 +3675,6 @@ class DecomposeAtenLeakyReluBackwardOp
36753675
};
36763676
} // namespace
36773677

3678-
namespace {
3679-
class DecomposeAtenRreluWithNoiseBackwardOp
3680-
: public OpRewritePattern<AtenRreluWithNoiseBackwardOp> {
3681-
public:
3682-
using OpRewritePattern::OpRewritePattern;
3683-
LogicalResult matchAndRewrite(AtenRreluWithNoiseBackwardOp op,
3684-
PatternRewriter &rewriter) const override {
3685-
Location loc = op.getLoc();
3686-
Value gradOutput = op.getGradOutput();
3687-
Value self = op.getSelf();
3688-
Value noise = op.getNoise();
3689-
auto resType = cast<BaseTensorType>(op.getType());
3690-
if (!resType.hasDtype()) {
3691-
return rewriter.notifyMatchFailure(op, "result should have dtype");
3692-
}
3693-
3694-
bool training;
3695-
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
3696-
return rewriter.notifyMatchFailure(op,
3697-
"training should be a bool constant");
3698-
}
3699-
3700-
bool selfIsResult = false;
3701-
if (!matchPattern(op.getSelfIsResult(),
3702-
m_TorchConstantBool(&selfIsResult)) ||
3703-
selfIsResult)
3704-
return rewriter.notifyMatchFailure(
3705-
op, "unimplemented: self_is_result should be false");
3706-
3707-
double lower, upper;
3708-
if (!matchPattern(op.getLower(), m_TorchConstantFloat(&lower)) ||
3709-
!matchPattern(op.getUpper(), m_TorchConstantFloat(&upper))) {
3710-
return rewriter.notifyMatchFailure(
3711-
op, "lower and upper should be float constants");
3712-
}
3713-
3714-
if (training && (upper - lower > 0.000001)) {
3715-
Value rreluWithNoiseBackwardOutput =
3716-
rewriter.create<AtenMulTensorOp>(loc, resType, gradOutput, noise);
3717-
rewriter.replaceOp(op, rreluWithNoiseBackwardOutput);
3718-
} else {
3719-
double negative_slope = (upper + lower) / 2;
3720-
Value cstNegativeSlope = rewriter.create<ConstantFloatOp>(
3721-
loc, rewriter.getF64FloatAttr(negative_slope));
3722-
rewriter.replaceOpWithNewOp<AtenLeakyReluBackwardOp>(
3723-
op, resType, gradOutput, self, cstNegativeSlope,
3724-
op.getSelfIsResult());
3725-
}
3726-
return success();
3727-
}
3728-
};
3729-
} // namespace
3730-
37313678
namespace {
37323679
class DecomposeAtenPreluOp : public OpRewritePattern<AtenPreluOp> {
37333680
public:
@@ -3827,82 +3774,6 @@ class DecomposeAtenRreluOp : public OpRewritePattern<AtenRreluOp> {
38273774
};
38283775
} // namespace
38293776

3830-
namespace {
3831-
class DecomposeAtenRreluWithNoiseOp
3832-
: public OpRewritePattern<AtenRreluWithNoiseOp> {
3833-
public:
3834-
using OpRewritePattern::OpRewritePattern;
3835-
LogicalResult matchAndRewrite(AtenRreluWithNoiseOp op,
3836-
PatternRewriter &rewriter) const override {
3837-
Location loc = op.getLoc();
3838-
Value self = op.getSelf();
3839-
Value noise = op.getNoise();
3840-
Value lower = op.getLower();
3841-
Value upper = op.getUpper();
3842-
auto resType = cast<BaseTensorType>(op.getType());
3843-
if (!resType.hasDtype()) {
3844-
return rewriter.notifyMatchFailure(op, "result should have dtype");
3845-
}
3846-
3847-
bool training;
3848-
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
3849-
return rewriter.notifyMatchFailure(op, "training should be a constant");
3850-
}
3851-
3852-
Value constantZeroFloat =
3853-
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
3854-
Value constantOneFloat =
3855-
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
3856-
Value constantTwoFloat =
3857-
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(2.0));
3858-
3859-
Value alpha;
3860-
if (training) {
3861-
Value none = rewriter.create<ConstantNoneOp>(loc);
3862-
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
3863-
loc, resType, self, constantZeroFloat, /*dtype=*/none,
3864-
/*layout=*/none,
3865-
/*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
3866-
alpha = rewriter.create<AtenUniformOp>(loc, resType, emptyTensor,
3867-
/*from=*/lower, /*to=*/upper,
3868-
/*generator=*/none);
3869-
} else {
3870-
Value half = rewriter.create<AtenAddOp>(loc, constantTwoFloat.getType(),
3871-
lower, upper);
3872-
alpha = rewriter.create<AtenDivOp>(loc, constantTwoFloat.getType(), half,
3873-
constantTwoFloat);
3874-
}
3875-
3876-
Value zeroTensor =
3877-
createRank0Tensor(rewriter, loc, resType, constantZeroFloat);
3878-
Value positiveOutput =
3879-
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, self);
3880-
3881-
Value scaledSelf;
3882-
if (training) {
3883-
scaledSelf = rewriter.create<AtenMulTensorOp>(loc, resType, self, alpha);
3884-
auto boolResType = resType.getWithSizesAndDtype(resType.getSizes(),
3885-
rewriter.getI1Type());
3886-
Value oneTensor =
3887-
createRank0Tensor(rewriter, loc, resType, constantOneFloat);
3888-
Value not_positive = rewriter.create<AtenLtScalarOp>(
3889-
loc, boolResType, self, constantZeroFloat);
3890-
noise = rewriter.create<AtenWhereSelfOp>(loc, resType, not_positive,
3891-
alpha, oneTensor);
3892-
} else {
3893-
scaledSelf = rewriter.create<AtenMulScalarOp>(loc, resType, self, alpha);
3894-
}
3895-
3896-
Value negativeOutput =
3897-
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledSelf);
3898-
Value rreluOutput = rewriter.create<AtenAddTensorOp>(
3899-
loc, resType, positiveOutput, negativeOutput, constantOneFloat);
3900-
rewriter.replaceOp(op, rreluOutput);
3901-
return success();
3902-
}
3903-
};
3904-
} // namespace
3905-
39063777
// CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1))
39073778
namespace {
39083779
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
@@ -11319,9 +11190,6 @@ class DecomposeComplexOpsPass
1131911190
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
1132011191
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
1132111192
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluOp>(patterns);
11322-
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseOp>(patterns);
11323-
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseBackwardOp>(
11324-
patterns);
1132511193
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
1132611194
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
1132711195
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -500,8 +500,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
500500
target.addIllegalOp<AtenPadOp>();
501501
target.addIllegalOp<AtenPreluOp>();
502502
target.addIllegalOp<AtenRreluOp>();
503-
target.addIllegalOp<AtenRreluWithNoiseOp>();
504-
target.addIllegalOp<AtenRreluWithNoiseBackwardOp>();
505503
target.addIllegalOp<AtenCeluOp>();
506504
target.addIllegalOp<AtenToDtypeLayoutOp>();
507505
target.addIllegalOp<AtenToDeviceOp>();

0 commit comments

Comments
 (0)