@@ -3562,6 +3562,51 @@ class DecomposeAtenBernoulliTensorOp
3562
3562
};
3563
3563
} // namespace
3564
3564
3565
+ namespace {
3566
+ // Decompose exponential() to do inverse transform sampling.
3567
+ // - https://en.wikipedia.org/wiki/Inverse_transform_sampling
3568
+ // With the exponential distribution, F(x) = 1 - exp(-lambda * x). Thus,
3569
+ // exponential() = - ln(1 - uniform(0, 1)) / lambda.
3570
+ class DecomposeAtenExponentialOp : public OpRewritePattern <AtenExponentialOp> {
3571
+ public:
3572
+ using OpRewritePattern::OpRewritePattern;
3573
+ LogicalResult matchAndRewrite (AtenExponentialOp op,
3574
+ PatternRewriter &rewriter) const override {
3575
+ if (!op.getGenerator ().getType ().isa <Torch::NoneType>())
3576
+ return rewriter.notifyMatchFailure (
3577
+ op, " The generator has to be None because only global default "
3578
+ " generator is supported" );
3579
+
3580
+ Location loc = op.getLoc ();
3581
+ Type resultType = op.getType ();
3582
+
3583
+ // Create a uniform random op with low and high set to 0.0 and 1.0,
3584
+ // respectively.
3585
+ Value none = rewriter.create <ConstantNoneOp>(loc);
3586
+ Value zero =
3587
+ rewriter.create <ConstantFloatOp>(loc, rewriter.getF64FloatAttr (0.0 ));
3588
+ Value one =
3589
+ rewriter.create <ConstantFloatOp>(loc, rewriter.getF64FloatAttr (1.0 ));
3590
+ Value emptyTensor = rewriter.create <AtenFullLikeOp>(
3591
+ loc, resultType, op.getSelf (), zero, /* dtype=*/ none, /* layout=*/ none,
3592
+ /* device=*/ none, /* pin_memoty=*/ none, /* memory_format=*/ none);
3593
+ Value x = rewriter.create <AtenUniformOp>(loc, resultType, emptyTensor,
3594
+ /* from=*/ zero, /* to=*/ one,
3595
+ /* generator=*/ none);
3596
+
3597
+ Value negX = rewriter.create <AtenNegOp>(loc, resultType, x);
3598
+ Value oneMinusX =
3599
+ rewriter.create <AtenAddScalarOp>(loc, resultType, negX, one,
3600
+ /* alpha=*/ one);
3601
+ Value lnOneMinusX = rewriter.create <AtenLogOp>(loc, resultType, oneMinusX);
3602
+ Value negLambda = rewriter.create <AtenNegFloatOp>(loc, op.getLambd ());
3603
+ rewriter.replaceOpWithNewOp <AtenDivScalarOp>(op, resultType, lnOneMinusX,
3604
+ negLambda);
3605
+ return success ();
3606
+ }
3607
+ };
3608
+ } // namespace
3609
+
3565
3610
namespace {
3566
3611
template <typename OpTy, typename T1T2Op>
3567
3612
class DecomposeAtenAddCLikeOp : public OpRewritePattern <OpTy> {
@@ -6410,6 +6455,7 @@ class DecomposeComplexOpsPass
6410
6455
addPatternIfTargetOpIsIllegal<
6411
6456
DecomposeAtenBernoulliLikeOp<AtenBernoulliPOp>>(patterns);
6412
6457
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliTensorOp>(patterns);
6458
+ addPatternIfTargetOpIsIllegal<DecomposeAtenExponentialOp>(patterns);
6413
6459
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
6414
6460
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeOp>(patterns);
6415
6461
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);
0 commit comments