Skip to content

[mlir] How to simplify this overly complex size calculation after tiling? #150185

@banach-space

Description

@banach-space

Hi folks,

This is based on iree-org/iree#21393 that was originally posted by @egebeysel , thanks! I've re-written it using "pure" MLIR (i.e. not requiring IREE).

REPRO

func.func @unpack(%arg0: tensor<512x?x8x?xf32>, %arg1: tensor<4096x4096xf32>, %arg2: tensor<512x?x8x?xf32>) -> tensor<4096x4096xf32> {
  %c8 = arith.constant 8 : index
  %vscale = vector.vscale
  %c8_vscale = arith.muli %vscale, %c8 : index
  %0 = tensor.empty() : tensor<4096x4096xf32>
  %unpack = linalg.unpack %arg2 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, %c8_vscale] into %0 : tensor<512x?x8x?xf32> -> tensor<4096x4096xf32>
  return %unpack : tensor<4096x4096xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%module : !transform.any_op {transform.readonly}) {
    %unpack = transform.structured.match ops{["linalg.unpack"]} in %module
      : (!transform.any_op) -> !transform.any_op

    %tiled_unpack, %loops:2 = transform.structured.tile_using_for %unpack tile_sizes [8, [8]]
      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)

    transform.yield
  }
}

After tiling (bin/mlir-opt --transform-interpreter unpack.mlir -cse -test-transform-dialect-erase-schedule):

#map = affine_map<(d0)[s0] -> (-d0 + 4096, s0)>
#map1 = affine_map<(d0) -> (d0 floordiv 8)>
#map2 = affine_map<(d0)[s0] -> (d0 floordiv s0)>
#map3 = affine_map<(d0)[s0] -> (d0 mod s0)>
#map4 = affine_map<(d0, d1)[s0] -> ((d0 + d1 - 1) floordiv s0 - d0 floordiv s0 + 1)>

// ...
    %1 = scf.for %arg3 = %c0 to %c4096 step %c8 iter_args(%arg4 = %0) -> (tensor<4096x4096xf32>) {
      %2 = scf.for %arg5 = %c0 to %c4096 step %c8_vscale iter_args(%arg6 = %arg4) -> (tensor<4096x4096xf32>) {
        %3 = affine.min #map(%arg5)[%c8_vscale]
        %4 = affine.apply #map1(%arg3)
        %5 = affine.apply #map2(%arg5)[%c8_vscale]
        %7 = affine.apply #map4(%arg5, %3)[%c8_vscale]
        %extracted_slice = tensor.extract_slice %arg2[%4, %5, 0, 0] [1, %7, 8, %c8_vscale] [1, 1, 1, 1] : tensor<512x?x8x?xf32> to tensor<1x?x8x?xf32>
        %9 = tensor.empty(%8) : tensor<8x?xf32>
        %unpack = linalg.unpack %extracted_slice outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, %c8_vscale] into %9 : tensor<1x?x8x?xf32> -> tensor<8x?xf32>
// ...

ISSUE

This expression in the output is overly complex:

%7 = affine.apply #map4(%arg5, %3)[%c8_vscale]
// ...
%extracted_slice = tensor.extract_slice %arg2[%4, %5, 0, 0] [1, %7, 8, %c8_vscale] [1, 1, 1, 1] : tensor<512x?x8x?xf32> to tensor<1x?x8x?xf32>

Specifically, %7 should be trivially 1 (see below). As a result, we get the unpack source shape 1x?x8x? instead of 1x1x8x?, which makes linalg.unpack much trickier to vectorise and to lower (dynamic sizes require masks).

Put differently, as

  • the inner tile size for linalg.pack, and
  • the corresponding tile size resulting from tiling

match, there's no need for this complex calculation for %7.

WHY SHOULD %7 BE 1?
For ease of interpretation:

  • %3: tile/step size or remaining size (last iteration)
  • %4: index of the loop over M (tile/step size = 8)
  • %5: index of the loop over N (tile/step size = 8 * vscale)
  • %7: from how many %5s are we reading from -> only != 1 when inner tile sizes are not aligned with tile sizes, which is not the case here (8 * vscale for both here).

Specifically, given that:%7 = ((%arg5 + %3 - 1) floordiv %c8_vscale - %arg5 floordiv %c8_vscale + 1), and since:

  1. %3 <= 8 * vscale --> %3 - 1 < 8 * vscale ("less than equal" --> "less than")
  2. (%arg5 + %3 - 1) / %c8_vscale == %arg5 / %c8_vscale (given 1.)
  3. ((%arg5 + %3 - 1) / %c8_vscale - %arg5 / %c8_vscale == 0 (given 2.).

we can safely conclude that:%7 == 1.

PROPOSED SOLUTION

As pointed out by @egebeysel , this goes back to tiling interface of the unpack op, which is not capable of inferring the alignment information of the vector tile size and the scalable inner tile size.

Metadata

Metadata

Assignees

Labels

mlirquestionA question, not bug report. Check out https://llvm.org/docs/GettingInvolved.html instead!

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions