Skip to content

Commit cdba452

Browse files
committed
supported the _logcumsumexp Op
Signed-off-by: sharavana20 <[email protected]>
1 parent 3398078 commit cdba452

File tree

6 files changed

+111
-32
lines changed

6 files changed

+111
-32
lines changed

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9364,6 +9364,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
93649364
" func.func @\"__torch_mlir_shape_fn.aten.logcumsumexp\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
93659365
" return %arg0 : !torch.list<int>\n"
93669366
" }\n"
9367+
" func.func @\"__torch_mlir_shape_fn.aten._logcumsumexp\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
9368+
" return %arg0 : !torch.list<int>\n"
9369+
" }\n"
93679370
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
93689371
" return %arg0 : !torch.list<int>\n"
93699372
" }\n"
@@ -12511,9 +12514,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1251112514
" return %1 : !torch.int\n"
1251212515
" }\n"
1251312516
" func.func @\"__torch_mlir_dtype_fn.aten.logcumsumexp\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
12514-
" %int1 = torch.constant.int 1\n"
12515-
" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
12516-
" return %0 : !torch.int\n"
12517+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12518+
" return %0#1 : !torch.int\n"
12519+
" }\n"
12520+
" func.func @\"__torch_mlir_dtype_fn.aten._logcumsumexp\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
12521+
" %none = torch.constant.none\n"
12522+
" %str = torch.constant.str \"AssertionError: \"\n"
12523+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12524+
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
12525+
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
12526+
" torch.prim.If %2 -> () {\n"
12527+
" torch.prim.If.yield\n"
12528+
" } else {\n"
12529+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12530+
" torch.prim.If.yield\n"
12531+
" }\n"
12532+
" return %0#1 : !torch.int\n"
1251712533
" }\n"
1251812534
" func.func @\"__torch_mlir_dtype_fn.aten.cumprod\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.int {\n"
1251912535
" %int4 = torch.constant.int 4\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2884,58 +2884,61 @@ class DecomposeAten_LogSoftmaxOp : public OpRewritePattern<Aten_LogSoftmaxOp> {
28842884
};
28852885
} // namespace
28862886

