Skip to content

Commit 7e7c64c

Browse files
committed
[OpOptimization] Add BatchMatMul benchmark.
1 parent 5ab2b09 commit 7e7c64c

File tree

5 files changed

+198
-10
lines changed

5 files changed

+198
-10
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module{
2+
func.func @bm_batch_matmul(%a : memref<?x?x?xf32>, %b : memref<?x?x?xf32>, %c : memref<?x?x?xf32>) {
3+
linalg.batch_matmul
4+
ins(%a, %b: memref<?x?x?xf32>, memref<?x?x?xf32>)
5+
outs(%c: memref<?x?x?xf32>)
6+
return
7+
}
8+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// The MLIR prototype of batchmatmul-optimize in buddy-opt
2+
3+
#map = affine_map<(d0) -> (d0 ceildiv STEP_PLACEHOLDER)>
4+
func.func @batch_matmul_broadcast_STEP_PLACEHOLDER(%a : memref<?x?x?xf32>, %b : memref<?x?x?xf32>, %c : memref<?x?x?xf32>) {
5+
%c0 = arith.constant 0 : index
6+
%c1 = arith.constant 1 : index
7+
%c2 = arith.constant 2 : index
8+
%step = arith.constant STEP_PLACEHOLDER : index
9+
%c0_f32 = arith.constant 0.0 : f32
10+
%c0_f32_vec = vector.splat %c0_f32 : vector<STEP_PLACEHOLDERxf32>
11+
12+
%a_row = memref.dim %a, %c1 : memref<?x?x?xf32>
13+
%a_col = memref.dim %a, %c2 : memref<?x?x?xf32>
14+
%b_row = memref.dim %b, %c1 : memref<?x?x?xf32>
15+
%b_col = memref.dim %b, %c2 : memref<?x?x?xf32>
16+
%batch = memref.dim %a, %c0 : memref<?x?x?xf32>
17+
18+
affine.parallel (%batch_idx) = (0) to (%batch){
19+
affine.prefetch %a[%batch_idx, %a_row, %a_col], read, locality<3>, data : memref<?x?x?xf32> //about 3% faster
20+
affine.for %b_row_idx = 0 to %b_row {
21+
affine.for %a_row_idx = 0 to %a_row {
22+
affine.for %b_col_idx = 0 to #map(%b_col) {
23+
%a_ele = affine.load %a[%batch_idx, %a_row_idx, %b_row_idx] : memref<?x?x?xf32>
24+
%a_vec = vector.broadcast %a_ele : f32 to vector<STEP_PLACEHOLDERxf32>
25+
// Check tail.
26+
%b_col_cur = arith.muli %b_col_idx, %step : index
27+
%tail_len = arith.subi %b_col, %b_col_cur : index
28+
%tail_flag = arith.cmpi sge, %tail_len, %step : index
29+
scf.if %tail_flag {
30+
%b_vec = affine.vector_load %b[%batch_idx, %b_row_idx, %b_col_idx * STEP_PLACEHOLDER] : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxf32>
31+
%c_vec = affine.vector_load %c[%batch_idx, %a_row_idx, %b_col_idx * STEP_PLACEHOLDER] : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxf32>
32+
%result_vec = vector.fma %a_vec, %b_vec, %c_vec : vector<STEP_PLACEHOLDERxf32>
33+
affine.vector_store %result_vec, %c[%batch_idx, %a_row_idx, %b_col_idx * STEP_PLACEHOLDER] : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxf32>
34+
} else {
35+
%mask_vec = vector.create_mask %tail_len : vector<STEP_PLACEHOLDERxi1>
36+
%b_col_idx_tail = arith.muli %b_col_idx, %step : index
37+
%b_vec_tail = vector.maskedload %b[%batch_idx, %b_row_idx, %b_col_idx_tail], %mask_vec, %c0_f32_vec : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxi1>, vector<STEP_PLACEHOLDERxf32> into vector<STEP_PLACEHOLDERxf32>
38+
%c_vec_tail = vector.maskedload %c[%batch_idx, %a_row_idx, %b_col_idx_tail], %mask_vec, %c0_f32_vec : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxi1>, vector<STEP_PLACEHOLDERxf32> into vector<STEP_PLACEHOLDERxf32>
39+
%result_vec_tail = vector.fma %a_vec, %b_vec_tail, %c_vec_tail : vector<STEP_PLACEHOLDERxf32>
40+
vector.maskedstore %c[%batch_idx, %a_row_idx, %b_col_idx_tail], %mask_vec, %result_vec_tail : memref<?x?x?xf32>, vector<STEP_PLACEHOLDERxi1>, vector<STEP_PLACEHOLDERxf32>
41+
}
42+
}
43+
}
44+
}
45+
}
46+
return
47+
}

