Skip to content

Commit 67c84b7

Browse files
committed
Implement lowering of aten::exponential.
1 parent 11cc92d commit 67c84b7

File tree

6 files changed

+109
-0
lines changed

6 files changed

+109
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4739,6 +4739,31 @@ def Torch_AtenBernoulliPOp : Torch_Op<"aten.bernoulli.p", [
47394739
}];
47404740
}
47414741

4742+
def Torch_AtenExponentialOp : Torch_Op<"aten.exponential", [
4743+
AllowsTypeRefinement,
4744+
HasValueSemantics,
4745+
ReadOnly
4746+
]> {
4747+
let summary = "Generated op for `aten::exponential : (Tensor, float, Generator?) -> (Tensor)`";
4748+
let arguments = (ins
4749+
AnyTorchTensorType:$self,
4750+
Torch_FloatType:$lambd,
4751+
AnyTorchOptionalGeneratorType:$generator
4752+
);
4753+
let results = (outs
4754+
AnyTorchTensorType:$result
4755+
);
4756+
let hasCustomAssemblyFormat = 1;
4757+
let extraClassDefinition = [{
4758+
ParseResult AtenExponentialOp::parse(OpAsmParser &parser, OperationState &result) {
4759+
return parseDefaultTorchOp(parser, result, 3, 1);
4760+
}
4761+
void AtenExponentialOp::print(OpAsmPrinter &printer) {
4762+
printDefaultTorchOp(printer, *this, 3, 1);
4763+
}
4764+
}];
4765+
}
4766+
47424767
def Torch_AtenMultinomialOp : Torch_Op<"aten.multinomial", [
47434768
AllowsTypeRefinement,
47444769
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7580,6 +7580,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
75807580
" func.func @\"__torch_mlir_shape_fn.aten.uniform\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list<int> {\n"
75817581
" return %arg0 : !torch.list<int>\n"
75827582
" }\n"
7583+
" func.func @\"__torch_mlir_shape_fn.aten.exponential\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.list<int> {\n"
7584+
" return %arg0 : !torch.list<int>\n"
7585+
" }\n"
75837586
" func.func @\"__torch_mlir_shape_fn.aten.rand\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.list<int> {\n"
75847587
" return %arg0 : !torch.list<int>\n"
75857588
" }\n"
@@ -9382,6 +9385,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
93829385
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
93839386
" return %0#1 : !torch.int\n"
93849387
" }\n"
9388+
" func.func @\"__torch_mlir_dtype_fn.aten.exponential\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n"
9389+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
9390+
" return %0#1 : !torch.int\n"
9391+
" }\n"
93859392
" func.func @\"__torch_mlir_dtype_fn.aten.rand\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>) -> !torch.int {\n"
93869393
" %int6 = torch.constant.int 6\n"
93879394
" %none = torch.constant.none\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3591,6 +3591,51 @@ class DecomposeAtenBernoulliTensorOp
35913591
};
35923592
} // namespace
35933593

