diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index d47e360e9dc13..41cc684808fe1 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -649,7 +649,10 @@ bool MemRefType::areTrailingDimsContiguous(int64_t n) { if (!isLastDimUnitStride()) return false; - auto memrefShape = getShape().take_back(n); + if (n == 1) + return true; + + auto memrefShape = getShape().take_back(n - 1); if (ShapedType::isDynamicShape(memrefShape)) return false; @@ -668,7 +671,7 @@ bool MemRefType::areTrailingDimsContiguous(int64_t n) { // Check whether strides match "flattened" dims. SmallVector flattenedDims; auto dimProduct = 1; - for (auto dim : llvm::reverse(memrefShape.drop_front(1))) { + for (auto dim : llvm::reverse(memrefShape)) { dimProduct *= dim; flattenedDims.push_back(dimProduct); } diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index e840dc6bbf224..aa922415f2669 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -188,18 +188,20 @@ func.func @transfer_read_leading_dynamic_dims( // ----- -// One of the dims to be flattened is dynamic - not supported ATM. +// One of the dims to be flattened is dynamic and not the leftmost - not +// possible to reason whether the memref is contiguous as the dynamic dimension +// could be one and the corresponding stride could be arbitrary. func.func @negative_transfer_read_dynamic_dim_to_flatten( %idx_1: index, %idx_2: index, - %mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> { + %mem: memref<1x4x?x6xi32>) -> vector<1x2x6xi32> { %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32 %res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %c0_i32 { in_bounds = [true, true, true] - } : memref<1x?x4x6xi32>, vector<1x2x6xi32> + } : memref<1x4x?x6xi32>, vector<1x2x6xi32> return %res : vector<1x2x6xi32> } @@ -212,6 +214,41 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten( // ----- +// One of the dims to be flattened is dynamic and leftmost. + +func.func @transfer_read_dynamic_leftmost_dim_to_flatten( + %idx_1: index, + %idx_2: index, + %mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> { + + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + %res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %c0_i32 { + in_bounds = [true, true, true] + } : memref<1x?x4x6xi32>, vector<1x2x6xi32> + return %res : vector<1x2x6xi32> +} + +// CHECK-LABEL: func.func @transfer_read_dynamic_leftmost_dim_to_flatten +// CHECK-SAME: %[[IDX_1:arg0]]: index +// CHECK-SAME: %[[IDX_2:arg1]]: index +// CHECK-SAME: %[[MEM:arg2]]: memref<1x?x4x6xi32> +// CHECK-NEXT: %[[C0_I32:.+]] = arith.constant 0 : i32 +// CHECK-NEXT: %[[C0:.+]] = arith.constant 0 : index +// CHECK-NEXT: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}} +// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32> +// CHECK-NEXT: %[[TMP:.+]] = affine.apply #map{{.*}}()[%[[IDX_1]], %[[IDX_2]]] +// CHECK-NEXT: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]] +// CHECK-SAME: [%[[C0]], %[[TMP]]], %[[C0_I32]] +// CHECK-SAME: {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32> +// CHECK-NEXT: %[[RES:.+]] = vector.shape_cast %[[VEC1D]] : vector<12xi32> to vector<1x2x6xi32> +// CHECK-NEXT: return %[[RES]] : vector<1x2x6xi32> + +// CHECK-128B-LABEL: func @transfer_read_dynamic_leftmost_dim_to_flatten +// CHECK-128B-NOT: memref.collapse_shape + +// ----- + // The vector to be read represents a _non-contiguous_ slice of the input // memref. @@ -451,26 +488,61 @@ func.func @transfer_write_leading_dynamic_dims( // ----- -// One of the dims to be flattened is dynamic - not supported ATM. +// One of the dims to be flattened is dynamic and not leftmost. -func.func @negative_transfer_write_dynamic_to_flatten( +func.func @negative_transfer_write_dynamic_dim_to_flatten( %idx_1: index, %idx_2: index, %vec : vector<1x2x6xi32>, - %mem: memref<1x?x4x6xi32>) { + %mem: memref<1x4x?x6xi32>) { %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32 vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} : - vector<1x2x6xi32>, memref<1x?x4x6xi32> + vector<1x2x6xi32>, memref<1x4x?x6xi32> return } -// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten +// CHECK-LABEL: func.func @negative_transfer_write_dynamic_dim_to_flatten // CHECK-NOT: memref.collapse_shape // CHECK-NOT: vector.shape_cast -// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten +// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_dim_to_flatten +// CHECK-128B-NOT: memref.collapse_shape + +// ----- + +// One of the dims to be flattened is dynamic and leftmost. + +func.func @transfer_write_dynamic_leftmost_dim_to_flatten( + %idx_1: index, + %idx_2: index, + %vec : vector<1x2x6xi32>, + %mem: memref<1x?x4x6xi32>) { + + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} : + vector<1x2x6xi32>, memref<1x?x4x6xi32> + return +} + +// CHECK-LABEL: func.func @transfer_write_dynamic_leftmost_dim_to_flatten +// CHECK-SAME: %[[IDX_1:arg0]]: index +// CHECK-SAME: %[[IDX_2:arg1]]: index +// CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32>, +// CHECK-SAME: %[[MEM:arg3]]: memref<1x?x4x6xi32> +// CHECK-NEXT: %[[C0:.+]] = arith.constant 0 : index +// CHECK-NEXT: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}} +// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32> +// CHECK-NEXT: %[[TMP:.+]] = affine.apply #map{{.*}}()[%[[IDX_1]], %[[IDX_2]]] +// CHECK-NEXT: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32> +// CHECK-NEXT: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] +// CHECK-SAME: [%[[C0]], %[[TMP]]] +// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32> +// CHECK-NEXT: return + +// CHECK-128B-LABEL: func @transfer_write_dynamic_leftmost_dim_to_flatten // CHECK-128B-NOT: memref.collapse_shape // -----