benchmarks/OpOptimization/MatMul/CMakeLists.txt

+49
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,53 @@ add_custom_command(OUTPUT matmul-scalar.o
9797
add_library(MatMulScalar STATIC matmul-scalar.o)
9898
set_target_properties(MatMulScalar PROPERTIES LINKER_LANGUAGE CXX)
9999

100+
add_custom_command(OUTPUT batch-matmul-scalar.o
101+
COMMAND cat ${BUDDY_SOURCE_DIR}/benchmarks/OpOptimization/MatMul/BatchMatMul.mlir |
102+
sed 's/bm_batch_matmul/batch_matmul_scalar/' |
103+
${LLVM_MLIR_BINARY_DIR}/mlir-opt
104+
-convert-linalg-to-loops
105+
-lower-affine
106+
-convert-scf-to-cf
107+
-convert-vector-to-llvm
108+
-finalize-memref-to-llvm
109+
-convert-arith-to-llvm
110+
-llvm-request-c-wrappers
111+
-convert-func-to-llvm
112+
-reconcile-unrealized-casts |
113+
${LLVM_MLIR_BINARY_DIR}/mlir-translate --mlir-to-llvmir |
114+
${LLVM_MLIR_BINARY_DIR}/llc -O3 -mtriple=${BUDDY_OPT_TRIPLE}
115+
-mattr=${BUDDY_OPT_ATTR} --filetype=obj
116+
-o ${BUDDY_BINARY_DIR}/../benchmarks/OpOptimization/MatMul/batch-matmul-scalar.o
117+
)
118+
add_library(BatchMatMulScalar STATIC batch-matmul-scalar.o)
119+
set_target_properties(BatchMatMulScalar PROPERTIES LINKER_LANGUAGE CXX)
120+
121+
function(build_batch_matmul_broadcast step)
122+
add_custom_command(OUTPUT batch-matmul-broadcast-${step}.o
123+
COMMAND cat ${BUDDY_SOURCE_DIR}/benchmarks/OpOptimization/MatMul/BatchMatMul.mlir |
124+
sed 's/bm_batch_matmul/batch_matmul_broadcast_${step}/g' |
125+
${BUDDY_MLIR_BUILD_DIR}/bin/buddy-opt
126+
-batchmatmul-optimize="step-placeholder=${step}"
127+
-expand-strided-metadata
128+
-lower-affine
129+
-convert-vector-to-llvm
130+
-finalize-memref-to-llvm
131+
-convert-scf-to-cf
132+
-convert-linalg-to-llvm
133+
-llvm-request-c-wrappers
134+
-convert-func-to-llvm
135+
-reconcile-unrealized-casts |
136+
${LLVM_MLIR_BINARY_DIR}/mlir-translate --mlir-to-llvmir |
137+
${LLVM_MLIR_BINARY_DIR}/llc -O3 -mtriple=${BUDDY_OPT_TRIPLE}
138+
-mattr=${BUDDY_OPT_ATTR} --filetype=obj
139+
-o ${BUDDY_BINARY_DIR}/../benchmarks/OpOptimization/MatMul/batch-matmul-broadcast-${step}.o
140+
)
141+
add_library(BatchMatMulBroadcast${step} STATIC batch-matmul-broadcast-${step}.o)
142+
set_target_properties(BatchMatMulBroadcast${step} PROPERTIES LINKER_LANGUAGE CXX)
143+
endfunction()
144+
145+
build_batch_matmul_broadcast(64)
146+
100147
add_executable(matmul-benchmark
101148
Main.cpp
102149
MatMulBenchmark.cpp
@@ -114,4 +161,6 @@ target_link_libraries(matmul-benchmark
114161
MatMulBroadcast128
115162
MatMulBroadcast256
116163
MatMulScalar
164+
BatchMatMulScalar
165+
BatchMatMulBroadcast64
117166
)

benchmarks/OpOptimization/MatMul/Main.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@
2020

2121
#include <benchmark/benchmark.h>
2222

23-
void verification();
23+
void matmul_verification();
24+
void batch_matmul_verification();
2425

2526
int main(int argc, char **argv) {
2627
// Run benchmark.
2728
::benchmark::Initialize(&argc, argv);
2829
::benchmark::RunSpecifiedBenchmarks();
2930
// Run correctness verification.
30-
verification();
31+
matmul_verification();
32+
batch_matmul_verification();
3133
return 0;
3234
}

benchmarks/OpOptimization/MatMul/MatMulBenchmark.cpp

+90-8
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
//
1919
//===----------------------------------------------------------------------===//
2020

21+
#include <array>
2122
#include <benchmark/benchmark.h>
2223
#include <buddy/Core/Container.h>
2324
#include <iostream>
@@ -27,6 +28,10 @@
2728
#define M 64
2829
#define N 3136
2930
#define K 576
31+
#define BATCH_M 16
32+
#define BATCH_N 784
33+
#define BATCH_K 144
34+
#define BATCH 64
3035

3136
// Helper functions and variables.
3237
namespace {
@@ -62,6 +67,11 @@ void _mlir_ciface_matmul_broadcast_256(MemRef<float, 2> *A, MemRef<float, 2> *B,
6267
MemRef<float, 2> *C);
6368
void _mlir_ciface_matmul_scalar(MemRef<float, 2> *A, MemRef<float, 2> *B,
6469
MemRef<float, 2> *C);
70+
void _mlir_ciface_batch_matmul_scalar(MemRef<float, 3> *A, MemRef<float, 3> *B,
71+
MemRef<float, 3> *C);
72+
void _mlir_ciface_batch_matmul_broadcast_64(MemRef<float, 3> *A,
73+
MemRef<float, 3> *B,
74+
MemRef<float, 3> *C);
6575
}
6676

6777
#define DEFINE_MATMUL_BENCHMARK(name, func) \
@@ -79,6 +89,21 @@ void _mlir_ciface_matmul_scalar(MemRef<float, 2> *A, MemRef<float, 2> *B,
7989
} \
8090
}
8191

92+
#define DEFINE_BATCH_MATMUL_BENCHMARK(name, func) \
93+
void BBATCH_M_MATMUL_##name(benchmark::State &state) { \
94+
intptr_t sizesA[3] = {BATCH, BATCH_M, BATCH_K}; \
95+
intptr_t sizesB[3] = {BATCH, BATCH_K, BATCH_N}; \
96+
intptr_t sizesC[3] = {BATCH, BATCH_M, BATCH_N}; \
97+
\
98+
MemRef<float, 3> A(sizesA, 1.0); \
99+
MemRef<float, 3> B(sizesB, 1.0); \
100+
MemRef<float, 3> C(sizesC, 0); \
101+
\
102+
for (auto _ : state) { \
103+
func(&A, &B, &C); \
104+
} \
105+
}
106+
82107
DEFINE_MATMUL_BENCHMARK(OCV, _mlir_ciface_matmul_ocv)
83108
DEFINE_MATMUL_BENCHMARK(TRANSFORM, _mlir_ciface_matmul_transform)
84109
DEFINE_MATMUL_BENCHMARK(BROADCAST_16, _mlir_ciface_matmul_broadcast_16)
@@ -87,6 +112,9 @@ DEFINE_MATMUL_BENCHMARK(BROADCAST_64, _mlir_ciface_matmul_broadcast_64)
87112
DEFINE_MATMUL_BENCHMARK(BROADCAST_128, _mlir_ciface_matmul_broadcast_128)
88113
DEFINE_MATMUL_BENCHMARK(BROADCAST_256, _mlir_ciface_matmul_broadcast_256)
89114
DEFINE_MATMUL_BENCHMARK(SCALAR, _mlir_ciface_matmul_scalar)
115+
DEFINE_BATCH_MATMUL_BENCHMARK(SCALAR, _mlir_ciface_batch_matmul_scalar) // batch_matmul
116+
DEFINE_BATCH_MATMUL_BENCHMARK(BROADCAST_64,
117+
_mlir_ciface_batch_matmul_broadcast_64) // batch_matmul
90118
} // namespace
91119

