-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[MLIR] Determine contiguousness of memrefs with a dynamic dimension #140872
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -188,18 +188,20 @@ func.func @transfer_read_leading_dynamic_dims( | |
|
||
// ----- | ||
|
||
// One of the dims to be flattened is dynamic - not supported ATM. | ||
// One of the dims to be flattened is dynamic and not the leftmost - not | ||
// possible to reason whether the memref is contiguous as the dynamic dimension | ||
// could be one and the corresponding stride could be arbitrary. | ||
Comment on lines
+191
to
+193
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From https://mlir.llvm.org/docs/Dialects/Builtin/#memreftype:
To me that reads as: without an explicit layout, it's an identity (i.e. a contiguous MemRef), no? If my reading is correct then the logic in However, in the context of "flattening", the dynamic dims are significant, yes. So one should check for dynamic dims, but probably somewhere in Vector dialect transforms. |
||
|
||
func.func @negative_transfer_read_dynamic_dim_to_flatten( | ||
%idx_1: index, | ||
%idx_2: index, | ||
%mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> { | ||
%mem: memref<1x4x?x6xi32>) -> vector<1x2x6xi32> { | ||
|
||
%c0 = arith.constant 0 : index | ||
%c0_i32 = arith.constant 0 : i32 | ||
%res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %c0_i32 { | ||
in_bounds = [true, true, true] | ||
} : memref<1x?x4x6xi32>, vector<1x2x6xi32> | ||
} : memref<1x4x?x6xi32>, vector<1x2x6xi32> | ||
return %res : vector<1x2x6xi32> | ||
} | ||
|
||
|
@@ -212,6 +214,41 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten( | |
|
||
// ----- | ||
|
||
// One of the dims to be flattened is dynamic and leftmost. | ||
|
||
func.func @transfer_read_dynamic_leftmost_dim_to_flatten( | ||
%idx_1: index, | ||
%idx_2: index, | ||
%mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> { | ||
|
||
%c0 = arith.constant 0 : index | ||
%c0_i32 = arith.constant 0 : i32 | ||
%res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %c0_i32 { | ||
in_bounds = [true, true, true] | ||
} : memref<1x?x4x6xi32>, vector<1x2x6xi32> | ||
return %res : vector<1x2x6xi32> | ||
} | ||
|
||
// CHECK-LABEL: func.func @transfer_read_dynamic_leftmost_dim_to_flatten | ||
// CHECK-SAME: %[[IDX_1:arg0]]: index | ||
// CHECK-SAME: %[[IDX_2:arg1]]: index | ||
// CHECK-SAME: %[[MEM:arg2]]: memref<1x?x4x6xi32> | ||
// CHECK-NEXT: %[[C0_I32:.+]] = arith.constant 0 : i32 | ||
// CHECK-NEXT: %[[C0:.+]] = arith.constant 0 : index | ||
// CHECK-NEXT: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}} | ||
// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32> | ||
// CHECK-NEXT: %[[TMP:.+]] = affine.apply #map{{.*}}()[%[[IDX_1]], %[[IDX_2]]] | ||
// CHECK-NEXT: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]] | ||
// CHECK-SAME: [%[[C0]], %[[TMP]]], %[[C0_I32]] | ||
// CHECK-SAME: {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32> | ||
// CHECK-NEXT: %[[RES:.+]] = vector.shape_cast %[[VEC1D]] : vector<12xi32> to vector<1x2x6xi32> | ||
// CHECK-NEXT: return %[[RES]] : vector<1x2x6xi32> | ||
|
||
// CHECK-128B-LABEL: func @transfer_read_dynamic_leftmost_dim_to_flatten | ||
// CHECK-128B-NOT: memref.collapse_shape | ||
|
||
// ----- | ||
|
||
// The vector to be read represents a _non-contiguous_ slice of the input | ||
// memref. | ||
|
||
|
@@ -451,26 +488,61 @@ func.func @transfer_write_leading_dynamic_dims( | |
|
||
// ----- | ||
|
||
// One of the dims to be flattened is dynamic - not supported ATM. | ||
// One of the dims to be flattened is dynamic and not leftmost. | ||
|
||
func.func @negative_transfer_write_dynamic_to_flatten( | ||
func.func @negative_transfer_write_dynamic_dim_to_flatten( | ||
%idx_1: index, | ||
%idx_2: index, | ||
%vec : vector<1x2x6xi32>, | ||
%mem: memref<1x?x4x6xi32>) { | ||
%mem: memref<1x4x?x6xi32>) { | ||
|
||
%c0 = arith.constant 0 : index | ||
%c0_i32 = arith.constant 0 : i32 | ||
vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} : | ||
vector<1x2x6xi32>, memref<1x?x4x6xi32> | ||
vector<1x2x6xi32>, memref<1x4x?x6xi32> | ||
return | ||
} | ||
|
||
// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten | ||
// CHECK-LABEL: func.func @negative_transfer_write_dynamic_dim_to_flatten | ||
// CHECK-NOT: memref.collapse_shape | ||
// CHECK-NOT: vector.shape_cast | ||
|
||
// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten | ||
// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_dim_to_flatten | ||
// CHECK-128B-NOT: memref.collapse_shape | ||
|
||
// ----- | ||
|
||
// One of the dims to be flattened is dynamic and leftmost. | ||
|
||
func.func @transfer_write_dynamic_leftmost_dim_to_flatten( | ||
%idx_1: index, | ||
%idx_2: index, | ||
%vec : vector<1x2x6xi32>, | ||
%mem: memref<1x?x4x6xi32>) { | ||
|
||
%c0 = arith.constant 0 : index | ||
%c0_i32 = arith.constant 0 : i32 | ||
vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} : | ||
vector<1x2x6xi32>, memref<1x?x4x6xi32> | ||
return | ||
} | ||
|
||
// CHECK-LABEL: func.func @transfer_write_dynamic_leftmost_dim_to_flatten | ||
// CHECK-SAME: %[[IDX_1:arg0]]: index | ||
// CHECK-SAME: %[[IDX_2:arg1]]: index | ||
// CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32>, | ||
// CHECK-SAME: %[[MEM:arg3]]: memref<1x?x4x6xi32> | ||
// CHECK-NEXT: %[[C0:.+]] = arith.constant 0 : index | ||
// CHECK-NEXT: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}} | ||
// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32> | ||
// CHECK-NEXT: %[[TMP:.+]] = affine.apply #map{{.*}}()[%[[IDX_1]], %[[IDX_2]]] | ||
// CHECK-NEXT: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32> | ||
// CHECK-NEXT: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] | ||
// CHECK-SAME: [%[[C0]], %[[TMP]]] | ||
// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32> | ||
// CHECK-NEXT: return | ||
|
||
// CHECK-128B-LABEL: func @transfer_write_dynamic_leftmost_dim_to_flatten | ||
// CHECK-128B-NOT: memref.collapse_shape | ||
|
||
// ----- | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Could you add comments explaining what makes
n == 1
and the lastn-1
dims special? Alternatively, renamen
to e.g.numTrailingDimsToCheck
(or something else self-documenting).