Skip to content

Commit 3398078

Browse files
committed
added the code for logcumsumexpOp
1 parent 4e2d0fd commit 3398078

File tree

8 files changed

+201
-0
lines changed

8 files changed

+201
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8625,6 +8625,54 @@ def Torch_AtenCumprodOp : Torch_Op<"aten.cumprod", [
86258625
}];
86268626
}
86278627

8628+
def Torch_AtenLogcumsumexpOp : Torch_Op<"aten.logcumsumexp", [
8629+
AllowsTypeRefinement,
8630+
HasValueSemantics,
8631+
ReadOnly
8632+
]> {
8633+
let summary = "Generated op for `aten::logcumsumexp : (Tensor, int) -> (Tensor)`";
8634+
let arguments = (ins
8635+
AnyTorchTensorType:$self,
8636+
Torch_IntType:$dim
8637+
);
8638+
let results = (outs
8639+
AnyTorchOptionalTensorType:$result
8640+
);
8641+
let hasCustomAssemblyFormat = 1;
8642+
let extraClassDefinition = [{
8643+
ParseResult AtenLogcumsumexpOp::parse(OpAsmParser &parser, OperationState &result) {
8644+
return parseDefaultTorchOp(parser, result, 2, 1);
8645+
}
8646+
void AtenLogcumsumexpOp::print(OpAsmPrinter &printer) {
8647+
printDefaultTorchOp(printer, *this, 2, 1);
8648+
}
8649+
}];
8650+
}
8651+
8652+
def Torch_Aten_LogcumsumexpOp : Torch_Op<"aten._logcumsumexp", [
8653+
AllowsTypeRefinement,
8654+
HasValueSemantics,
8655+
ReadOnly
8656+
]> {
8657+
let summary = "Generated op for `aten::_logcumsumexp : (Tensor, int) -> (Tensor)`";
8658+
let arguments = (ins
8659+
AnyTorchTensorType:$self,
8660+
Torch_IntType:$dim
8661+
);
8662+
let results = (outs
8663+
AnyTorchOptionalTensorType:$result
8664+
);
8665+
let hasCustomAssemblyFormat = 1;
8666+
let extraClassDefinition = [{
8667+
ParseResult Aten_LogcumsumexpOp::parse(OpAsmParser &parser, OperationState &result) {
8668+
return parseDefaultTorchOp(parser, result, 2, 1);
8669+
}
8670+
void Aten_LogcumsumexpOp::print(OpAsmPrinter &printer) {
8671+
printDefaultTorchOp(printer, *this, 2, 1);
8672+
}
8673+
}];
8674+
}
8675+
86288676
def Torch_AtenFloorDivideScalarOp : Torch_Op<"aten.floor_divide.Scalar", [
86298677
AllowsTypeRefinement,
86308678
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9361,6 +9361,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
93619361
" func.func @\"__torch_mlir_shape_fn.aten.cumprod\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
93629362
" return %arg0 : !torch.list<int>\n"
93639363
" }\n"
9364+
" func.func @\"__torch_mlir_shape_fn.aten.logcumsumexp\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
9365+
" return %arg0 : !torch.list<int>\n"
9366+
" }\n"
93649367
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
93659368
" return %arg0 : !torch.list<int>\n"
93669369
" }\n"
@@ -12507,6 +12510,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1250712510
" }\n"
1250812511
" return %1 : !torch.int\n"
1250912512
" }\n"
12513+
" func.func @\"__torch_mlir_dtype_fn.aten.logcumsumexp\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
12514+
" %int1 = torch.constant.int 1\n"
12515+
" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
12516+
" return %0 : !torch.int\n"
12517+
" }\n"
1251012518
" func.func @\"__torch_mlir_dtype_fn.aten.cumprod\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.int {\n"
1251112519
" %int4 = torch.constant.int 4\n"
1251212520
" %none = torch.constant.none\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2884,6 +2884,65 @@ class DecomposeAten_LogSoftmaxOp : public OpRewritePattern<Aten_LogSoftmaxOp> {
28842884
};
28852885
} // namespace
28862886

