Skip to content

Commit 8e389ff

Browse files
authored
Implement lowering of torch.aten.exponential (#2680)
#2646 Decompose aten.exponential() into: -exp(1-x)/lambda
1 parent d560698 commit 8e389ff

File tree

8 files changed

+111
-0
lines changed

8 files changed

+111
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4739,6 +4739,31 @@ def Torch_AtenBernoulliPOp : Torch_Op<"aten.bernoulli.p", [
47394739
}];
47404740
}
47414741

4742+
def Torch_AtenExponentialOp : Torch_Op<"aten.exponential", [
4743+
AllowsTypeRefinement,
4744+
HasValueSemantics,
4745+
ReadOnly
4746+
]> {
4747+
let summary = "Generated op for `aten::exponential : (Tensor, float, Generator?) -> (Tensor)`";
4748+
let arguments = (ins
4749+
AnyTorchTensorType:$self,
4750+
Torch_FloatType:$lambd,
4751+
AnyTorchOptionalGeneratorType:$generator
4752+
);
4753+
let results = (outs
4754+
AnyTorchTensorType:$result
4755+
);
4756+
let hasCustomAssemblyFormat = 1;
4757+
let extraClassDefinition = [{
4758+
ParseResult AtenExponentialOp::parse(OpAsmParser &parser, OperationState &result) {
4759+
return parseDefaultTorchOp(parser, result, 3, 1);
4760+
}
4761+
void AtenExponentialOp::print(OpAsmPrinter &printer) {
4762+
printDefaultTorchOp(printer, *this, 3, 1);
4763+
}
4764+
}];
4765+
}
4766+
47424767
def Torch_AtenMultinomialOp : Torch_Op<"aten.multinomial", [
47434768
AllowsTypeRefinement,
47444769
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7580,6 +7580,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
75807580
" func.func @\"__torch_mlir_shape_fn.aten.uniform\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list<int> {\n"
75817581
" return %arg0 : !torch.list<int>\n"
75827582
" }\n"
7583+
" func.func @\"__torch_mlir_shape_fn.aten.exponential\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.list<int> {\n"
7584+
" return %arg0 : !torch.list<int>\n"
7585+
" }\n"
75837586
" func.func @\"__torch_mlir_shape_fn.aten.rand\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
75847587
" return %arg0 : !torch.list<int>\n"
75857588
" }\n"
@@ -9382,6 +9385,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
93829385
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
93839386
" return %0#1 : !torch.int\n"
93849387
" }\n"
9388+
" func.func @\"__torch_mlir_dtype_fn.aten.exponential\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n"
9389+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
9390+
" return %0#1 : !torch.int\n"
9391+
" }\n"
93859392
" func.func @\"__torch_mlir_dtype_fn.aten.rand\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
93869393
" %int6 = torch.constant.int 6\n"
93879394
" %none = torch.constant.none\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3562,6 +3562,51 @@ class DecomposeAtenBernoulliTensorOp
35623562
};
35633563
} // namespace
35643564