92120
// Register benchmark cases.
@@ -98,15 +126,18 @@ BENCHMARK(BM_MATMUL_BROADCAST_32)->Unit(benchmark::kMillisecond);
98126
BENCHMARK(BM_MATMUL_BROADCAST_64)->Unit(benchmark::kMillisecond);
99127
BENCHMARK(BM_MATMUL_BROADCAST_128)->Unit(benchmark::kMillisecond);
100128
BENCHMARK(BM_MATMUL_BROADCAST_256)->Unit(benchmark::kMillisecond);
129+
BENCHMARK(BM_MATMUL_BROADCAST_256)->Unit(benchmark::kMillisecond);
130+
BENCHMARK(BBATCH_M_MATMUL_SCALAR)->Unit(benchmark::kMillisecond); // batch_matmul
131+
BENCHMARK(BBATCH_M_MATMUL_BROADCAST_64)->Unit(benchmark::kMillisecond); // batch_matmul
101132

102-
/// Correctness Verification
103-
/// The verification does not affect the performance.
104-
/// - Set the scalar case as the criteria.
105-
/// - Input elements are random numbers.
106-
/// - Output elements are initialized to zero.
107-
/// - Compare the output of various optimizations with the scalar version to
108-
/// verify correctness.
109-
void verification() {
133+
// Correctness Verification
134+
// The verification does not affect the performance.
135+
// - Set the scalar case as the criteria.
136+
// - Input elements are random numbers.
137+
// - Output elements are initialized to zero.
138+
// - Compare the output of various optimizations with the scalar version to
139+
// verify correctness.
140+
void matmul_verification() {
110141
// Set the random number generator.
111142
std::random_device rd;
112143
std::mt19937 generator(rd());
@@ -206,6 +237,57 @@ void verification() {
206237
? PASS
207238
: FAIL)
208239
<< std::endl;
240+
}
241+
242+
void batch_matmul_verification() {
243+
// Set the random number generator.
244+
std::random_device rd;
245+
std::mt19937 generator(rd());
246+
std::uniform_int_distribution<int> distribution(1, 100);
247+
248+
// Set the layout sizes of input and output memref container.
249+
intptr_t sizesA[3] = {BATCH, BATCH_M, BATCH_K};
250+
intptr_t sizesB[3] = {BATCH, BATCH_K, BATCH_N};
251+
intptr_t sizesC[3] = {BATCH, BATCH_M, BATCH_N};
252+
253+
// Generate input A and input B memref container with random numbers.
254+
const int inputASize = BATCH * (BATCH_M) * (BATCH_K);
255+
// float inputARand[inputASize];
256+
auto inputARand = new std::array<float, inputASize>();
257+
for (int i = 0; i < inputASize; ++i) {
258+
(*inputARand)[i] = distribution(generator);
259+
}
260+
MemRef<float, 3> inputAMemRef(inputARand->data(), sizesA);
261+
262+
const int inputBSize = BATCH * (BATCH_K) * (BATCH_N);
263+
// float inputBRand[inputBSize];
264+
auto inputBRand = new std::array<float, inputBSize>();
265+
for (int i = 0; i < inputBSize; ++i) {
266+
(*inputBRand)[i] = distribution(generator);
267+
}
268+
MemRef<float, 3> inputBMemRef(inputBRand->data(), sizesB);
269+
270+
// Generate output memref container with zero.
271+
const int outputSize = BATCH * (BATCH_M) * (BATCH_N);
272+
MemRef<float, 3> outputScalar(sizesC, 0);
273+
MemRef<float, 3> outputBroadcast64(sizesC, 0);
274+
275+
// Perform all the matmul implementation.
276+
_mlir_ciface_batch_matmul_scalar(&inputAMemRef, &inputBMemRef, &outputScalar);
277+
_mlir_ciface_batch_matmul_broadcast_64(&inputAMemRef, &inputBMemRef,
278+
&outputBroadcast64);
279+
280+
// Get the result array.
281+
auto resultScalar = outputScalar.getData();
282+
auto resultBroadcast16 = outputBroadcast64.getData();
283+
284+
// Print the verfication result.
285+
std::cout << "Batch Broadcast 64 case: "
286+
<< (areArraysEqual(resultScalar, resultBroadcast16,
287+
outputSize / BATCH)
288+
? PASS
289+
: FAIL)
290+
<< std::endl;
209291
std::cout << "-----------------------------------------------------------"
210292
<< std::endl;
211293
}

0 commit comments

Comments
 (0)