@@ -193,6 +193,9 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
193193 ArrayRef<int64_t > yDims,
194194 SmallVector<int64_t > &xIndices,
195195 SmallVector<int64_t > &yIndices) {
196+ if (xDims.empty () || yDims.empty ())
197+ return failure ();
198+
196199 auto isValidReduction = [](int64_t expectedReductionProduct,
197200 ArrayRef<int64_t > arrayToReduce) -> bool {
198201 if (llvm::count (arrayToReduce, kUnknownSize ) > 0 ||
@@ -255,13 +258,34 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
255258 return success ();
256259 }
257260
261+ // If one of the two dims arrays has size 0 and the other array only
262+ // has dims of size 1, a mapping is created from no dimensions to
263+ // all the dimensions of the other array.
264+ static LogicalResult mapTrailingSizeOneDims (ArrayRef<int64_t > xDims,
265+ ArrayRef<int64_t > yDims,
266+ SmallVector<int64_t > &xIndices,
267+ SmallVector<int64_t > &yIndices) {
268+ SmallVector<int64_t > ignoredIndices;
269+ if (xDims.empty ()) {
270+ return mapAllDimsToSingleDim (ArrayRef<int64_t >({1 }), yDims,
271+ ignoredIndices, yIndices);
272+ } else if (yDims.empty ()) {
273+ return mapAllDimsToSingleDim (xDims, ArrayRef<int64_t >({1 }), xIndices,
274+ ignoredIndices);
275+ } else {
276+ return failure ();
277+ }
278+ }
279+
258280 // Calculates the size of a dynamic dimension if all other dimensions are
259281 // statically known, and rewrites that dynamic dimension with the static size.
260282 //
261283 // Note: this function assumes that all the dimensions in `inputShape` map to
262284 // all the dimensions in `outputShape`.
263285 static void calculateSingleDynamicSize (MutableArrayRef<int64_t > inputShape,
264286 MutableArrayRef<int64_t > outputShape) {
287+ if (inputShape.empty () || outputShape.empty ())
288+ return ;
265289 int64_t inputDynamicDimCount = llvm::count (inputShape, kUnknownSize );
266290 int64_t outputDynamicDimCount = llvm::count (outputShape, kUnknownSize );
267291 if (inputDynamicDimCount + outputDynamicDimCount != 1 )
@@ -420,7 +444,7 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
420444 for (auto [nextUnchangedInput, nextUnchangedOutput] : unchangedDims) {
421445 // Used for ensuring that we don't have an ambiguous expansion
422446 bool assumedDynamicDimNotSplit = false ;
423- while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) {
447+ while (inputDim < nextUnchangedInput || outputDim < nextUnchangedOutput) {
424448 auto inputShapeSlice =
425449 MutableArrayRef<int64_t >(inputShape)
426450 .slice (inputDim, nextUnchangedInput - inputDim);
@@ -441,9 +465,15 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
441465 " (e.g. [-1, -1] -> [-1, -1, -1])" );
442466 }
443467
444- if (succeeded (mapAllDimsToSingleDim (inputShapeSlice, outputShapeSlice,
445- inputSliceIndices,
446- outputSliceIndices))) {
468+ if (succeeded (mapTrailingSizeOneDims (inputShapeSlice, outputShapeSlice,
469+ inputSliceIndices,
470+ outputSliceIndices))) {
471+ } else if (outputShapeSlice.empty ()) {
472+ inputSliceIndices.assign (
473+ llvm::to_vector (llvm::seq<int64_t >(0 , inputShapeSlice.size ())));
474+ } else if (succeeded (mapAllDimsToSingleDim (
475+ inputShapeSlice, outputShapeSlice, inputSliceIndices,
476+ outputSliceIndices))) {
447477 calculateSingleDynamicSize (inputShapeSlice, outputShapeSlice);
448478 // Update shape to pass the tensor.expand_shape and
449479 // tensor.collapse_shape verifiers. If one of the dimensions of the
@@ -462,7 +492,8 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
462492 // / `mapStaticallyKnownDims` maps the smallest number of
463493 // / input and output dimensions in the slice statically
464494 // / known to have the same number of elements.
465- } else if (inputShapeSlice[0 ] == kUnknownSize ) {
495+ } else if (inputShapeSlice.size () > 0 &&
496+ inputShapeSlice[0 ] == kUnknownSize ) {
466497 // If the input is dynamic, assume it is not split
467498 checkDimEqualHelper (rewriter, loc, inputSize[inputDim],
468499 outputSizeInt[outputDim]);
@@ -478,8 +509,14 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
478509 " in `aten.view`" );
479510 }
480511
481- inputAssociations.emplace_back ();
482- outputAssociations.emplace_back ();
512+ // If one of the slices is empty, this means we are handling
513+ // the case of trailing dimensions, which does not require a
514+ // new reassociation; the trailing dimensions get added to the
515+ // last reassociation created.
516+ if (inputShapeSlice.size () > 0 && outputShapeSlice.size () > 0 ) {
517+ inputAssociations.emplace_back ();
518+ outputAssociations.emplace_back ();
519+ }
483520 for (int64_t inputSliceIndex : inputSliceIndices)
484521 inputAssociations.back ().push_back (inputSliceIndex + inputDim);
485522 for (int64_t outputSliceIndex : outputSliceIndices)
0 commit comments