2887-
// Decompose AtenLogCumsumExpOp to:
2888-
// AtenExpOp
2889-
// AtenCumsumOp
2890-
// AtenLogOp
2887+
// Decompose AtenLogCumsumExpOp into: AtenExpOp,
2888+
// AtenCumsumOp and AtenLogOp
2889+
// logcumsumexp(x)[i][j] = log(sum_{k=0}^{j} exp(x[i][k]))
2890+
28912891
namespace {
2892+
template <typename OpTy>
28922893

2893-
class DecomposeAtenLogCumsumExpOp
2894-
: public OpRewritePattern<AtenLogcumsumexpOp> {
2894+
class DecomposeAtenLogCumsumExpOp : public OpRewritePattern<OpTy> {
28952895
public:
2896-
using OpRewritePattern<AtenLogcumsumexpOp>::OpRewritePattern;
2897-
LogicalResult matchAndRewrite(AtenLogcumsumexpOp op,
2896+
using OpRewritePattern<OpTy>::OpRewritePattern;
2897+
LogicalResult matchAndRewrite(OpTy op,
28982898
PatternRewriter &rewriter) const override {
28992899
Location loc = op.getLoc();
29002900
Value input = op.getSelf();
29012901

29022902
auto inputType = dyn_cast<BaseTensorType>(input.getType());
2903-
if (!inputType || !inputType.getDtype())
2903+
if (!inputType)
29042904
return rewriter.notifyMatchFailure(op, "Supports only tensor type");
29052905

29062906
if (!inputType.hasDtype() || !isa<mlir::FloatType>(inputType.getDtype()))
2907-
return rewriter.notifyMatchFailure(op, "Support only floating type");
2907+
return rewriter.notifyMatchFailure(
2908+
op, "Currently Support only floating point type");
2909+
2910+
int64_t inputRank = inputType.getSizes().size();
2911+
int64_t dim;
2912+
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
2913+
return rewriter.notifyMatchFailure(
2914+
op, "Unimplemented: Only constant dim value is supported");
2915+
dim = toPositiveDim(dim, inputRank);
2916+
if (!isValidDim(dim, inputRank))
2917+
return rewriter.notifyMatchFailure(op, "invalid dim");
29082918

29092919
Type elementType = inputType.getDtype();
29102920
torch_upstream::ScalarType scalarType;
2911-
// logcumsumexp is only supported for Float datatype
2921+
// Currently it supports for float datatype
29122922
if (elementType.isF16())
29132923
scalarType = torch_upstream::ScalarType::Half;
29142924
else if (elementType.isF32())
29152925
scalarType = torch_upstream::ScalarType::Float;
2916-
else
2926+
else if (elementType.isF64())
29172927
scalarType = torch_upstream::ScalarType::Double;
2928+
else
2929+
return rewriter.notifyMatchFailure(op, "Unsupported data type");
29182930

29192931
int64_t scalarVal = static_cast<int64_t>(scalarType);
29202932

29212933
Value dtypeVal = rewriter.create<Torch::ConstantIntOp>(
2922-
loc, rewriter.getType<Torch::IntType>(), scalarVal);
2923-
2924-
int64_t inputRank = inputType.getSizes().size();
2925-
int64_t dim;
2926-
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
2927-
return rewriter.notifyMatchFailure(
2928-
op, "Only constant dim value is supported");
2929-
dim = toPositiveDim(dim, inputRank);
2930-
if (!isValidDim(dim, inputRank))
2931-
return rewriter.notifyMatchFailure(op, "invalid dim");
2934+
loc, rewriter.getI64IntegerAttr(scalarVal));
29322935

2933-
Value expInput = rewriter.create<AtenExpOp>(loc, input.getType(), input);
2936+
Value expInput = rewriter.create<AtenExpOp>(loc, op.getType(), input);
29342937

2935-
Value cumsum = rewriter.create<AtenCumsumOp>(
2936-
loc, expInput.getType(), expInput, op.getDim(), dtypeVal);
2938+
Value cumsum = rewriter.create<AtenCumsumOp>(loc, op.getType(), expInput,
2939+
op.getDim(), dtypeVal);
29372940

2938-
Value result = rewriter.create<AtenLogOp>(loc, cumsum.getType(), cumsum);
2941+
Value result = rewriter.create<AtenLogOp>(loc, op.getType(), cumsum);
29392942

29402943
rewriter.replaceOp(op, result);
29412944
return success();
@@ -11988,7 +11991,10 @@ class DecomposeComplexOpsPass
1198811991
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxOp>(patterns);
1198911992
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSoftmaxIntOp>(patterns);
1199011993
addPatternIfTargetOpIsIllegal<DecomposeAtenLogSigmoidOp>(patterns);
11991-
addPatternIfTargetOpIsIllegal<DecomposeAtenLogCumsumExpOp>(patterns);
11994+
addPatternIfTargetOpIsIllegal<
11995+
DecomposeAtenLogCumsumExpOp<AtenLogcumsumexpOp>>(patterns);
11996+
addPatternIfTargetOpIsIllegal<
11997+
DecomposeAtenLogCumsumExpOp<Aten_LogcumsumexpOp>>(patterns);
1199211998
addPatternIfTargetOpIsIllegal<DecomposeAtenHardshrinkOp>(patterns);
1199311999
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftshrinkOp>(patterns);
1199412000
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyLikeOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
375375
target.addIllegalOp<Aten_LogSoftmaxOp>();
376376
target.addIllegalOp<AtenLogSoftmaxIntOp>();
377377
target.addIllegalOp<AtenLogSigmoidOp>();
378-
target.addIllegalOp<AtenLogcumsumexpOp>();
378+
target.addIllegalOp<Aten_LogcumsumexpOp, AtenLogcumsumexpOp>();
379379
target.addIllegalOp<AtenHardshrinkOp>();
380380
target.addIllegalOp<AtenSoftshrinkOp>();
381381
target.addIllegalOp<AtenEmptyLikeOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3326,6 +3326,8 @@
33263326
# RuntimeError: Given input size: (1x1x1). Calculated output size: (1x0x0). Output size is too small
33273327
"AvgPool2dWithoutPadFullDimIndivisibleByStrideModule_basic",
33283328
"MaxPool2dWithoutPadFullDimIndivisibleByStrideModule_basic",
3329+
"_LogCumsumExpStaticModule_basic",
3330+
"_LogCumsumExpStaticNegativeDimModule_basic",
33293331
}
33303332

33313333
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
@@ -3873,6 +3875,8 @@
38733875
"ScaledDotProductAttentionSameDynamicModule_basic",
38743876
"ScaledDotProductAttentionSameModule_basic",
38753877
"ScaledDotProductAttentionGQAModule_basic",
3878+
"_LogCumsumExpStaticModule_basic",
3879+
"_LogCumsumExpStaticNegativeDimModule_basic",
38763880
}
38773881

38783882
ONNX_TOSA_CRASHING_SET = {
@@ -4938,6 +4942,8 @@
49384942
"_ConvolutionDeprecated2DDeterministicModule_basic",
49394943
"_LogSoftmaxModule_basic",
49404944
"_SoftmaxModule_basic",
4945+
"_LogCumsumExpStaticModule_basic",
4946+
"_LogCumsumExpStaticNegativeDimModule_basic",
49414947
}
49424948

49434949
if torch_version_for_comparison() > version.parse("2.5.1"):

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1540,6 +1540,9 @@ def aten〇cumprod〡shape(self: List[int], dim: int, dtype: Optional[int] = Non
15401540
def aten〇logcumsumexp〡shape(self: List[int], dim: int) -> List[int]:
15411541
return self
15421542

1543+
def aten〇_logcumsumexp〡shape(self: List[int], dim: int) -> List[int]:
1544+
return self
1545+
15431546
def aten〇rand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
15441547
return self
15451548

@@ -3223,7 +3226,22 @@ def aten〇cumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Opt
32233226
@check_dtype_function(
32243227
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0))
32253228
def aten〇logcumsumexp〡dtype(self_rank_dtype: Tuple[int, int], dim: int) -> int:
3226-
return self_rank_dtype[1]
3229+
self_rank, self_dtype = self_rank_dtype
3230+
return self_dtype
3231+
3232+
@check_dtype_function(
3233+
_check_tensors_with_the_same_dtype(
3234+
tensor_shapes=[(1, 1)],
3235+
tensor_device="cpu",
3236+
dim=0,
3237+
error_types={*all_integer_dtypes()}
3238+
)
3239+
)
3240+
def aten〇_logcumsumexp〡dtype(self_rank_dtype: Tuple[int, int], dim: int) -> int:
3241+
self_rank, self_dtype = self_rank_dtype
3242+
assert not is_integer_dtype(self_dtype)
3243+
return self_dtype
3244+
32273245

32283246
@check_dtype_function(
32293247
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) +

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5111,6 +5111,39 @@ def LogCumsumExpDtypeModule_basic(module, tu: TestUtils):
51115111
# ==============================================================================
51125112

51135113

5114+
class _LogCumsumExpStaticModule(torch.nn.Module):
5115+
def __init__(self):
5116+
super().__init__()
5117+
5118+
@export
5119+
@annotate_args([None, ([4, 5, 6], torch.float32, True)])
5120+
def forward(self, x):
5121+
return torch.ops.aten._logcumsumexp(x, dim=1)
5122+
5123+
5124+
@register_test_case(module_factory=lambda: _LogCumsumExpStaticModule())
5125+
def _LogCumsumExpStaticModule_basic(module, tu: TestUtils):
5126+
module.forward(tu.rand(5, 3, 6, 9).to(torch.float64))
5127+
5128+
5129+
class _LogCumsumExpStaticNegativeDimModule(torch.nn.Module):
5130+
def __init__(self):
5131+
super().__init__()
5132+
5133+
@export
5134+
@annotate_args([None, ([6, 2, 3], torch.float32, True)])
5135+
def forward(self, x):
5136+
return torch.ops.aten.logcumsumexp(x, dim=-1)
5137+
5138+
5139+
@register_test_case(module_factory=lambda: _LogCumsumExpStaticNegativeDimModule())
5140+
def _LogCumsumExpStaticNegativeDimModule_basic(module, tu: TestUtils):
5141+
module.forward(tu.rand(5, 3, 6, 9).to(torch.float64))
5142+
5143+
5144+
# ==============================================================================
5145+
5146+
51145147
class CumprodModule(torch.nn.Module):
51155148
def __init__(self):
51165149
super().__init__()

0 commit comments

Comments
 (0)