Skip to content

Commit 7c6b9d2

Browse files
authored
[linalg] Fix handling of trailing size-1 dimensions in aten.view (#2474)
This commit adds to the lowering of `aten.view` handling for the following cases: - `(..., a.size(i))` -> `(..., a.size(i), 1, ..., 1)` - `(..., a.size(i), 1, ..., 1)` -> `(..., a.size(i))` Fixes: #2448
1 parent e69266a commit 7c6b9d2

File tree

2 files changed

+79
-8
lines changed

2 files changed

+79
-8
lines changed

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
193193
ArrayRef<int64_t> yDims,
194194
SmallVector<int64_t> &xIndices,
195195
SmallVector<int64_t> &yIndices) {
196+
if (xDims.empty() || yDims.empty())
197+
return failure();
198+
196199
auto isValidReduction = [](int64_t expectedReductionProduct,
197200
ArrayRef<int64_t> arrayToReduce) -> bool {
198201
if (llvm::count(arrayToReduce, kUnknownSize) > 0 ||
@@ -255,13 +258,34 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
255258
return success();
256259
}
257260

261+
// If one of the two dims arrays has size 0 and the other array only
262+
// has dims of size 1, a mapping is created from no dimensions to
263+
// all the dimensions of the other array.
264+
static LogicalResult mapTrailingSizeOneDims(ArrayRef<int64_t> xDims,
265+
ArrayRef<int64_t> yDims,
266+
SmallVector<int64_t> &xIndices,
267+
SmallVector<int64_t> &yIndices) {
268+
SmallVector<int64_t> ignoredIndices;
269+
if (xDims.empty()) {
270+
return mapAllDimsToSingleDim(ArrayRef<int64_t>({1}), yDims,
271+
ignoredIndices, yIndices);
272+
} else if (yDims.empty()) {
273+
return mapAllDimsToSingleDim(xDims, ArrayRef<int64_t>({1}), xIndices,
274+
ignoredIndices);
275+
} else {
276+
return failure();
277+
}
278+
}
279+
258280
// Calculates the size of a dynamic dimension if all other dimensions are
259281
// statically known, and rewrites that dynamic dimension with the static size.
260282
//
261283
// Note: this function assumes that all the dimensions in `inputShape` map to
262284
// all the dimensions in `outputShape`.
263285
static void calculateSingleDynamicSize(MutableArrayRef<int64_t> inputShape,
264286
MutableArrayRef<int64_t> outputShape) {
287+
if (inputShape.empty() || outputShape.empty())
288+
return;
265289
int64_t inputDynamicDimCount = llvm::count(inputShape, kUnknownSize);
266290
int64_t outputDynamicDimCount = llvm::count(outputShape, kUnknownSize);
267291
if (inputDynamicDimCount + outputDynamicDimCount != 1)
@@ -420,7 +444,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
420444
for (auto [nextUnchangedInput, nextUnchangedOutput] : unchangedDims) {
421445
// Used for ensuring that we don't have an ambiguous expansion
422446
bool assumedDynamicDimNotSplit = false;
423-
while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) {
447+
while (inputDim < nextUnchangedInput || outputDim < nextUnchangedOutput) {
424448
auto inputShapeSlice =
425449
MutableArrayRef<int64_t>(inputShape)
426450
.slice(inputDim, nextUnchangedInput - inputDim);
@@ -441,9 +465,15 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
441465
"(e.g. [-1, -1] -> [-1, -1, -1])");
442466
}
443467

444-
if (succeeded(mapAllDimsToSingleDim(inputShapeSlice, outputShapeSlice,
445-
inputSliceIndices,
446-
outputSliceIndices))) {
468+
if (succeeded(mapTrailingSizeOneDims(inputShapeSlice, outputShapeSlice,
469+
inputSliceIndices,
470+
outputSliceIndices))) {
471+
} else if (outputShapeSlice.empty()) {
472+
inputSliceIndices.assign(
473+
llvm::to_vector(llvm::seq<int64_t>(0, inputShapeSlice.size())));
474+
} else if (succeeded(mapAllDimsToSingleDim(
475+
inputShapeSlice, outputShapeSlice, inputSliceIndices,
476+
outputSliceIndices))) {
447477
calculateSingleDynamicSize(inputShapeSlice, outputShapeSlice);
448478
// Update shape to pass the tensor.expand_shape and
449479
// tensor.collapse_shape verifiers. If one of the dimensions of the
@@ -462,7 +492,8 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
462492
/// `mapStaticallyKnownDims` maps the smallest number of
463493
/// input and output dimensions in the slice statically
464494
/// known to have the same number of elements.
465-
} else if (inputShapeSlice[0] == kUnknownSize) {
495+
} else if (inputShapeSlice.size() > 0 &&
496+
inputShapeSlice[0] == kUnknownSize) {
466497
// If the input is dynamic, assume it is not split
467498
checkDimEqualHelper(rewriter, loc, inputSize[inputDim],
468499
outputSizeInt[outputDim]);
@@ -478,8 +509,14 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
478509
"in `aten.view`");
479510
}
480511

481-
inputAssociations.emplace_back();
482-
outputAssociations.emplace_back();
512+
// If one of the slices is empty, this means we are handling
513+
// the case of trailing dimensions, which does not require a
514+
// new reassociation; the trailing dimensions get added to the
515+
// last reassociation created.
516+
if (inputShapeSlice.size() > 0 && outputShapeSlice.size() > 0) {
517+
inputAssociations.emplace_back();
518+
outputAssociations.emplace_back();
519+
}
483520
for (int64_t inputSliceIndex : inputSliceIndices)
484521
inputAssociations.back().push_back(inputSliceIndex + inputDim);
485522
for (int64_t outputSliceIndex : outputSliceIndices)

python/torch_mlir_e2e_test/test_suite/reshape_like.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,40 @@ def forward(self, a):
672672
def ViewNegativeStaticModule_basic(module, tu: TestUtils):
673673
module.forward(tu.rand(1, 128))
674674

675+
class ViewSizeDimFollowedByExpandedOnesModule(torch.nn.Module):
676+
def __init__(self):
677+
super().__init__()
678+
679+
@export
680+
@annotate_args([
681+
None,
682+
([-1], torch.float32, True),
683+
])
684+
685+
def forward(self, a):
686+
return a.view(a.size(0), 1, 1, 1)
687+
688+
@register_test_case(module_factory=lambda: ViewSizeDimFollowedByExpandedOnesModule())
689+
def ViewSizeDimFollowedByExpandedOnesModule_basic(module, tu: TestUtils):
690+
module.forward(tu.rand(128))
691+
692+
class ViewSizeDimFollowedByCollapsedOnesModule(torch.nn.Module):
693+
def __init__(self):
694+
super().__init__()
695+
696+
@export
697+
@annotate_args([
698+
None,
699+
([-1, 1, 1, 1], torch.float32, True),
700+
])
701+
702+
def forward(self, a):
703+
return a.view(a.size(0))
704+
705+
@register_test_case(module_factory=lambda: ViewSizeDimFollowedByCollapsedOnesModule())
706+
def ViewSizeDimFollowedByCollapsedOnesModule_basic(module, tu: TestUtils):
707+
module.forward(tu.rand(128, 1, 1, 1))
708+
675709
# ==============================================================================
676710

677711
class ReshapeAliasExpandModule(torch.nn.Module):
@@ -710,4 +744,4 @@ def forward(self, a):
710744

711745
@register_test_case(module_factory=lambda: ReshapeAliasCollapseModule())
712746
def ReshapeAliasCollapseModule_basic(module, tu: TestUtils):
713-
module.forward(tu.rand(2, 4))
747+
module.forward(tu.rand(2, 4))

0 commit comments

Comments
 (0)