2887+
// Decompose AtenLogCumsumExpOp to:
2888+
// AtenExpOp
2889+
// AtenCumsumOp
2890+
// AtenLogOp
2891+
namespace {
2892+
2893+
class DecomposeAtenLogCumsumExpOp
2894+
: public OpRewritePattern<AtenLogcumsumexpOp> {
2895+
public:
2896+
using OpRewritePattern<AtenLogcumsumexpOp>::OpRewritePattern;
2897+
LogicalResult matchAndRewrite(AtenLogcumsumexpOp op,
2898+
PatternRewriter &rewriter) const override {
2899+
Location loc = op.getLoc();
2900+
Value input = op.getSelf();
2901+
2902+
auto inputType = dyn_cast<BaseTensorType>(input.getType());
2903+
if (!inputType || !inputType.getDtype())
2904+
return rewriter.notifyMatchFailure(op, "Supports only tensor type");
2905+
2906+
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype()))
2907+
return rewriter.notifyMatchFailure(op, "Support only floating type");
2908+
2909+
Type elementType = inputType.getDtype();
2910+
torch_upstream::ScalarType scalarType;
2911+
// logcumsumexp is only supported for Float datatype
2912+
if (elementType.isF16())
2913+
scalarType = torch_upstream::ScalarType::Half;
2914+
else if (elementType.isF32())
2915+
scalarType = torch_upstream::ScalarType::Float;
2916+
else
2917+
scalarType = torch_upstream::ScalarType::Double;
2918+
2919+
int64_t scalarVal = static_cast<int64_t>(scalarType);
2920+
2921+
Value dtypeVal = rewriter.create<Torch::ConstantIntOp>(
2922+
loc, rewriter.getType<Torch::IntType>(), scalarVal);
2923+
2924+
int64_t inputRank = inputType.getSizes().size();
2925+
int64_t dim;
2926+
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
2927+
return rewriter.notifyMatchFailure(
2928+
op, "Only constant dim value is supported");
2929+
dim = toPositiveDim(dim, inputRank);
2930+
if (!isValidDim(dim, inputRank))
2931+
return rewriter.notifyMatchFailure(op, "invalid dim");
2932+
2933+
Value expInput = rewriter.create<AtenExpOp>(loc, input.getType(), input);
2934+
2935+
Value cumsum = rewriter.create<AtenCumsumOp>(
2936+
loc, expInput.getType(), expInput, op.getDim(), dtypeVal);
2937+
2938+
Value result = rewriter.create<AtenLogOp>(loc, cumsum.getType(), cumsum);
2939+
2940+
rewriter.replaceOp(op, result);
2941+
return success();
2942+
}
2943+
};
2944+
} // namespace
2945+
28872946
namespace {
28882947
class DecomposeAtenLogSigmoidOp : public OpRewritePattern<AtenLogSigmoidOp> {
28892948
public:
@@ -11929,6 +11988,7 @@ class DecomposeComplexOpsPass
1192911988
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxOp>(patterns);
1193011989
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSoftmaxIntOp>(patterns);
1193111990
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSigmoidOp>(patterns);
11991+
addPatternIfTargetOpIsIllegal<DecomposeAtenLogCumsumExpOp>(patterns);
1193211992
addPatternIfTargetOpIsIllegal<DecomposeAtenHardshrinkOp>(patterns);
1193311993
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftshrinkOp>(patterns);
1193411994
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyLikeOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
375375
target.addIllegalOp<Aten_LogSoftmaxOp>();
376376
target.addIllegalOp<AtenLogSoftmaxIntOp>();
377377
target.addIllegalOp<AtenLogSigmoidOp>();
378+
target.addIllegalOp<AtenLogcumsumexpOp>();
378379
target.addIllegalOp<AtenHardshrinkOp>();
379380
target.addIllegalOp<AtenSoftshrinkOp>();
380381
target.addIllegalOp<AtenEmptyLikeOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2965,6 +2965,10 @@
29652965
"LinalgNormKeepDimComplexModule_basic",
29662966
"LinalgVectorNormComplexModule_basic",
29672967
"LogSoftmaxBackwardModule_basic",
2968+
"LogCumsumExpModule_basic",
2969+
"LogCumsumExpStaticModule_basic",
2970+
"LogCumsumExpStaticNegativeDimModule_basic",
2971+
"LogCumsumExpDtypeModule_basic",
29682972
"MaxPool1dCeilModeTrueModule_basic",
29692973
"MaxPool1dModule_basic",
29702974
"MaxPool2dCeilModeTrueModule_basic",
@@ -3683,6 +3687,10 @@
36833687
"LinalgNormKeepDimComplexModule_basic",
36843688
"LinalgVectorNormComplexModule_basic",
36853689
"LinspaceEmptyModule_basic",
3690+
"LogCumsumExpModule_basic",
3691+
"LogCumsumExpStaticModule_basic",
3692+
"LogCumsumExpStaticNegativeDimModule_basic",
3693+
"LogCumsumExpDtypeModule_basic",
36863694
"MaskedScatterStaticBasic_basic",
36873695
"MaxPool1dCeilModeTrueModule_basic",
36883696
"MaxPool1dModule_basic",
@@ -4472,6 +4480,10 @@
44724480
"LinalgVectorNormComplexModule_basic",
44734481
"LogSoftmaxBackwardModule_basic",
44744482
"LogSoftmaxIntModule_basic",
4483+
"logCumsumExpModule_basic",
4484+
"LogCumsumExpStaticModule_basic",
4485+
"LogCumsumExpStaticNegativeDimModule_basic",
4486+
"LogCumsumExpDtypeModule_basic",
44754487
"MaskedFillTensorFloatValueModule_basic",
44764488
"MatmulBroadcastBatchDim_basic",
44774489
"MatmulSingleDynamicBatchDim_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1537,6 +1537,9 @@ def aten〇cumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None
15371537
def aten〇cumprod〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
15381538
return self
15391539

1540+
def aten〇logcumsumexp〡shape(self: List[int], dim: int) -> List[int]:
1541+
return self
1542+
15401543
def aten〇rand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
15411544
return self
15421545

@@ -3217,6 +3220,10 @@ def aten〇cumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Opt
32173220
return torch.int64
32183221
return self_dtype
32193222

3223+
@check_dtype_function(
3224+
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0))
3225+
def aten〇logcumsumexp〡dtype(self_rank_dtype: Tuple[int, int], dim: int) -> int:
3226+
return self_rank_dtype[1]
32203227

32213228
@check_dtype_function(
32223229
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) +

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,8 @@ def emit_with_mutating_variants(key, **kwargs):
717717
emit("aten::bmm : (Tensor, Tensor) -> (Tensor)")
718718
emit("aten::cumsum : (Tensor, int, int?) -> (Tensor)")
719719
emit("aten::cumprod : (Tensor, int, int?) -> (Tensor)")
720+
emit("aten::logcumsumexp : (Tensor, int) -> (Tensor)")
721+
emit("aten::_logcumsumexp : (Tensor, int) -> (Tensor)")
720722
emit("aten::floor_divide.Scalar : (Tensor, Scalar) -> (Tensor)")
721723
emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)")
722724
emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5048,6 +5048,69 @@ def CumsumWithDtypeModule_basic(module, tu: TestUtils):
50485048
# ==============================================================================
50495049

