Skip to content

Commit 4e3e288

Browse files
committed
[OpOptimization] Further optimize BatchMatMulBroadcast and add OpenMP tests.
1 parent ff2049a commit 4e3e288

File tree

4 files changed

+76
-28
lines changed

4 files changed

+76
-28
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,15 @@ $ mkdir build && cd build
227227
$ cmake -G Ninja .. \
228228
-DCMAKE_BUILD_TYPE=RELEASE \
229229
-DOP_OPTIMIZATION_BENCHMARKS=ON \
230+
-DCMAKE_CXX_COMPILER=clang++ \
230231
-DBUDDY_MLIR_BUILD_DIR=/PATH/TO/BUDDY-MLIR/BUILD/
231232
$ ninja <your target operation benchmark>
232233
233234
// Operation benchamrk supported include:
234235
// - conv2d-nchw-fchw-benchmark
235236
// - matmul-benchmark
236237
```
238+
OpenMP is required in matmul-benchmark, make sure `libomp` and `libomp-dev` (on Ubuntu and Debian) / `libomp-devel` (on Redhat and SUSE) have been installed.
237239

238240
Run TVM operation optimization benchmark cases.
239241
- Install TVM ([steps](./thirdparty/README.md#tvm)).

benchmarks/OpOptimization/MatMul/BatchMatMulBroadcast.mlir

+20-21
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
// The MLIR prototype of batchmatmul-optimize in buddy-opt.
22

33
#map = affine_map<(d0) -> (d0 ceildiv STEP_PLACEHOLDER)>
4+
#tail_len_map = affine_map<(d0) -> (d0 mod STEP_PLACEHOLDER)>
5+
#if_set = affine_set<(d0)[s0] : (s0 - d0 * STEP_PLACEHOLDER >= STEP_PLACEHOLDER)>
6+
#b_col_idx_tail_map = affine_map<(d0) -> (d0 * STEP_PLACEHOLDER)>
7+
48
func.func @batch_matmul_broadcast_STEP_PLACEHOLDER(%a : memref<?x?x?xf32>, %b : memref<?x?x?xf32>, %c : memref<?x?x?xf32>) {
59
%c0 = arith.constant 0 : index
610
%c1 = arith.constant 1 : index
@@ -15,32 +19,27 @@ func.func @batch_matmul_broadcast_STEP_PLACEHOLDER(%a : memref<?x?x?xf32>, %b :
1519
%b_col = memref.dim %b, %c2 : memref<?x?x?xf32>
1620
%batch = memref.dim %a, %c0 : memref<?x?x?xf32>
1721

22+
%tail_len = affine.apply #tail_len_map(%b_col)
23+
%mask_vec = vector.create_mask %tail_len : vector<STEP_PLACEHOLDERxi1>
24+
1825
affine.parallel (%batch_idx) = (0) to (%batch){ // Affine.parallel can be lowered to the omp dialect, which enables batch-level parallelization.
1926
affine.prefetch %a[%batch_idx, %a_row, %a_col], read, locality<3>, data : memref<?x?x?xf32> // Explicitly prefetch, about 5% faster on X86.
2027
affine.for %b_row_idx = 0 to %b_row {
28+
affine.for %b_col_idx = 0 to #map(%b_col) {
29+
%b_vec = affine.vector_load %b[%batch_idx, %b_row_idx, %b_col_idx * STEP_PLACEHOLDER] : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxf32>
30+
%b_col_idx_tail = affine.apply #b_col_idx_tail_map(%b_col_idx)
2131
affine.for %a_row_idx = 0 to %a_row {
22-
affine.for %b_col_idx = 0 to #map(%b_col) {
23-
%a_ele = affine.load %a[%batch_idx, %a_row_idx, %b_row_idx] : memref<?x?x?xf32>
24-
%a_vec = vector.broadcast %a_ele : f32 to vector<STEP_PLACEHOLDERxf32>
25-
// Check tail.
26-
%b_col_cur = arith.muli %b_col_idx, %step : index
27-
%tail_len = arith.subi %b_col, %b_col_cur : index
28-
%tail_flag = arith.cmpi sge, %tail_len, %step : index
29-
scf.if %tail_flag {
30-
%b_vec = affine.vector_load %b[%batch_idx, %b_row_idx, %b_col_idx * STEP_PLACEHOLDER] : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxf32>
31-
%c_vec = affine.vector_load %c[%batch_idx, %a_row_idx, %b_col_idx * STEP_PLACEHOLDER] : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxf32>
32-
%result_vec = vector.fma %a_vec, %b_vec, %c_vec : vector<STEP_PLACEHOLDERxf32>
33-
affine.vector_store %result_vec, %c[%batch_idx, %a_row_idx, %b_col_idx * STEP_PLACEHOLDER] : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxf32>
34-
} else {
35-
%mask_vec = vector.create_mask %tail_len : vector<STEP_PLACEHOLDERxi1>
36-
%b_col_idx_tail = arith.muli %b_col_idx, %step : index
37-
%b_vec_tail = vector.maskedload %b[%batch_idx, %b_row_idx, %b_col_idx_tail], %mask_vec, %c0_f32_vec : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxi1>, vector<STEP_PLACEHOLDERxf32> into vector<STEP_PLACEHOLDERxf32>
38-
%c_vec_tail = vector.maskedload %c[%batch_idx, %a_row_idx, %b_col_idx_tail], %mask_vec, %c0_f32_vec : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxi1>, vector<STEP_PLACEHOLDERxf32> into vector<STEP_PLACEHOLDERxf32>
39-
%result_vec_tail = vector.fma %a_vec, %b_vec_tail, %c_vec_tail : vector<STEP_PLACEHOLDERxf32>
40-
vector.maskedstore %c[%batch_idx, %a_row_idx, %b_col_idx_tail], %mask_vec, %result_vec_tail : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxi1>, vector<STEP_PLACEHOLDERxf32>
41-
}
42-
}
32+
%a_ele = affine.load %a[%batch_idx, %a_row_idx, %b_row_idx] : memref<?x?x?xf32>
33+
%a_vec = vector.broadcast %a_ele : f32 to vector<STEP_PLACEHOLDERxf32>
34+
%c_vec = affine.vector_load %c[%batch_idx, %a_row_idx, %b_col_idx * STEP_PLACEHOLDER] : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxf32>
35+
%result_vec = vector.fma %a_vec, %b_vec, %c_vec : vector<STEP_PLACEHOLDERxf32>
36+
affine.if #if_set(%b_col_idx)[%b_col] {
37+
affine.vector_store %result_vec, %c[%batch_idx, %a_row_idx, %b_col_idx * STEP_PLACEHOLDER] : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxf32>
38+
} else {
39+
vector.maskedstore %c[%batch_idx, %a_row_idx, %b_col_idx_tail], %mask_vec, %result_vec : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxi1>, vector<STEP_PLACEHOLDERxf32>
40+
}
4341
}
42+
}
4443
}
4544
}
4645
return

benchmarks/OpOptimization/MatMul/CMakeLists.txt

+31-1
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ function(build_batch_matmul_broadcast step)
125125
${BUDDY_MLIR_BUILD_DIR}/bin/buddy-opt
126126
-batchmatmul-optimize="step-placeholder=${step}"
127127
-expand-strided-metadata
128+
-affine-super-vectorize
128129
-lower-affine
129130
-convert-vector-to-llvm
130131
-finalize-memref-to-llvm
@@ -144,12 +145,40 @@ endfunction()
144145

145146
build_batch_matmul_broadcast(64)
146147

148+
function(build_batch_matmul_broadcast_omp step)
149+
add_custom_command(OUTPUT batch-matmul-broadcast-${step}-omp.o
150+
COMMAND cat ${BUDDY_SOURCE_DIR}/benchmarks/OpOptimization/MatMul/BatchMatMulBroadcast.mlir |
151+
sed 's/batch_matmul_broadcast_STEP_PLACEHOLDER/batch_matmul_broadcast_STEP_PLACEHOLDER_omp/g' |
152+
sed 's/STEP_PLACEHOLDER/${step}/g' |
153+
${BUDDY_MLIR_BUILD_DIR}/bin/buddy-opt
154+
-expand-strided-metadata
155+
-affine-super-vectorize
156+
-lower-affine
157+
-convert-scf-to-openmp
158+
-convert-vector-to-llvm
159+
-finalize-memref-to-llvm
160+
-convert-scf-to-cf
161+
-convert-linalg-to-llvm
162+
-llvm-request-c-wrappers
163+
-convert-openmp-to-llvm
164+
-convert-func-to-llvm
165+
-reconcile-unrealized-casts |
166+
${LLVM_MLIR_BINARY_DIR}/mlir-translate --mlir-to-llvmir |
167+
${CMAKE_CXX_COMPILER} -c -x ir -O3 --target=${BUDDY_OPT_TRIPLE} -fopenmp -march=native -flto
168+
-o ${BUDDY_BINARY_DIR}/../benchmarks/OpOptimization/MatMul/batch-matmul-broadcast-${step}-omp.o -
169+
)
170+
add_library(BatchMatMulBroadcast${step}OMP STATIC batch-matmul-broadcast-${step}-omp.o)
171+
set_target_properties(BatchMatMulBroadcast${step}OMP PROPERTIES LINKER_LANGUAGE CXX)
172+
endfunction()
173+
174+
build_batch_matmul_broadcast_omp(64)
175+
147176
add_executable(matmul-benchmark
148177
Main.cpp
149178
MatMulBenchmark.cpp
150179
)
151180

152-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native")
181+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -fopenmp -flto")
153182

154183
target_link_libraries(matmul-benchmark
155184
GoogleBenchmark
@@ -163,4 +192,5 @@ target_link_libraries(matmul-benchmark
163192
MatMulScalar
164193
BatchMatMulScalar
165194
BatchMatMulBroadcast64
195+
BatchMatMulBroadcast64OMP
166196
)

benchmarks/OpOptimization/MatMul/MatMulBenchmark.cpp

+23-6
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
#define M 64
2929
#define N 3136
3030
#define K 576
31-
#define BATCH_M 16
31+
#define BATCH_M 128
3232
#define BATCH_N 784
33-
#define BATCH_K 144
34-
#define BATCH 64
33+
#define BATCH_K 72
34+
#define BATCH 16
3535

3636
// Helper functions and variables.
3737
namespace {
@@ -72,6 +72,9 @@ void _mlir_ciface_batch_matmul_scalar(MemRef<float, 3> *A, MemRef<float, 3> *B,
7272
void _mlir_ciface_batch_matmul_broadcast_64(MemRef<float, 3> *A,
7373
MemRef<float, 3> *B,
7474
MemRef<float, 3> *C);
75+
void _mlir_ciface_batch_matmul_broadcast_64_omp(MemRef<float, 3> *A,
76+
MemRef<float, 3> *B,
77+
MemRef<float, 3> *C);
7578
}
7679

7780
#define DEFINE_MATMUL_BENCHMARK(name, func) \
@@ -115,6 +118,8 @@ DEFINE_MATMUL_BENCHMARK(SCALAR, _mlir_ciface_matmul_scalar)
115118
DEFINE_BATCH_MATMUL_BENCHMARK(SCALAR, _mlir_ciface_batch_matmul_scalar)
116119
DEFINE_BATCH_MATMUL_BENCHMARK(BROADCAST_64,
117120
_mlir_ciface_batch_matmul_broadcast_64)
121+
DEFINE_BATCH_MATMUL_BENCHMARK(BROADCAST_64_OMP,
122+
_mlir_ciface_batch_matmul_broadcast_64_omp)
118123
} // namespace
119124

120125
// Register benchmark cases.
@@ -129,6 +134,7 @@ BENCHMARK(BM_MATMUL_BROADCAST_256)->Unit(benchmark::kMillisecond);
129134
BENCHMARK(BM_MATMUL_BROADCAST_256)->Unit(benchmark::kMillisecond);
130135
BENCHMARK(BM_BATCH_MATMUL_SCALAR)->Unit(benchmark::kMillisecond);
131136
BENCHMARK(BM_BATCH_MATMUL_BROADCAST_64)->Unit(benchmark::kMillisecond);
137+
BENCHMARK(BM_BATCH_MATMUL_BROADCAST_64_OMP)->Unit(benchmark::kMillisecond);
132138

133139
// Correctness Verification
134140
// The verification does not affect the performance.
@@ -237,7 +243,6 @@ void matmul_verification() {
237243
? PASS
238244
: FAIL)
239245
<< std::endl;
240-
241246
std::cout << "-----------------------------------------------------------"
242247
<< std::endl;
243248
}
@@ -274,23 +279,35 @@ void batch_matmul_verification() {
274279
const int outputSize = BATCH * (BATCH_M) * (BATCH_N);
275280
MemRef<float, 3> outputScalar(sizesC, 0);
276281
MemRef<float, 3> outputBroadcast64(sizesC, 0);
282+
MemRef<float, 3> outputBroadcast64OMP(sizesC, 0);
277283

278284
// Perform all the matmul implementation.
279285
_mlir_ciface_batch_matmul_scalar(&inputAMemRef, &inputBMemRef, &outputScalar);
280286
_mlir_ciface_batch_matmul_broadcast_64(&inputAMemRef, &inputBMemRef,
281287
&outputBroadcast64);
288+
_mlir_ciface_batch_matmul_broadcast_64_omp(&inputAMemRef, &inputBMemRef,
289+
&outputBroadcast64OMP);
282290

283291
// Get the result array.
284292
auto resultScalar = outputScalar.getData();
285-
auto resultBroadcast16 = outputBroadcast64.getData();
293+
auto resultBroadcast64 = outputBroadcast64.getData();
294+
auto resultBroadcast64OMP = outputBroadcast64OMP.getData();
286295

287296
// Print the verfication result.
288297
std::cout << "Batch Matmul Broadcast 64 case: "
289-
<< (areArraysEqual(resultScalar, resultBroadcast16,
298+
<< (areArraysEqual(resultScalar, resultBroadcast64,
299+
outputSize / BATCH)
300+
? PASS
301+
: FAIL)
302+
<< std::endl;
303+
304+
std::cout << "Batch Matmul Broadcast 64 OpenMP case: "
305+
<< (areArraysEqual(resultScalar, resultBroadcast64OMP,
290306
outputSize / BATCH)
291307
? PASS
292308
: FAIL)
293309
<< std::endl;
310+
294311
std::cout << "-----------------------------------------------------------"
295312
<< std::endl;
296313
}

0 commit comments

Comments
 (0)