3565+
namespace {
3566+
// Decompose exponential() to do inverse transform sampling.
3567+
// - https://en.wikipedia.org/wiki/Inverse_transform_sampling
3568+
// With the exponential distribution, F(x) = 1 - exp(-lambda * x). Thus,
3569+
// exponential() = - ln(1 - uniform(0, 1)) / lambda.
3570+
class DecomposeAtenExponentialOp : public OpRewritePattern<AtenExponentialOp> {
3571+
public:
3572+
using OpRewritePattern::OpRewritePattern;
3573+
LogicalResult matchAndRewrite(AtenExponentialOp op,
3574+
PatternRewriter &rewriter) const override {
3575+
if (!op.getGenerator().getType().isa<Torch::NoneType>())
3576+
return rewriter.notifyMatchFailure(
3577+
op, "The generator has to be None because only global default "
3578+
"generator is supported");
3579+
3580+
Location loc = op.getLoc();
3581+
Type resultType = op.getType();
3582+
3583+
// Create a uniform random op with low and high set to 0.0 and 1.0,
3584+
// respectively.
3585+
Value none = rewriter.create<ConstantNoneOp>(loc);
3586+
Value zero =
3587+
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
3588+
Value one =
3589+
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
3590+
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
3591+
loc, resultType, op.getSelf(), zero, /*dtype=*/none, /*layout=*/none,
3592+
/*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
3593+
Value x = rewriter.create<AtenUniformOp>(loc, resultType, emptyTensor,
3594+
/*from=*/zero, /*to=*/one,
3595+
/*generator=*/none);
3596+
3597+
Value negX = rewriter.create<AtenNegOp>(loc, resultType, x);
3598+
Value oneMinusX =
3599+
rewriter.create<AtenAddScalarOp>(loc, resultType, negX, one,
3600+
/*alpha=*/one);
3601+
Value lnOneMinusX = rewriter.create<AtenLogOp>(loc, resultType, oneMinusX);
3602+
Value negLambda = rewriter.create<AtenNegFloatOp>(loc, op.getLambd());
3603+
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, resultType, lnOneMinusX,
3604+
negLambda);
3605+
return success();
3606+
}
3607+
};
3608+
} // namespace
3609+
35653610
namespace {
35663611
template <typename OpTy, typename T1T2Op>
35673612
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
@@ -6410,6 +6455,7 @@ class DecomposeComplexOpsPass
64106455
addPatternIfTargetOpIsIllegal<
64116456
DecomposeAtenBernoulliLikeOp<AtenBernoulliPOp>>(patterns);
64126457
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliTensorOp>(patterns);
6458+
addPatternIfTargetOpIsIllegal<DecomposeAtenExponentialOp>(patterns);
64136459
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
64146460
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeOp>(patterns);
64156461
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
427427
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
428428
target.addIllegalOp<AtenBernoulliPOp>();
429429
target.addIllegalOp<AtenBernoulliTensorOp>();
430+
target.addIllegalOp<AtenExponentialOp>();
430431
target.addIllegalOp<AtenZeroOp>();
431432
target.addIllegalOp<AtenEyeOp>();
432433
target.addIllegalOp<AtenEyeMOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,6 +1397,7 @@
13971397
"CeilFloatModule_basic",
13981398
"DivFloatModule_basic",
13991399
"EqIntModule_basic",
1400+
"ExponentialModule_basic",
14001401
"GeFloatIntModule_basic",
14011402
"GeFloatModule_basic",
14021403
"GeIntModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,9 @@ def aten〇copy〡shape(self: List[int], src: List[int], non_blocking: bool = Fa
831831
def aten〇uniform〡shape(self: List[int], from_: float = 0., to: float = 1., generator: Any = None) -> List[int]:
832832
return self
833833

834+
def aten〇exponential〡shape(self: List[int], lambd: float = 1., generator: Any = None) -> List[int]:
835+
return self
836+
834837
def aten〇rand〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
835838
return size
836839

@@ -2267,6 +2270,10 @@ def aten〇uniform〡dtype(self_rank_dtype: Tuple[int, int], from_: float = 0.,
22672270
self_rank, self_dtype = self_rank_dtype
22682271
return self_dtype
22692272

2273+
def aten〇exponential〡dtype(self_rank_dtype: Tuple[int, int], lambd: float = 1., generator: Any = None) -> int:
2274+
self_rank, self_dtype = self_rank_dtype
2275+
return self_dtype
2276+
22702277
@check_dtype_function([Invocation([1]),
22712278
Invocation([1], dtype=torch.float16),
22722279
Invocation([1], dtype=torch.complex64)])

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ def emit_with_mutating_variants(key, **kwargs):
378378
emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)")
379379
emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)")
380380
emit("aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)")
381+
emit("aten::exponential : (Tensor, float, Generator?) -> (Tensor)")
381382
emit("aten::multinomial : (Tensor, int, bool, Generator?) -> (Tensor)")
382383
emit("aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)")
383384
emit("aten::randint : (int, int[], int?, int?, Device?, bool?) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,29 @@ def UniformNoCorrelationModule_basic(module, tu: TestUtils):
157157

158158
# ==============================================================================
159159

160+
class ExponentialModule(torch.nn.Module):
161+
def __init__(self):
162+
super().__init__()
163+
164+
@export
165+
@annotate_args([
166+
None,
167+
([-1, -1, -1], torch.float64, True),
168+
])
169+
def forward(self, x):
170+
a = torch.ops.aten.exponential(x, 3.0)
171+
mean = torch.mean(a)
172+
std = torch.std(a)
173+
return mean, std
174+
175+
176+
@register_test_case(module_factory=lambda: ExponentialModule())
177+
def ExponentialModule_basic(module, tu: TestUtils):
178+
module.forward(
179+
tu.rand(512, 512, 16).double())
180+
181+
# ==============================================================================
182+
160183
class BernoulliModule(torch.nn.Module):
161184
def __init__(self):
162185
super().__init__()

0 commit comments

Comments
 (0)