Skip to content

Commit 2e5d650

Browse files
committed
[linalg] Add handling for leadin and trailing size-1 dims in ViewOp
This commit adds to the lowering of `aten.view` handling for the following cases: - `(..., a.size(i))` -> `(..., a.size(i), 1, ..., 1)` - `(..., a.size(i), 1, ..., 1)` -> `(..., a.size(i))` - `(a.size(i), ...)` -> `(1, ..., 1, a.size(i), ...)` - `(1, ..., 1, a.size(i), ...)` -> `(a.size(i), ...)`
1 parent 1c508af commit 2e5d650

File tree

2 files changed

+128
-4
lines changed

2 files changed

+128
-4
lines changed

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
193193
ArrayRef<int64_t> yDims,
194194
SmallVector<int64_t> &xIndices,
195195
SmallVector<int64_t> &yIndices) {
196+
if (xDims.empty() || yDims.empty())
197+
return failure();
198+
196199
auto isValidReduction = [](int64_t expectedReductionProduct,
197200
ArrayRef<int64_t> arrayToReduce) -> bool {
198201
if (llvm::count(arrayToReduce, kUnknownSize) > 0 ||
@@ -262,6 +265,8 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
262265
// all the dimensions in `outputShape`.
263266
static void calculateSingleDynamicSize(MutableArrayRef<int64_t> inputShape,
264267
MutableArrayRef<int64_t> outputShape) {
268+
if (inputShape.empty() || outputShape.empty())
269+
return;
265270
int64_t inputDynamicDimCount = llvm::count(inputShape, kUnknownSize);
266271
int64_t outputDynamicDimCount = llvm::count(outputShape, kUnknownSize);
267272
if (inputDynamicDimCount + outputDynamicDimCount != 1)
@@ -488,12 +493,29 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
488493
outputDim = outputAssociations.back().back() + 1;
489494
}
490495

491-
// Append the associations for the dims matching `aten.size.int`
492-
if (nextUnchangedInput != inputRank &&
493-
nextUnchangedOutput != resultRank) {
496+
// Handle any leading or trailing size-1 dimensions and append the
497+
// associations for the dims matching `aten.size.int`.
498+
if (nextUnchangedInput != inputRank) {
499+
assert(nextUnchangedOutput != resultRank &&
500+
"`nextUnchangedInput` and `nextUnchangedOutput` should equal "
501+
"the respective input and output rank at the same time");
494502
inputAssociations.emplace_back();
495503
outputAssociations.emplace_back();
504+
}
505+
while (inputDim <= nextUnchangedInput && inputDim < inputRank) {
506+
if (inputDim != nextUnchangedInput && inputShape[inputDim] != 1) {
507+
return rewriter.notifyMatchFailure(
508+
op, "unimplemented: only collapsing of static size-1 into "
509+
"unchanged dim supported");
510+
}
496511
inputAssociations.back().push_back(inputDim++);
512+
}
513+
while (outputDim <= nextUnchangedOutput && outputDim < resultRank) {
514+
if (outputDim != nextUnchangedOutput && outputShape[outputDim] != 1) {
515+
return rewriter.notifyMatchFailure(
516+
op, "unimplemented: only expanding of static size-1 out of "
517+
"unchanged dim supported");
518+
}
497519
outputAssociations.back().push_back(outputDim++);
498520
}
499521
}

python/torch_mlir_e2e_test/test_suite/reshape_like.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,108 @@ def forward(self, a):
672672
def ViewNegativeStaticModule_basic(module, tu: TestUtils):
673673
module.forward(tu.rand(1, 128))
674674

675+
class ViewSizeDimFollowedByExpandedOnesModule(torch.nn.Module):
676+
def __init__(self):
677+
super().__init__()
678+
679+
@export
680+
@annotate_args([
681+
None,
682+
([-1], torch.float32, True),
683+
])
684+
685+
def forward(self, a):
686+
return a.view(a.size(0), 1, 1, 1)
687+
688+
@register_test_case(module_factory=lambda: ViewSizeDimFollowedByExpandedOnesModule())
689+
def ViewSizeDimFollowedByExpandedOnesModule_basic(module, tu: TestUtils):
690+
module.forward(tu.rand(128))
691+
692+
class ViewSizeDimFollowedByCollapsedOnesModule(torch.nn.Module):
693+
def __init__(self):
694+
super().__init__()
695+
696+
@export
697+
@annotate_args([
698+
None,
699+
([-1, 1, 1, 1], torch.float32, True),
700+
])
701+
702+
def forward(self, a):
703+
return a.view(a.size(0))
704+
705+
@register_test_case(module_factory=lambda: ViewSizeDimFollowedByCollapsedOnesModule())
706+
def ViewSizeDimFollowedByCollapsedOnesModule_basic(module, tu: TestUtils):
707+
module.forward(tu.rand(128, 1, 1, 1))
708+
709+
class ViewSizeDimLedByExpandedOnesModule(torch.nn.Module):
710+
def __init__(self):
711+
super().__init__()
712+
713+
@export
714+
@annotate_args([
715+
None,
716+
([-1], torch.float32, True),
717+
])
718+
719+
def forward(self, a):
720+
return a.view(1, 1, 1, a.size(0))
721+
722+
@register_test_case(module_factory=lambda: ViewSizeDimLedByExpandedOnesModule())
723+
def ViewSizeDimLedByExpandedOnesModule_basic(module, tu: TestUtils):
724+
module.forward(tu.rand(128))
725+
726+
class ViewSizeDimLedByCollapsedOnesModule(torch.nn.Module):
727+
def __init__(self):
728+
super().__init__()
729+
730+
@export
731+
@annotate_args([
732+
None,
733+
([1, 1, 1, -1], torch.float32, True),
734+
])
735+
736+
def forward(self, a):
737+
return a.view(a.size(3))
738+
739+
@register_test_case(module_factory=lambda: ViewSizeDimLedByCollapsedOnesModule())
740+
def ViewSizeDimLedByCollapsedOnesModule_basic(module, tu: TestUtils):
741+
module.forward(tu.rand(1, 1, 1, 128))
742+
743+
class ViewSizeDimLedAndFollowedByExpandedOnesModule(torch.nn.Module):
744+
def __init__(self):
745+
super().__init__()
746+
747+
@export
748+
@annotate_args([
749+
None,
750+
([-1], torch.float32, True),
751+
])
752+
753+
def forward(self, a):
754+
return a.view(1, 1, 1, a.size(0), 1, 1, 1)
755+
756+
@register_test_case(module_factory=lambda: ViewSizeDimLedAndFollowedByExpandedOnesModule())
757+
def ViewSizeDimLedAndFollowedByExpandedOnesModule_basic(module, tu: TestUtils):
758+
module.forward(tu.rand(128))
759+
760+
class ViewSizeDimLedAndFollowedByCollapsedOnesModule(torch.nn.Module):
761+
def __init__(self):
762+
super().__init__()
763+
764+
@export
765+
@annotate_args([
766+
None,
767+
([1, 1, 1, -1, 1, 1, 1], torch.float32, True),
768+
])
769+
770+
def forward(self, a):
771+
return a.view(a.size(3))
772+
773+
@register_test_case(module_factory=lambda: ViewSizeDimLedAndFollowedByCollapsedOnesModule())
774+
def ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic(module, tu: TestUtils):
775+
module.forward(tu.rand(1, 1, 1, 128, 1, 1, 1))
776+
675777
# ==============================================================================
676778

677779
class ReshapeAliasExpandModule(torch.nn.Module):
@@ -710,4 +812,4 @@ def forward(self, a):
710812

711813
@register_test_case(module_factory=lambda: ReshapeAliasCollapseModule())
712814
def ReshapeAliasCollapseModule_basic(module, tu: TestUtils):
713-
module.forward(tu.rand(2, 4))
815+
module.forward(tu.rand(2, 4))

0 commit comments

Comments
 (0)