50505050

5051+
class LogCumsumExpModule(torch.nn.Module):
5052+
def __init__(self):
5053+
super().__init__()
5054+
5055+
@export
5056+
@annotate_args([None, ([-1, -1, -1], torch.float32, True)])
5057+
def forward(self, x):
5058+
return torch.ops.aten.logcumsumexp(x, dim=1)
5059+
5060+
5061+
@register_test_case(module_factory=lambda: LogCumsumExpModule())
5062+
def LogCumsumExpModule_basic(module, tu: TestUtils):
5063+
module.forward(tu.rand(1, 2, 3))
5064+
5065+
5066+
class LogCumsumExpStaticModule(torch.nn.Module):
5067+
def __init__(self):
5068+
super().__init__()
5069+
5070+
@export
5071+
@annotate_args([None, ([1, 2, 3], torch.float32, True)])
5072+
def forward(self, x):
5073+
return torch.ops.aten.logcumsumexp(x, dim=1)
5074+
5075+
5076+
@register_test_case(module_factory=lambda: LogCumsumExpStaticModule())
5077+
def LogCumsumExpStaticModule_basic(module, tu: TestUtils):
5078+
module.forward(tu.rand(1, 2, 3))
5079+
5080+
5081+
class LogCumsumExpStaticNegativeDimModule(torch.nn.Module):
5082+
def __init__(self):
5083+
super().__init__()
5084+
5085+
@export
5086+
@annotate_args([None, ([8, 5, 6], torch.float32, True)])
5087+
def forward(self, x):
5088+
return torch.ops.aten.logcumsumexp(x, dim=-2)
5089+
5090+
5091+
@register_test_case(module_factory=lambda: LogCumsumExpStaticNegativeDimModule())
5092+
def LogCumsumExpStaticNegativeDimModule_basic(module, tu: TestUtils):
5093+
module.forward(tu.rand(8, 5, 6))
5094+
5095+
5096+
class LogCumsumExpDtypeModule(torch.nn.Module):
5097+
def __init__(self):
5098+
super().__init__()
5099+
5100+
@export
5101+
@annotate_args([None, ([5, 3, 6, 9], torch.float64, True)])
5102+
def forward(self, x):
5103+
return torch.ops.aten.logcumsumexp(x, dim=1)
5104+
5105+
5106+
@register_test_case(module_factory=lambda: LogCumsumExpDtypeModule())
5107+
def LogCumsumExpDtypeModule_basic(module, tu: TestUtils):
5108+
module.forward(tu.rand(5, 3, 6, 9).to(torch.float64))
5109+
5110+
5111+
# ==============================================================================
5112+
5113+
50515114
class CumprodModule(torch.nn.Module):
50525115
def __init__(self):
50535116
super().__init__()

0 commit comments

Comments
 (0)