-
Notifications
You must be signed in to change notification settings - Fork 14.5k
Description
Hi folks,
This is based on iree-org/iree#21393 that was originally posted by @egebeysel , thanks! I've re-written it using "pure" MLIR (i.e. not requiring IREE).
REPRO
func.func @unpack(%arg0: tensor<512x?x8x?xf32>, %arg1: tensor<4096x4096xf32>, %arg2: tensor<512x?x8x?xf32>) -> tensor<4096x4096xf32> {
%c8 = arith.constant 8 : index
%vscale = vector.vscale
%c8_vscale = arith.muli %vscale, %c8 : index
%0 = tensor.empty() : tensor<4096x4096xf32>
%unpack = linalg.unpack %arg2 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, %c8_vscale] into %0 : tensor<512x?x8x?xf32> -> tensor<4096x4096xf32>
return %unpack : tensor<4096x4096xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module : !transform.any_op {transform.readonly}) {
%unpack = transform.structured.match ops{["linalg.unpack"]} in %module
: (!transform.any_op) -> !transform.any_op
%tiled_unpack, %loops:2 = transform.structured.tile_using_for %unpack tile_sizes [8, [8]]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
After tiling (bin/mlir-opt --transform-interpreter unpack.mlir -cse -test-transform-dialect-erase-schedule
):
#map = affine_map<(d0)[s0] -> (-d0 + 4096, s0)>
#map1 = affine_map<(d0) -> (d0 floordiv 8)>
#map2 = affine_map<(d0)[s0] -> (d0 floordiv s0)>
#map3 = affine_map<(d0)[s0] -> (d0 mod s0)>
#map4 = affine_map<(d0, d1)[s0] -> ((d0 + d1 - 1) floordiv s0 - d0 floordiv s0 + 1)>
// ...
%1 = scf.for %arg3 = %c0 to %c4096 step %c8 iter_args(%arg4 = %0) -> (tensor<4096x4096xf32>) {
%2 = scf.for %arg5 = %c0 to %c4096 step %c8_vscale iter_args(%arg6 = %arg4) -> (tensor<4096x4096xf32>) {
%3 = affine.min #map(%arg5)[%c8_vscale]
%4 = affine.apply #map1(%arg3)
%5 = affine.apply #map2(%arg5)[%c8_vscale]
%7 = affine.apply #map4(%arg5, %3)[%c8_vscale]
%extracted_slice = tensor.extract_slice %arg2[%4, %5, 0, 0] [1, %7, 8, %c8_vscale] [1, 1, 1, 1] : tensor<512x?x8x?xf32> to tensor<1x?x8x?xf32>
%9 = tensor.empty(%8) : tensor<8x?xf32>
%unpack = linalg.unpack %extracted_slice outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, %c8_vscale] into %9 : tensor<1x?x8x?xf32> -> tensor<8x?xf32>
// ...
ISSUE
This expression in the output is overly complex:
%7 = affine.apply #map4(%arg5, %3)[%c8_vscale]
// ...
%extracted_slice = tensor.extract_slice %arg2[%4, %5, 0, 0] [1, %7, 8, %c8_vscale] [1, 1, 1, 1] : tensor<512x?x8x?xf32> to tensor<1x?x8x?xf32>
Specifically, %7
should be trivially 1
(see below). As a result, we get the unpack source shape 1x?x8x?
instead of 1x1x8x?
, which makes linalg.unpack
much trickier to vectorise and to lower (dynamic sizes require masks).
Put differently, as
- the inner tile size for
linalg.pack
, and - the corresponding tile size resulting from tiling
match, there's no need for this complex calculation for %7
.
WHY SHOULD %7 BE 1?
For ease of interpretation:
%3
: tile/step size or remaining size (last iteration)%4
: index of the loop over M (tile/step size =8
)%5
: index of the loop over N (tile/step size =8 * vscale
)%7
: from how many%5
s are we reading from -> only != 1 when inner tile sizes are not aligned with tile sizes, which is not the case here (8 * vscale
for both here).
Specifically, given that:%7 = ((%arg5 + %3 - 1) floordiv %c8_vscale - %arg5 floordiv %c8_vscale + 1)
, and since:
%3
<=8 * vscale
-->%3 - 1 < 8 * vscale
("less than equal" --> "less than")(%arg5 + %3 - 1) / %c8_vscale
==%arg5 / %c8_vscale
(given 1.)((%arg5 + %3 - 1) / %c8_vscale - %arg5 / %c8_vscale
==0
(given 2.).
we can safely conclude that:%7 == 1
.
PROPOSED SOLUTION
As pointed out by @egebeysel , this goes back to tiling interface of the unpack op, which is not capable of inferring the alignment information of the vector tile size and the scalable inner tile size.