3594+
namespace {
3595+
// In general, a function g(x) produces a random variable of probability density
3596+
// function f(X) if g(x) := inverse(F(x)) where
3597+
// - F(x) is the cumulative distribution function; i.e.
3598+
// F(x) := integrate(g(u), {u, 0, x}) and
3599+
// - x is sampled from the uniform distribution [0, 1) (half-closed interval)
3600+
// With the exponential distribution, F(x) = 1 - exp(-lambda * x). Thus, we get
3601+
// exponential(x) = - ln(1 - uniform(0, 1)) / lambda.
3602+
class DecomposeAtenExponentialOp : public OpRewritePattern<AtenExponentialOp> {
3603+
public:
3604+
using OpRewritePattern::OpRewritePattern;
3605+
LogicalResult matchAndRewrite(AtenExponentialOp op,
3606+
PatternRewriter &rewriter) const override {
3607+
if (!op.getGenerator().getType().isa<Torch::NoneType>())
3608+
return rewriter.notifyMatchFailure(
3609+
op, "The generator has to be None because only global default "
3610+
"generator is supported");
3611+
3612+
Location loc = op.getLoc();
3613+
Type resultType = op.getType();
3614+
3615+
// Create a uniform random op with low and high set to 0.0 and 1.0,
3616+
// respectively.
3617+
Value none = rewriter.create<ConstantNoneOp>(loc);
3618+
Value zero =
3619+
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
3620+
Value one =
3621+
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
3622+
Value x = rewriter.create<AtenUniformOp>(loc, resultType, op.getSelf(),
3623+
/*from=*/zero, /*to=*/one,
3624+
/*generator=*/none);
3625+
3626+
Value negX = rewriter.create<AtenNegOp>(loc, resultType, x);
3627+
Value oneMinusX =
3628+
rewriter.create<AtenAddScalarOp>(loc, resultType, negX, one,
3629+
/*alpha=*/one);
3630+
Value lnOneMinusX = rewriter.create<AtenLogOp>(loc, resultType, oneMinusX);
3631+
Value negLambda = rewriter.create<AtenNegFloatOp>(loc, op.getLambd());
3632+
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, resultType, lnOneMinusX,
3633+
negLambda);
3634+
return success();
3635+
}
3636+
};
3637+
} // namespace
3638+
35943639
namespace {
35953640
template <typename OpTy, typename T1T2Op>
35963641
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
@@ -6439,6 +6484,7 @@ class DecomposeComplexOpsPass
64396484
addPatternIfTargetOpIsIllegal<
64406485
DecomposeAtenBernoulliLikeOp<AtenBernoulliPOp>>(patterns);
64416486
addPatternIfTargetOpIsIllegal<DecomposeAtenBernoulliTensorOp>(patterns);
6487+
addPatternIfTargetOpIsIllegal<DecomposeAtenExponentialOp>(patterns);
64426488
addPatternIfTargetOpIsIllegal<DecomposeAtenZeroOp>(patterns);
64436489
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeOp>(patterns);
64446490
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,9 @@ def aten〇copy〡shape(self: List[int], src: List[int], non_blocking: bool = Fa
831831
def aten〇uniform〡shape(self: List[int], from_: float = 0., to: float = 1., generator: Any = None) -> List[int]:
832832
return self
833833

834+
def aten〇exponential〡shape(self: List[int], lambd: float = 1., generator: Any = None) -> List[int]:
835+
return self
836+
834837
def aten〇rand〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
835838
return size
836839

@@ -2267,6 +2270,10 @@ def aten〇uniform〡dtype(self_rank_dtype: Tuple[int, int], from_: float = 0.,
22672270
self_rank, self_dtype = self_rank_dtype
22682271
return self_dtype
22692272

2273+
def aten〇exponential〡dtype(self_rank_dtype: Tuple[int, int], lambd: float = 1., generator: Any = None) -> int:
2274+
self_rank, self_dtype = self_rank_dtype
2275+
return self_dtype
2276+
22702277
@check_dtype_function([Invocation([1]),
22712278
Invocation([1], dtype=torch.float16),
22722279
Invocation([1], dtype=torch.complex64)])

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ def emit_with_mutating_variants(key, **kwargs):
378378
emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)")
379379
emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)")
380380
emit("aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)")
381+
emit("aten::exponential : (Tensor, float, Generator?) -> (Tensor)")
381382
emit("aten::multinomial : (Tensor, int, bool, Generator?) -> (Tensor)")
382383
emit("aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)")
383384
emit("aten::randint : (int, int[], int?, int?, Device?, bool?) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,29 @@ def UniformNoCorrelationModule_basic(module, tu: TestUtils):
157157

158158
# ==============================================================================
159159

160+
class ExponentialModule(torch.nn.Module):
161+
def __init__(self):
162+
super().__init__()
163+
164+
@export
165+
@annotate_args([
166+
None,
167+
([-1, -1, -1], torch.float64, True),
168+
])
169+
def forward(self, x):
170+
a = torch.ops.aten.exponential(x, 3.0)
171+
mean = torch.mean(a)
172+
std = torch.std(a)
173+
return mean, std
174+
175+
176+
@register_test_case(module_factory=lambda: ExponentialModule())
177+
def ExponentialModule_basic(module, tu: TestUtils):
178+
module.forward(
179+
tu.rand(512, 512, 16).double())
180+
181+
# ==============================================================================
182+
160183
class BernoulliModule(torch.nn.Module):
161184
def __init__(self):
162185
super().__init__()

0 commit comments

Comments
 (0)