1
1
// The MLIR prototype of batchmatmul-optimize in buddy-opt.
2
2
3
3
#map = affine_map <(d0 ) -> (d0 ceildiv STEP_PLACEHOLDER )>
4
+ #tail_len_map = affine_map <(d0 ) -> (d0 mod STEP_PLACEHOLDER )>
5
+ #if_set = affine_set <(d0 )[s0 ] : (s0 - d0 * STEP_PLACEHOLDER >= STEP_PLACEHOLDER )>
6
+ #b_col_idx_tail_map = affine_map <(d0 ) -> (d0 * STEP_PLACEHOLDER )>
7
+
4
8
func.func @batch_matmul_broadcast_STEP_PLACEHOLDER (%a : memref <?x?x?xf32 >, %b : memref <?x?x?xf32 >, %c : memref <?x?x?xf32 >) {
5
9
%c0 = arith.constant 0 : index
6
10
%c1 = arith.constant 1 : index
@@ -15,32 +19,27 @@ func.func @batch_matmul_broadcast_STEP_PLACEHOLDER(%a : memref<?x?x?xf32>, %b :
15
19
%b_col = memref.dim %b , %c2 : memref <?x?x?xf32 >
16
20
%batch = memref.dim %a , %c0 : memref <?x?x?xf32 >
17
21
22
+ %tail_len = affine.apply #tail_len_map (%b_col )
23
+ %mask_vec = vector.create_mask %tail_len : vector <STEP_PLACEHOLDERxi1 >
24
+
18
25
affine.parallel (%batch_idx ) = (0 ) to (%batch ){ // Affine.parallel can be lowered to the omp dialect, which enables batch-level parallelization.
19
26
affine.prefetch %a [%batch_idx , %a_row , %a_col ], read , locality <3 >, data : memref <?x?x?xf32 > // Explicitly prefetch, about 5% faster on X86.
20
27
affine.for %b_row_idx = 0 to %b_row {
28
+ affine.for %b_col_idx = 0 to #map (%b_col ) {
29
+ %b_vec = affine.vector_load %b [%batch_idx , %b_row_idx , %b_col_idx * STEP_PLACEHOLDER ] : memref <?x?x?xf32 >, vector <STEP_PLACEHOLDERxf32 >
30
+ %b_col_idx_tail = affine.apply #b_col_idx_tail_map (%b_col_idx )
21
31
affine.for %a_row_idx = 0 to %a_row {
22
- affine.for %b_col_idx = 0 to #map (%b_col ) {
23
- %a_ele = affine.load %a [%batch_idx , %a_row_idx , %b_row_idx ] : memref <?x?x?xf32 >
24
- %a_vec = vector.broadcast %a_ele : f32 to vector <STEP_PLACEHOLDERxf32 >
25
- // Check tail.
26
- %b_col_cur = arith.muli %b_col_idx , %step : index
27
- %tail_len = arith.subi %b_col , %b_col_cur : index
28
- %tail_flag = arith.cmpi sge , %tail_len , %step : index
29
- scf.if %tail_flag {
30
- %b_vec = affine.vector_load %b [%batch_idx , %b_row_idx , %b_col_idx * STEP_PLACEHOLDER ] : memref <?x?x?xf32 >, vector <STEP_PLACEHOLDERxf32 >
31
- %c_vec = affine.vector_load %c [%batch_idx , %a_row_idx , %b_col_idx * STEP_PLACEHOLDER ] : memref <?x?x?xf32 >, vector <STEP_PLACEHOLDERxf32 >
32
- %result_vec = vector.fma %a_vec , %b_vec , %c_vec : vector <STEP_PLACEHOLDERxf32 >
33
- affine.vector_store %result_vec , %c [%batch_idx , %a_row_idx , %b_col_idx * STEP_PLACEHOLDER ] : memref <?x?x?xf32 >, vector <STEP_PLACEHOLDERxf32 >
34
- } else {
35
- %mask_vec = vector.create_mask %tail_len : vector <STEP_PLACEHOLDERxi1 >
36
- %b_col_idx_tail = arith.muli %b_col_idx , %step : index
37
- %b_vec_tail = vector.maskedload %b [%batch_idx , %b_row_idx , %b_col_idx_tail ], %mask_vec , %c0_f32_vec : memref <?x?x?xf32 >, vector <STEP_PLACEHOLDERxi1 >, vector <STEP_PLACEHOLDERxf32 > into vector <STEP_PLACEHOLDERxf32 >
38
- %c_vec_tail = vector.maskedload %c [%batch_idx , %a_row_idx , %b_col_idx_tail ], %mask_vec , %c0_f32_vec : memref <?x?x?xf32 >, vector <STEP_PLACEHOLDERxi1 >, vector <STEP_PLACEHOLDERxf32 > into vector <STEP_PLACEHOLDERxf32 >
39
- %result_vec_tail = vector.fma %a_vec , %b_vec_tail , %c_vec_tail : vector <STEP_PLACEHOLDERxf32 >
40
- vector.maskedstore %c [%batch_idx , %a_row_idx , %b_col_idx_tail ], %mask_vec , %result_vec_tail : memref <?x?x?xf32 >, vector <STEP_PLACEHOLDERxi1 >, vector <STEP_PLACEHOLDERxf32 >
41
- }
42
- }
32
+ %a_ele = affine.load %a [%batch_idx , %a_row_idx , %b_row_idx ] : memref <?x?x?xf32 >
33
+ %a_vec = vector.broadcast %a_ele : f32 to vector <STEP_PLACEHOLDERxf32 >
34
+ %c_vec = affine.vector_load %c [%batch_idx , %a_row_idx , %b_col_idx * STEP_PLACEHOLDER ] : memref <?x?x?xf32 >, vector <STEP_PLACEHOLDERxf32 >
35
+ %result_vec = vector.fma %a_vec , %b_vec , %c_vec : vector <STEP_PLACEHOLDERxf32 >
36
+ affine.if #if_set (%b_col_idx )[%b_col ] {
37
+ affine.vector_store %result_vec , %c [%batch_idx , %a_row_idx , %b_col_idx * STEP_PLACEHOLDER ] : memref <?x?x?xf32 >, vector <STEP_PLACEHOLDERxf32 >
38
+ } else {
39
+ vector.maskedstore %c [%batch_idx , %a_row_idx , %b_col_idx_tail ], %mask_vec , %result_vec : memref <?x?x?xf32 >, vector <STEP_PLACEHOLDERxi1 >, vector <STEP_PLACEHOLDERxf32 >
40
+ }
43
41
}
42
+ }
44
43
}
45
44
}
46
45
return
0 commit comments