diff --git a/docs/rfcs/XeGPU.md b/docs/rfcs/XeGPU.md index 16e52b48f..d261bfd41 100644 --- a/docs/rfcs/XeGPU.md +++ b/docs/rfcs/XeGPU.md @@ -2,7 +2,9 @@ ## Summary The XeGPU dialect provides an abstraction that closely models Xe instructions to support high-performance GEMM code generation. -The matrix instructions at this level exactly match the hardware instructions’ semantics including the matrix sizes. +The XeGPU operations are designed to support tile based programming. The same set of operations work at multiple levels, including workgroup, subgroup, and work item. +The workgroup level operation can be decomposed and unrolled to multiple XeGPU operations at subgroup level, which can be further decomposed to work item level. +Along the way, the tensor size is partitioned to smaller size, and the subgroup and work item level XeGPU operations exactly match the hardware instructions’ semantics including the matrix sizes. The lowering and optimizations built on top of the XeGPU dialect are target-specific. ## Proposal @@ -12,7 +14,7 @@ XeGPU operations are introduced when there is a special Xe instruction not model load and store. In some cases, one XeGPU op may lower to a sequence of instructions for a dedicated and performance-critical function. For example, create_tdesc is mapped to a fixed sequence of instructions to create an address description. -Below is a summary. +Below is a summary of operation definition. The operation definition is general and works for workgroup, subgroup, or work item level. | Ops | Syntax | Example | | :--- | :---- | :--- | @@ -34,10 +36,6 @@ Below is a summary. |nbarrier_wait | operation ::= xegpu.nbarrier_wait $nbarrier : type($nbarrier) | xegpu.nbarrier_wait %nbarrier : !xegpu.nbarrier | |fence | operation ::= xegpu.fence attr-dict | xegpu.fence {scope = gpu, memory_kind = global} | -The XeGPU dialect supports lowering from [XeTile dialects]{./XeTile.md}. The tile-based XeTile operation can be further decomposed to -multiple XeGPU ops. For example, XeTile.load_tile operation is lowered to XeGPU’s load_nd or load operations. Compared with the -XeTile dialect, the XeGPU dialect works with even smaller matrix sizes, since XeGPU operations map to one hardware instruction in most cases. - XeGPU supports two flavors of load/store operations: n-dimension load (nd load) and scattered load. Both need a tensor descriptor to describe the addresses/offsets to a data block. The descriptor is used for load/store/prefetch, and then updated for reuse with the next data block. Nd_load can be used to map to 1D load, 2D load, or nd load. Scattered load requires a special tensor descriptor, which @@ -61,9 +59,6 @@ innermost dimension (base_stride[0]) must be 1. into tensor_desc<8x16xbf16> ``` -XeGPU op is carried out by all the work items within a subgroup. The `sg_map` attribute specifies the mapping of each work item to the -data fragments and will be introduced in the next section in details. XeGPU operation without `sg_map` attribute works on the vectors as a whole. - `create_nd_tdesc` can also accept an optional `block_tdesc_attr` to extend its capablity. The `block_tdesc_attr` could encode the following optional attributes: - `memory_space`. It describes where the data block being described is located. `global` means device memory, or `slm` means shared local memory. @@ -303,7 +298,6 @@ In case that certain Xe GPU target does not support atomic operation for a certa xegpu.alloc_nbarrier %total_nbarrier_num: i8 ``` - `init_nbarrier` returns one named barrier with the specified barrier ID to the current thread. Multiple threads may bind to the same named barrier, and the input specifies the number of total participant threads. The returned nbarrier object holds a description of the specified barrier, which encodes all the barrier information. @@ -327,11 +321,9 @@ which encodes all the barrier information. Attribute `scope` describes the scope of fence. "workgroup" means that the scope is within each work group. "gpu" means the scope is across work groups within the gpu. Attribute `Memory_kind` describes the memory kind. "global" means the global memory, "shared" means the shared local memory. -`nbarrier` and `fence` operations lower to uniform instructions, so there is no need to specify the `sg_map`. - ## XeGPU Attributes to support Work Item Level semantics -**Attribute xegpu.sg_map** +**Attribute xegpu.sg_map (To be deprecated)** xegpu.sg_map specifies how a 2D tensor (defined by the tensor descriptor) is partitioned among work items (WIs) within a subgroup. sg_map consists of two parameters: * wi_layout: Defines the 2D arrangement of WIs within the subgroup. @@ -488,7 +480,7 @@ The load with chunk_size pack the low-precision data to 32-bit data using wi_dat User must use legal sg_map value for the WI data distribution for certain operations on PVC and ARC. It includes load_nd/store_nd, load/store with chunk_size, and DPAS. -## Rules of sg_map setting for load and store on PVC and ARC +**Rules of sg_map setting for load and store on PVC and ARC** The WI data distribution requires the following sg_map for the 2D block load and store to work with DPAS on PVC. Not using the sg_map value defined here leads to undefined behavior. ```mlir # assert (wi_layout[0] x wi_layout[1] == subgroup_size) // PVC subgroup_size = 16 @@ -643,9 +635,7 @@ users must use for the WI data distribution of 1D block load and regular load wi #sg_map_t = xegpu.sg_map // for 8-bit data element like uint8, sint8 ``` - - -## sg_map use case - 2D load +**sg_map use case - 2D load** An example on how to load a 2D block, perform dpas, and store back to memory. @@ -678,7 +668,7 @@ An example on how to load a 2D block, perform dpas, and store back to memory. ``` -## sg_map use case - regular load: +**sg_map use case - regular load** An example on how to perform transpose using load with chunk_size in SIMT flavor. ```mlir @@ -701,6 +691,404 @@ An example on how to perform transpose using load with chunk_size in SIMT flavor ``` +## layout attributes to support work group level semantic + +By allowing XeGPU operating on workgroup level data size, it provides a concise IR for tensor compiler instead of multiple level nested loop IR for subgroup and work item level operation. To enable XeGPU operate the workgroup level, we introduce `layout` attribute to specify how the data is distributed across subgroups. `layout` enables tensor compiler to express the cooperative operation among subgroups by specifying a `layout` to partition data among subgroups without manipulating a nested IR representation. The attribute allows tensor compiler to control the tile size for both the workgroup and subgroup and perform autotuning as the number of subgroups, layout, and tile size are critical performance knobs. + +`layout` attribute also includes parameters to support subgroup to work item distribution. `layout` attribute supports multiple dimensions. + +**xegpu.layout** + +`layout` specifies how a n-d tensor (defined by the tensor descriptor) is partitioned among subgroups and work items within a workgroup. `layout` consists of six parameters: + * sg_layout: Defines the n-d arrangement of subgroups within the workgroup. + * sg_data: Specifies the shape of the tensor for each subgroup after decomposition. + * inst_data: Specifies the shape of the tensor for each instruction at subgroup level. It maybe identical to or a fraction of sg_data. + * lane_layout: Defines the n-d arrangement of work items within the subgroup. It is also knowns as sgmap's `wi_layout`. + * lane_data: Specifies the shape of the tensor fragment that each work item owns. The lane_data must be contiguous. Each lane may owns multiple lane_data within one instruction. It is also known as sgmap's `wi_data`. + * order: The dimension order used to linearize n-d subgroup ids and lane ids. The first dimension in the order list is the fastest-changing dimension. + +Example of linerized subgourp id regarding order[1, 0] vs. order [0, 1]. +| sg_layout[4, 4] | order[1, 0] | order[0, 1] +| :---- | :---- | :---- | +| [0, 0], [0, 1], [0, 2], [0, 3] | 0, 1, 2, 3 | 0, 4, 8, 12 | +| [1, 0], [1, 1], [1, 2], [1, 3] | 4, 5 , 6, 7| 1, 5, 9, 13 | +| [2, 0], [2, 1], [2, 2], [2, 3] | 8, 9, 10 , 11 | 2, 6, 10, 14 | +| [3, 0], [3, 1], [3, 2], [3, 3] | 12, 13, 14, 15 | 3, 7, 11, 15 | + +For a subgroup in 3-d sg_layout [dim_0, dim_1, dim_2], order[2, 1, 0] maps a subgroup with 3-d index [x, y, z] to a linear subgroup index [z + dim_2 * y + dim_2 * dim_1 * x ], order[1, 2, 0] maps to [y + dim_2 * z + dim_2 * dim_1 * x]. The same order applies to lane ids with the same formula. + +User may specify all these parameters and expect xegpu mechanically and gradually lowers to xevm dialect. After subgroup distribution, `sg_layout` and `sg_data` will be droped. After work item distribution, `lane_layout`, `lane_data`, and `lane_order` will be droped. + +User may just specify `sg_layout`,`sg_data`, and `order` attributes, and use xegpu passes to automatically fill the rest parameters before lowering. + +**distribution rule** + +The workgroup-level tile, referred to as wg_data, is initially partitioned into sg_data at the subgroup level. It is then further blocked into inst_data to align with the instruction data size, followed by a final distribution into lane_data fragments at the work item level. + +**Constraints** + +Given these definitions: +```mlir +lane_data_size = lane_data[0] × lane_data[1] +sg_size = lane_layout[0] × lane_layout[1] +sg_tile_size = sg_data[0] × sg_data[1] +wg_size = sg_layout[0] × sg_layout[1] +wg_tile_size = wg_data[0] × wg_data[1] +lane_distribution_unit_size = lane_layout x lane_data +sg_distribution_unit_size = sg_layout x sg_data +``` + +The following conditions must hold: +```mlir +* sg_size must represent the number of work items (lanes) in a subgroup for a kernel. +* wg_size must represent the number of subgroups in a workgroup for a kernel. The user may explicitly specify the number of subgroups and their id ranges. +* For any dimension i, wg_data[i] must be either evenly divisible by sg_distribution_unit_size (i.e., wg_data % sg_distribution_unit_size == 0), or equal to sg_data[i]. +* For any dimension i, sg_data[i] must be evenly divisible by inst_data[i]. +* For any dimension i, inst_data[i] must be evenly divisible by sg_distribution_unit_size (i.e., sg_data % lane_distribution_unit_size == 0). +* When lane_data contains multiple elements, they must be contiguous and come from a single dimension. +``` + + +***workgroup to subgroup distribution rule*** +The tensor is first distributed to sg_data along each dimension in a round-robin fashion. If sg_data[i] x sg_layout[i] < wg_tile[i], after all subgroups are assigned for the first round, the rest data will wrap around and be assigned to the first subgroup until the data is completely assigned. If sg_data[i] is equal to tensor_shape[i], the tensor data is broadcasted to all subgroups along the dimension i. +Below is the pseudo code to compute the offsets of subgroup tiles for a given subgroup. +```mlir + // Each subgroup may recieve multiple distribution of sg_data tile. + // sg_dist[i] represents the distribution count along the ith dimension. + // When wg_data size matches sg_data, it is broadcast to all subgroups along that dimension. + + sg_dist[i] = wg_data[i]==sg_data[i] ? 1 : wg_data[i] / (sg_layout[i] * sg_data[i]) + + offset[i] = wg_data[i]==sg_data[i] ? 0 : sg_id[i] * lane_data[i] * [1:sg_dist[i]] +``` + +***blocking rule*** + +With each subgroup tile, it may be decomposed into even smaller tiles. The offsets computed from the following rules are relative to the +subgroup offsets. +```mlir + inst_dist[i] = sg_data[i] / inst_data[i] + + offset[i] = inst_data[i] * [0:inst_dist[i]) +``` + +***subgroup to work item distribution rule*** +As inst_data is evenly divisible by lane_distribution_unit_size , and each work item will recieve the distribution unit multiple times, with each unit having lane_data_size. + +```mlir + lane_dist[i] = inst_data[i] / (lane_layout[i] * lane_data[i]) + offset[i] = sg_id[i] * lane_data[i] * [1:lane_dist[i]] +``` + +The distribution process can be broken down into two steps. The first step divides the sg_data tensor according to lane_layout to obtain a subtensor. The second step linearizes the elements as a 1D tensor. Each lane gets multiple distributions of data fragments, and these fragments are packed into 1D. The data within lane_data is packed first, followed by the packing of data fragments in row-major order. + +**xegpu.slice** +`slice` builds upon xegpu.layout attribute to define the data distribution following a reduction operation. It consists of a base LayoutAttr, which maintains the same rank as the tensor and ensures an even distribution of the data across all subgroups and work-items. In contrast, `slice` modifies this distribution by focusing on the reduced tensor, which is distributed among a subset of subgroups and work-items. For the remaining subgroups and work-items that are not part of the reduction, `slice` ensures that the data is broadcast accordingly. `slice`is frequently employed in operations like vector.multi_reduction and vector.broadcast. + +**Examples of workgroup distribution with xegpu.layout** + +The workgroup creates a tensor descriptor [64, 16] and distributes to 32 subgroups with `sg_layout` [4, 8], and each subgroup gets `sg_data` [16, 16]. The first dimension is split and distributed to 4 subgroups, and the second dimension is broadcast to 8 subgroups. Then it is further distributed to work item level as a vector of 16 bf16 values. + +```mlir + #layout_a = #xegpu.layout + %wg_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xbf16> -> tensor_desc<64x16xbf16, #layout_a> +   %wg_vec = load_nd %wg_tdesc : tensor_desc<64x16xbf16, #layout_a> -> vector<64x16xbf16> +// after subgroup distribution +// #layout_a_sg = #xegpu.layout +// %sg_vec_sg = load_nd %sg_tdesc : tensor_desc<16x16xbf16, #layout_a_sg> -> vector<16x16xbf16> +// after work item distribution +// %sg_vec_wi_0 = load_nd %sg_tdesc : tensor_desc<8x16xbf16> -> vector<8xbf16> +// %sg_vec_wi_1 = load_nd %sg_tdesc : tensor_desc<8x16xbf16> -> vector<8xbf16> +// %sg_vec_wi = vector.shuffle %sg_vec_wi_0, %sg_vec_wi_0, [0..165] : vector<16xbf16> +``` + +The example below shows a workgroup tensor [128, 128] being distributed to 4 subgroups with `sg_layout` [2,2], with each subgroup assigned `sg_data` [32,128]. +```mlir + #layout_a = #xegpu.layout + %wg_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> tensor_desc<128x128xf16, #layout_a> +``` +The table below illustrates the result tensor for each subgroup and its linear subgroup id. + +| subgroup tensor | 2D subgroup id | Linearized subgroup id +| :--- | :---- | :---- | +| [ 0:31, 0:127] | [0, 0], [0, 1] | 0 , 1 | +| [ 32:63, 0:127] | [1, 0], [1, 1] | 2 , 3 | +| [ 64:95, 0:127] | [0, 0], [0, 1] | 0 , 1 | +| [ 96:127, 0:127] | [1, 0], [1, 1] | 2 , 3 | + +The `layout` attribute propagates from the matrix multiplication ops to other ops. Since we can't attatch the `layout` attribute to MLIR vector data type, we attach the attribute to vector type-based operations. The `layout` attribute propagation can be performed from output to input, or the other direction. We describes below the propagation rules from output to input for typical operations including dpas, reduction, broadcast, shape_cast, and transpose. + +For `dpas`, the `layout` attribute of input operands must have the same `sg_layout`, and `sg_data` for m and n dimension as output, and `sg_data` for k dimension must be same as operand A and B. `order` must be same as output. +```mlir + #layout_d = #xegpu.layout + %vector_d = xegpu.dpas %vector_a, %vector_b, %vector_c {#layout_d}: + vector<256x32xfloat>, vector<32x256xbf16>, vector<256x256xbf16> + into vector<256x256xfloat> + //derived layout for input operands + #layout_a = #xegpu.layout //layout for %vector_a + #layout_b = #xegpu.layout //layout for %vector_b + #layout_c = #xegpu.layout //layout for %vector_c +``` + +For `reduction`, `xegpu.slice` is introduced to represent the `layout` of the reduced tensor. It inherits a regualr `layout` and specifies the dimension being reduced. +It conceptually squeezes the threads along that dim, and all threads share the same mapping with their representative thread. + +```mlir + #layout_a = #xegpu.layout + #layout_a_reduce = #xegpu.slice<#layout_a, 1> + %vector_a = vector.multi_reduction %vector_b, %cst_0 [1] {#layout_a_reduce}: vector<256x128xfloat> into vector<256xfloat> + + //derived layout for input operand + #layout_b = #xegpu.layout +``` + +The rule also applies to reduction from 3d to 2d. +```mlir + #layout_a = #xegpu.layout + #layout_a_reduce = #xegpu.slice<#layout_a, 1> + %%vector_a = vector.multi_reduction , %vector_b, %cst_0 [1] {#layout_a_reduce}: vector<8x32x128xf32> to vector<8x128xf32> + + //derived layout for input operand + #layout_b = #xegpu.layout +``` + +For `shape_cast`, it first determines the dimensions being reduced or expanded. The input's `layout` needs to expand or reduce the value accordingly for related dimension in `sg_layout` and `sg_data`. `order` should be consistent between input and output. +```mlir + #layout_a = #xegpu.layout + %vector_a = vector.shape_cast %vector_b {#layout_a} : vector<256x128xf32> to vector<8x32x128xf32> + + //derived layout for input operand + #layout_b = #xegpu.layout +``` +```mlir + #layout_a = #xegpu.layout + %vector_a = vector.shape_cast %vector_b {#layout_a} : vector<256x128xf32> to vector<256x1x128xf32> + + //derived layout for input operand + #layout_b = #xegpu.layout + #layout_b_reduce = #xegpu.slice<#layout_b, 1> +``` + +For `broadcast`, `layout` of the input operand has one less dimension for the broadcast dimension. `sg_layout` for that dimension must be `1` in the ouptut layout and must be removed for the input operand. The corresponding dimension in `sg_data` and `order` must be removed also. + +```mlir + #layout_a = #xegpu.layout + #layout_a_reduce = #xegpu.slice<#layout_a, 1> + %vector_a = vector.broadcast %vector_b [1] {#layout_a_reduce}: vector<256xfloat> into vector<256x256xfloat> + + //derived layout for input operand + #layout_b = #xegpu.layout +``` + +For `transpose`, the values in `layout` must be swapped for the two dimensions being transposed, including `sg_layout`, `sg_data`, and `order`. +```mlir + #layout_a = #xegpu.layout + %vector_a = vector.transpose %vector_b {#layout_a}: vector<512x128xfloat> into vector<128x512xfloat> + + //derived layout for input operand + #layout_b = #xegpu.layout +``` + +`layout` may be assinged for certain operation before the workgroup layout propagation, for example, the cooperative load pass may specify `layout` for certain load to be cooperated. In this case, the propagation may insert an operation to express the conversion of one `layout` to the other. + +`convert_layout` is introduced to convert two inconsistent `layout`. + +```mlir + #layout_b = #xegpu.layout // used for cooperative load/prefetch + #layout_a = #xegpu.layout // used as mma's input matrix A + %vector_a = xegpu.convert_layout %vector_b {#layout_a #layout_b}: vector<256x256xfloat> into vector<256x256xfloat> +``` +The `layout` conversion can be lowered to storing and loading from the shared local memory. It can be conceptually viewed as a composition of two operations: 1) store the vector to shared local memory with the #layout_b and 2) use layout_a mapping to load the data from shared local memory. + +## Appendix 1 - Code examples for work group level XeGPU using layout attribute + +## Appendix 1.1 Simple Gemm with prefetch +The first example shows a simple gemm. It demonstrates the different layout we used for prefetch and load. The layout doesn't show lane_layout and lane_data for simplicity. +```mlir +Pseudo code for simple gemm +C[4096, 4096] = matmul (A[4096, 4096], B[4096, 4096]) +``` + +```mlir +#mp_a = #layout +#mp_a_pfh = #layout +#mp_b = #layout +#mp_c = #layout + +func.func @test_gemm(%a : memref<4096x4096xf16>, +     %b: memref<4096x4096xf16>, + %c: memref<4096xf32> ) { + scf.for %i = %c0 to %c4096 step %c256 { + scf.for %j = %c0 to %c4096 step %c256 { +   %1 = create_nd_tdesc %a[%i, %c0] : memref<4096x4096xf16> -> tensor_desc<256x32xf16, #mp_a> // sg_layout=[8,4], sg_data=[32,32] +   %2 = create_nd_tdesc %b[%c0, %j] : memref<4096x4096xf16> -> tensor_desc<32x256xf16, #mp_b> // sg_layout=[8,4], sg_data=[32,64] +   %1p = create_nd_tdesc %a[%i, %c96] : memref<4096x4096xf16> -> tensor_desc<256x32xf16, #mp_a_pfh]> // sg_layout=[32,1] +   %2p = create_nd_tdesc %b[%c96, %j] : memref<4096x4096xf16> -> tensor_desc<32x256xf16, #mp_b_pfh> // sg_layout=[4,8] + + %3 = create_nd_tdesc %c[%i, %j] : memref<4096x4096xf32> -> tensor_desc<256x256xf32, #mp_c> // sg_layout=[32, 1] + + scf.for %k= %c0 to %c4096 step %c32 { +   %4 = load_nd %1 : tensor_desc<256x32xf16 #mp_a > -> vector<256x32xf16> // sg_layout=[8,4], sg_data=[32,32] + %10 = load_nd %2 : tensor_desc<32x256xf16 #mp_b> -> vector<32x256xf16> // sg_layout=[8,4], sg_data=[32,64] + +   prefetch_nd %1 : tensor_desc<256x32xf16, #mp_a_pfh>              // sg_layout=[32,1] + prefetch_nd %2 : tensor_desc<32x256xf16, #mp_a_pfh>              // sg_layout=[4,8] + %6 = dpas %4, %10 {#mp_a #mp_b #mp_c} : (vector<256x32xf16>, vector<32x256xf16>) -> vector<256x256xf32> //sg_layout=[8,4] + %1 = update_nd_offset%1, %c0, %c32 : tensor_desc<256x32xf16, #mp_a> + %2 = update_nd_offset%2, %c32, %c0 : tensor_desc<32x256xf16, #mp_b> + %1p = update_nd_offset%1p, %c0, %c32 : tensor_desc<256x32xf16, #mp_a_pft> + %2p = update_nd_offset%2p, %c32, %c0 : tensor_desc<32x256xf16, #mp_b_pft> + } +   store_nd %3, %6: (tensor_desc<256x256xf32, #mp_c>, vector<256x256xf32>)          // sg_layout=[8, 4] + } + } +``` +## Appendix 1.2 Gemm with transpose, broadcast, and reduction +The second example contains transpose, broadcast, and reduction. The layout doesn't show lane_layout and lane_data for simplicity. +```mlir +Pseduo code for the original problem. +C[4096, 4096] = matmul (A[4096, 4096], BT[4096, 4096]) + broad_cast(bcast[4096], dim=0) +Reduce[4096] = reduce_add(C[4096, 4096], dim=1) +``` + +```mlir +#mp_a = #layout +#mp_a_pfh = #layout +#mp_b = #layout +#mp_bt = #layout +#mp_bt_pfh = #layout +#mp_c = #layout + +#mp_bcast = #layout +#mp_bcast2 = #layout +#mp_reduce= #layout +#mp_reduce2= #layout + +func.func @test_gemm(%a : memref<4096x4096xf16>, +     %b: memref<4096x4096xf16>, + %bcast: memref<4096xf32> +     %res: memref<4096xf32> ) { + scf.for %i = %c0 to %c4096 step %c256 { + scf.for %j = %c0 to %c4096 step %c256 { +   %1 = create_nd_tdesc %a[%i, %c0] : memref<4096x4096xf16> -> tensor_desc<256x32xf16, #mp_a> // sg_layout=[8,4], sg_data=[32,32] +   %2 = create_nd_tdesc %bt[%j, %c0] : memref<4096x4096xf16> -> tensor_desc<256x32xf16, #mp_bt> // sg_layout=[4,8], sg_data=[64,32] +   %1p = create_nd_tdesc %a[%i, %c192] : memref<4096x4096xf16> -> tensor_desc<256x32xf16, #mp_a_pfh]> // sg_layout=[32,1] +   %2p = create_nd_tdesc %bt[%j, %c192] : memref<4096x4096xf16> -> tensor_desc<256x32xf16, #mp_bt_pfh> // sg_layout=[32,1] + + %bcast'= memref.cast %bcast: memref<4096xf32> -> memref<1x4096xf32> +   %7 = create_nd_tdesc %bcast'[%j] : memref<1x4096xf32> -> tensor_desc<1x256xf32, #mp_bast> // sg_layout=[4, 8], sg_data=[1,32] + + %res'= memref.cast %res: memref<4096xf32> -> memref<4096x1xf32> +   %3 = create_nd_tdesc %res'[%i] : memref<4096x1xf32> -> tensor_desc<256x1xf32, #mp_reduce> // sg_layout=[32, 1] + + scf.for %k= %c0 to %c4096 step %c32 { +   %4 = load_nd %1 : tensor_desc<256x32xf16 #mp_a > -> vector<256x32xf16> // sg_layout=[8,4], sg_data=[32,32] + %10 = load_nd %2 : tensor_desc<256x32xf16 #mp_bt> -> vector<256x32xf16> // sg_layout=[4,8], sg_data=[64,32] + %5 = vector.transpose %10 {#mp_bt #mp_b}: vector<256x32xf16> -> vector<32x256xf16> // sg_layout=[4,8] -> sg_layout=[8,4] + +   prefetch_nd %1 : tensor_desc<256x32xf16, #mp_a_pfh>              // sg_layout=[32,1] + prefetch_nd %2 : tensor_desc<256x32xf16, #mp_a_pfh>              // sg_layout=[32,1] + %6 = dpas %4, %5 {#mp_a #mp_b #mp_c} : (vector<256x32xf16>, vector<32x256xf16>) -> vector<256x256xf32> //sg_layout=[8,4] + %1 = update_nd_offset%1, %c0, %c32 : tensor_desc<256x32xf16, #mp_a> + %2 = update_nd_offset%2, %c0, %c32 : tensor_desc<256x32xf16, #mp_bt> + %1p = update_nd_offset%1p, %c0, %c32 : tensor_desc<256x32xf16, #mp_a_pft> + %2p = update_nd_offset%2p, %c32, %c0 : tensor_desc<256x32xf16, #mp_bt_pft> + } + + %12 = load_nd %7 : tensor_desc<256xf32, #mp_bcast2> -> vector<256xf16>     // sg_layout=[32], sg_data=[8] + %12' = convert_layout {#mp_bcast2 #mp_bcast} %12 : vector<256x256xf32> // sg_layout=[32] -> sg_layout=[8, 4] +   %13 = vector.broadcast {#mp_c} %12' [0]: vector<256xf32> => vector<256x256xf32>     // sg_layout=[8, 4], sg_data=[32,64] + %14 = add %6, %13 : vector<256x256xf32> +    %14' = convert_layout {#mp_c #mp_reduce2} %14 : vector<256x256xf32>   // sg_layout=[8, 4] -> sg_layout=[32] +    %16 = vector.reduction {#mp_reduce2 #mp_reduce} %14' [1]: vector<256x256xf32> => vector<256xf32>  // sg_layout=[32] +   store_nd %3, %7: (tensor_desc<256xf32, #mp_reduce>, vector<256xf32>)          // sg_layout=[32] + } + } +``` + +## Appendix 1.3 Gemm implementation with two cache levels +For GPU support high-performance prefetch through two level of caches. +```mlir +#mp_a = #layout +#mp_b = #layout +#mp_c = #layout + +#mp_a_copl2 = #layout +#mp_b_copl2 = #layout< sg_layout=[16,2], sg_data=[8,128], order=[1,0]> + +#mp_a_copl1 = #layout +#mp_b_copl1 = #layout< sg_layout=[4, 8], sg_data=[8,32], order=[1,0]> + +func.func @test_gemm(%a : memref<4096x4096xf16>, +     %b: memref<4096x4096xf16>, + %c: memref<4096xf32> ) { + scf.for %i = %c0 to %c4096 step %c256 { + scf.for %j = %c0 to %c4096 step %c256 { +   %a1_l2 = create_nd_tdesc %a[%i, %c0] : memref<4096x4096xf16> -> tensor_desc<512x128xf16, #mp_a_copl2> +   %b1_l2 = create_nd_tdesc %b[%c0, %j] : memref<4096x4096xf16> -> tensor_desc<128x256xf16, #mp_b_copl2> +   %a2_l2 = create_nd_tdesc %a[%i, %c256] : memref<4096x4096xf16> -> tensor_desc<512x128xf16, #mp_a_copl2> +   %b2_l2 = create_nd_tdesc %b[%c256, %j] : memref<4096x4096xf16> -> tensor_desc<128x256xf16, #mp_b_copl2> + + prefetch_nd %a1_l2 locality<2>: tensor_desc<512x128xf16, #mp_a_copl2> + prefetch_nd %b1_l2 locality<2>: tensor_desc<128x256xf16, #mp_b_copl2> + prefetch_nd %a2_l2 locality<2>: tensor_desc<512x128xf16, #mp_a_copl2> + prefetch_nd %b2_l2 locality<2>: tensor_desc<128x256xf16, #mp_b_copl2> + %a2_l2’ = update_nd_offset%a2_l2, %c0, %c32 : tensor_desc<512x128xf16, #mp_b_copl2> + %b2_l2’ = update_nd_offset%b2_l2, %c32, %c0 : tensor_desc<128x256xf16, #mp_b_copl2> + +   %a1_l1 = create_nd_tdesc %a[%i, %c0] : memref<4096x4096xf16> -> tensor_desc<512x32xf16, #mp_a_copl1> +   %b1_l1 = create_nd_tdesc %b[%c0, %j] : memref<4096x4096xf16> -> tensor_desc<32x256xf16, #mp_b_copl1> +   %a2_l1 = create_nd_tdesc %a[%i, %c32] : memref<4096x4096xf16> -> tensor_desc<512x32xf16, #mp_a_copl1> +   %b2_l1 = create_nd_tdesc %b[%c32, %j] : memref<4096x4096xf16> -> tensor_desc<32x256xf16, #mp_b_copl1> +   %a3_l1 = create_nd_tdesc %a[%i, %c64] : memref<4096x4096xf16> -> tensor_desc<512x32xf16, #mp_a_copl1> +   %b3_l1 = create_nd_tdesc %b[%c64, %j] : memref<4096x4096xf16> -> tensor_desc<32x256xf16, #mp_b_copl1> +   %a4_l1 = create_nd_tdesc %a[%i, %c96] : memref<4096x4096xf16> -> tensor_desc<512x32xf16, #mp_a_copl1> +   %b4_l1 = create_nd_tdesc %b[%c96, %j] : memref<4096x4096xf16> -> tensor_desc<32x256xf16, #mp_b_copl1> + + prefetch_nd %a1_l1 locality<3>: tensor_desc<512x32xf16, #mp_a_copl1> + prefetch_nd %b1_l1 locality<3>: tensor_desc<32x256xf16, #mp_b_copl1> + prefetch_nd %a2_l1 locality<3>: tensor_desc<512x32xf16, #mp_a_copl1> + prefetch_nd %b2_l1 locality<3>: tensor_desc<32x256xf16, #mp_b_copl1> + prefetch_nd %a3_l1 locality<3>: tensor_desc<512x32xf16, #mp_a_copl1> + prefetch_nd %b3_l1 locality<3>: tensor_desc<32x256xf16, #mp_b_copl1> + prefetch_nd %a4_l1 locality<3>: tensor_desc<512x32xf16, #mp_a_copl1> + prefetch_nd %b4_l1 locality<3>: tensor_desc<32x256xf16, #mp_b_copl1> + %a4_l1’ = update_nd_offset% a4_l1, %c0, %c128 : tensor_desc<512x32xf16, #mp_a_copl1> + %b4_l1’ = update_nd_offset% b4_l1, %c128, %c0 : tensor_desc<32x256xf16, #mp_b_copl1> + +   %a1_load = create_nd_tdesc %a[%i, %c0] : memref<4096x4096xf16> -> tensor_desc<512x32xf16, #mp_a> +   %b1_load = create_nd_tdesc %b[%c0, %j] : memref<4096x4096xf16> -> tensor_desc<32x256xf16, #mp_b> + + %c_tile = create_nd_tdesc %c[%i, %j] : memref<4096x4096xf32> -> tensor_desc<512x256xf32, #mp_c> + + scf.for %k= %c0 to %c4096 step %c32 { + %a1_r = load_nd %a1_load : tensor_desc<256x32xf16 #mp_a > -> vector<512x32xf16> + %b1_r = load_nd %b1_load : tensor_desc<32x256xf16 #mp_b> -> vector<32x256xf16> + + Scf.if (%k %4 == 0) { + gpu.barrier + prefetch_nd %a2_l2’ locality<2>: tensor_desc<512x128xf16, #mp_a_copl2> + prefetch_nd %b2_l2’ locality<2>: tensor_desc<128x256xf16, #mp_b_copl2> + %a2_l2’ = update_nd_offset%a2_l2’, %c0, %c128 : tensor_desc<512x128xf16, #mp_a_copl2> + %b2_l2’ = update_nd_offset%b2_l2’, %c128, %c0 : tensor_desc<128x256xf16, #mp_b_copl2> + } + prefetch_nd %a4_l1’ locality<3>: tensor_desc<512x32xf16, #mp_a_copl1> + prefetch_nd %b4_l1’ locality<3>: tensor_desc<32x256xf16, #mp_b_copl1> + %a4_l1’ = update_nd_offset%a4_l1’, %c0, %c32 : tensor_desc<512x32xf16, #mp_a_copl1> + %b4_l1’ = update_nd_offset%b4_l1’, %c32, %c0 : tensor_desc<32x256xf16, #mp_b_copl1> + + %a1_load = update_nd_offset%a1_load, %c0, %c32 : tensor_desc<512x32xf16, #mp_a> + %a2_load = update_nd_offset%b1_load, %c32, %c0 : tensor_desc<32x256xf16, #mp_b> + + %6 = dpas %a1_r, %b1_r {#mp_a #mp_b #mp_c} : (vector<512x32xf16>, vector<32x256xf16>) -> vector<512x256xf32> + } +  store_nd %c_tile, %6: (tensor_desc<512x256xf32, #mp_c>, vector<512x256xf32>) + } + } +} +``` + ## Notes Currently, there is no lower-level dialect for the Intel GPU compiler toolchain to represent GPU ops with values based on LLVM data types such as NVVM