18
18
//
19
19
// ===----------------------------------------------------------------------===//
20
20
21
+ #include < array>
21
22
#include < benchmark/benchmark.h>
22
23
#include < buddy/Core/Container.h>
23
24
#include < iostream>
27
28
#define M 64
28
29
#define N 3136
29
30
#define K 576
31
+ #define BATCH_M 16
32
+ #define BATCH_N 784
33
+ #define BATCH_K 144
34
+ #define BATCH 64
30
35
31
36
// Helper functions and variables.
32
37
namespace {
@@ -62,6 +67,11 @@ void _mlir_ciface_matmul_broadcast_256(MemRef<float, 2> *A, MemRef<float, 2> *B,
62
67
MemRef<float , 2 > *C);
63
68
void _mlir_ciface_matmul_scalar (MemRef<float , 2 > *A, MemRef<float , 2 > *B,
64
69
MemRef<float , 2 > *C);
70
+ void _mlir_ciface_batch_matmul_scalar (MemRef<float , 3 > *A, MemRef<float , 3 > *B,
71
+ MemRef<float , 3 > *C);
72
+ void _mlir_ciface_batch_matmul_broadcast_64 (MemRef<float , 3 > *A,
73
+ MemRef<float , 3 > *B,
74
+ MemRef<float , 3 > *C);
65
75
}
66
76
67
77
#define DEFINE_MATMUL_BENCHMARK (name, func ) \
@@ -79,6 +89,21 @@ void _mlir_ciface_matmul_scalar(MemRef<float, 2> *A, MemRef<float, 2> *B,
79
89
} \
80
90
}
81
91
92
+ #define DEFINE_BATCH_MATMUL_BENCHMARK (name, func ) \
93
+ void BM_BATCH_MATMUL_##name(benchmark::State &state) { \
94
+ intptr_t sizesA[3 ] = {BATCH, BATCH_M, BATCH_K}; \
95
+ intptr_t sizesB[3 ] = {BATCH, BATCH_K, BATCH_N}; \
96
+ intptr_t sizesC[3 ] = {BATCH, BATCH_M, BATCH_N}; \
97
+ \
98
+ MemRef<float , 3 > A (sizesA, 1.0 ); \
99
+ MemRef<float , 3 > B (sizesB, 1.0 ); \
100
+ MemRef<float , 3 > C (sizesC, 0 ); \
101
+ \
102
+ for (auto _ : state) { \
103
+ func (&A, &B, &C); \
104
+ } \
105
+ }
106
+
82
107
DEFINE_MATMUL_BENCHMARK (OCV, _mlir_ciface_matmul_ocv)
83
108
DEFINE_MATMUL_BENCHMARK (TRANSFORM, _mlir_ciface_matmul_transform)
84
109
DEFINE_MATMUL_BENCHMARK (BROADCAST_16, _mlir_ciface_matmul_broadcast_16)
@@ -87,6 +112,9 @@ DEFINE_MATMUL_BENCHMARK(BROADCAST_64, _mlir_ciface_matmul_broadcast_64)
87
112
DEFINE_MATMUL_BENCHMARK (BROADCAST_128, _mlir_ciface_matmul_broadcast_128)
88
113
DEFINE_MATMUL_BENCHMARK (BROADCAST_256, _mlir_ciface_matmul_broadcast_256)
89
114
DEFINE_MATMUL_BENCHMARK (SCALAR, _mlir_ciface_matmul_scalar)
115
+ DEFINE_BATCH_MATMUL_BENCHMARK (SCALAR, _mlir_ciface_batch_matmul_scalar)
116
+ DEFINE_BATCH_MATMUL_BENCHMARK (BROADCAST_64,
117
+ _mlir_ciface_batch_matmul_broadcast_64)
90
118
} // namespace
91
119
92
120
// Register benchmark cases.
@@ -98,15 +126,18 @@ BENCHMARK(BM_MATMUL_BROADCAST_32)->Unit(benchmark::kMillisecond);
98
126
BENCHMARK (BM_MATMUL_BROADCAST_64)->Unit(benchmark::kMillisecond );
99
127
BENCHMARK (BM_MATMUL_BROADCAST_128)->Unit(benchmark::kMillisecond );
100
128
BENCHMARK (BM_MATMUL_BROADCAST_256)->Unit(benchmark::kMillisecond );
129
+ BENCHMARK (BM_MATMUL_BROADCAST_256)->Unit(benchmark::kMillisecond );
130
+ BENCHMARK (BM_BATCH_MATMUL_SCALAR)->Unit(benchmark::kMillisecond );
131
+ BENCHMARK (BM_BATCH_MATMUL_BROADCAST_64)->Unit(benchmark::kMillisecond );
101
132
102
- // / Correctness Verification
103
- // / The verification does not affect the performance.
104
- // / - Set the scalar case as the criteria.
105
- // / - Input elements are random numbers.
106
- // / - Output elements are initialized to zero.
107
- // / - Compare the output of various optimizations with the scalar version to
108
- // / verify correctness.
109
- void verification () {
133
+ // Correctness Verification
134
+ // The verification does not affect the performance.
135
+ // - Set the scalar case as the criteria.
136
+ // - Input elements are random numbers.
137
+ // - Output elements are initialized to zero.
138
+ // - Compare the output of various optimizations with the scalar version to
139
+ // verify correctness.
140
+ void matmul_verification () {
110
141
// Set the random number generator.
111
142
std::random_device rd;
112
143
std::mt19937 generator (rd ());
@@ -206,6 +237,60 @@ void verification() {
206
237
? PASS
207
238
: FAIL)
208
239
<< std::endl;
240
+
241
+ std::cout << " -----------------------------------------------------------"
242
+ << std::endl;
243
+ }
244
+
245
+ void batch_matmul_verification () {
246
+ // Set the random number generator.
247
+ std::random_device rd;
248
+ std::mt19937 generator (rd ());
249
+ std::uniform_int_distribution<int > distribution (1 , 100 );
250
+
251
+ // Set the layout sizes of input and output memref container.
252
+ intptr_t sizesA[3 ] = {BATCH, BATCH_M, BATCH_K};
253
+ intptr_t sizesB[3 ] = {BATCH, BATCH_K, BATCH_N};
254
+ intptr_t sizesC[3 ] = {BATCH, BATCH_M, BATCH_N};
255
+
256
+ // Generate input A and input B memref container with random numbers.
257
+ const int inputASize = BATCH * (BATCH_M) * (BATCH_K);
258
+ // float inputARand[inputASize];
259
+ auto inputARand = new std::array<float , inputASize>();
260
+ for (int i = 0 ; i < inputASize; ++i) {
261
+ (*inputARand)[i] = distribution (generator);
262
+ }
263
+ MemRef<float , 3 > inputAMemRef (inputARand->data (), sizesA);
264
+
265
+ const int inputBSize = BATCH * (BATCH_K) * (BATCH_N);
266
+ // float inputBRand[inputBSize];
267
+ auto inputBRand = new std::array<float , inputBSize>();
268
+ for (int i = 0 ; i < inputBSize; ++i) {
269
+ (*inputBRand)[i] = distribution (generator);
270
+ }
271
+ MemRef<float , 3 > inputBMemRef (inputBRand->data (), sizesB);
272
+
273
+ // Generate output memref container with zero.
274
+ const int outputSize = BATCH * (BATCH_M) * (BATCH_N);
275
+ MemRef<float , 3 > outputScalar (sizesC, 0 );
276
+ MemRef<float , 3 > outputBroadcast64 (sizesC, 0 );
277
+
278
+ // Perform all the matmul implementation.
279
+ _mlir_ciface_batch_matmul_scalar (&inputAMemRef, &inputBMemRef, &outputScalar);
280
+ _mlir_ciface_batch_matmul_broadcast_64 (&inputAMemRef, &inputBMemRef,
281
+ &outputBroadcast64);
282
+
283
+ // Get the result array.
284
+ auto resultScalar = outputScalar.getData ();
285
+ auto resultBroadcast16 = outputBroadcast64.getData ();
286
+
287
+ // Print the verfication result.
288
+ std::cout << " Batch Matmul Broadcast 64 case: "
289
+ << (areArraysEqual (resultScalar, resultBroadcast16,
290
+ outputSize / BATCH)
291
+ ? PASS
292
+ : FAIL)
293
+ << std::endl;
209
294
std::cout << " -----------------------------------------------------------"
210
295
<< std::endl;
211
296
}
0 commit comments