Skip to content

Commit 4901773

Browse files
authored
add uncovered cases in view lowering (#2524)
removes unecessary checks from empty strided
1 parent 365655c commit 4901773

File tree

4 files changed

+37
-1
lines changed

4 files changed

+37
-1
lines changed

e2e_testing/xfail_sets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@
154154
'BoolFloatConstantModule_basic',
155155
'BoolIntConstantModule_basic',
156156

157+
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.size
158+
"ViewSizeFromOtherTensor_basic",
159+
157160
# ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function aten.__contains__
158161
'ContainsIntList_False',
159162
'ContainsIntList_True',

lib/Conversion/TorchToLinalg/DataMovement.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,11 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
377377
// collapsed. Note this may technically not always be true.
378378
// TODO: think of a way better way to at least detect when this assumption
379379
// is violated for the cases of dynamic dimensions.
380+
bool inputHasOneDynDim = llvm::count(inputShape, kUnknownSize) == 1;
381+
bool outputHasOneDynDim = llvm::count(outputShape, kUnknownSize) == 1;
382+
bool singleDynDimsAreEqual =
383+
inputHasOneDynDim && outputHasOneDynDim &&
384+
productReduce(inputShape) == productReduce(outputShape);
380385
SmallVector<std::pair<int64_t, int64_t>> unchangedDims;
381386
for (auto [outputDim, outputDimSize] :
382387
llvm::enumerate(outputSizeTorchInt)) {
@@ -385,6 +390,14 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
385390
if (matchPattern(outputDimSize,
386391
m_TorchTensorSizeInt(op.getSelf(), &inputDim))) {
387392
unchangedDims.push_back(std::make_pair(inputDim, outputDim));
393+
} else if (singleDynDimsAreEqual &&
394+
outputShape[outputDim] == kUnknownSize) {
395+
// If the input and output have a single dynamic dimension and the
396+
// product of the other dimensions is the same, then we know that the
397+
// dynamic dimension is unchanged.
398+
inputDim = std::distance(inputShape.begin(),
399+
llvm::find(inputShape, kUnknownSize));
400+
unchangedDims.push_back(std::make_pair(inputDim, outputDim));
388401
}
389402
}
390403
// Mark the end of the input/output shapes

python/torch_mlir_e2e_test/test_suite/reshape_like.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,26 @@ def ViewFlattenAndExpandModule_basic(module, tu: TestUtils):
430430

431431
# ==============================================================================
432432

433+
class ViewSizeFromOtherTensor(torch.nn.Module):
434+
def __init__(self):
435+
super().__init__()
436+
437+
@export
438+
@annotate_args([
439+
None,
440+
([1, -1], torch.float32, True),
441+
([1, -1, 10], torch.float32, True),
442+
])
443+
444+
def forward(self, x, y):
445+
return torch.ops.aten.view(y, (torch.ops.aten.size(x, 1), 10))
446+
447+
@register_test_case(module_factory=lambda: ViewSizeFromOtherTensor())
448+
def ViewSizeFromOtherTensor_basic(module, tu: TestUtils):
449+
module.forward(tu.rand(1, 7), tu.rand(1, 7, 10))
450+
451+
# ==============================================================================
452+
433453
class UnsafeViewExpandModule(torch.nn.Module):
434454
def __init__(self):
435455
super().__init__()

test/Conversion/TorchToLinalg/view.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,4 @@ func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) -
7373
%0 = torch.prim.ListConstruct %int3, %int2, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
7474
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[2,6],f32>, !torch.list<int> -> !torch.vtensor<[3,2,2],f32>
7575
return %1 : !torch.vtensor<[3,2,2],f32>
76-
}
76+
}

0 commit comments

Comments
 (0)