18
18
//
19
19
// ===----------------------------------------------------------------------===//
20
20
21
+ #include < array>
21
22
#include < benchmark/benchmark.h>
22
23
#include < buddy/Core/Container.h>
23
24
#include < iostream>
27
28
#define M 64
28
29
#define N 3136
29
30
#define K 576
31
+ #define BATCH_M 16
32
+ #define BATCH_N 784
33
+ #define BATCH_K 144
34
+ #define BATCH 64
30
35
31
36
// Helper functions and variables.
32
37
namespace {
@@ -62,6 +67,11 @@ void _mlir_ciface_matmul_broadcast_256(MemRef<float, 2> *A, MemRef<float, 2> *B,
62
67
MemRef<float , 2 > *C);
63
68
void _mlir_ciface_matmul_scalar (MemRef<float , 2 > *A, MemRef<float , 2 > *B,
64
69
MemRef<float , 2 > *C);
70
+ void _mlir_ciface_batch_matmul_scalar (MemRef<float , 3 > *A, MemRef<float , 3 > *B,
71
+ MemRef<float , 3 > *C);
72
+ void _mlir_ciface_batch_matmul_broadcast_64 (MemRef<float , 3 > *A,
73
+ MemRef<float , 3 > *B,
74
+ MemRef<float , 3 > *C);
65
75
}
66
76
67
77
#define DEFINE_MATMUL_BENCHMARK (name, func ) \
@@ -79,6 +89,21 @@ void _mlir_ciface_matmul_scalar(MemRef<float, 2> *A, MemRef<float, 2> *B,
79
89
} \
80
90
}
81
91
92
+ #define DEFINE_BATCH_MATMUL_BENCHMARK (name, func ) \
93
+ void BBATCH_M_MATMUL_##name(benchmark::State &state) { \
94
+ intptr_t sizesA[3 ] = {BATCH, BATCH_M, BATCH_K}; \
95
+ intptr_t sizesB[3 ] = {BATCH, BATCH_K, BATCH_N}; \
96
+ intptr_t sizesC[3 ] = {BATCH, BATCH_M, BATCH_N}; \
97
+ \
98
+ MemRef<float , 3 > A (sizesA, 1.0 ); \
99
+ MemRef<float , 3 > B (sizesB, 1.0 ); \
100
+ MemRef<float , 3 > C (sizesC, 0 ); \
101
+ \
102
+ for (auto _ : state) { \
103
+ func (&A, &B, &C); \
104
+ } \
105
+ }
106
+
82
107
DEFINE_MATMUL_BENCHMARK (OCV, _mlir_ciface_matmul_ocv)
83
108
DEFINE_MATMUL_BENCHMARK (TRANSFORM, _mlir_ciface_matmul_transform)
84
109
DEFINE_MATMUL_BENCHMARK (BROADCAST_16, _mlir_ciface_matmul_broadcast_16)
@@ -87,6 +112,9 @@ DEFINE_MATMUL_BENCHMARK(BROADCAST_64, _mlir_ciface_matmul_broadcast_64)
87
112
DEFINE_MATMUL_BENCHMARK (BROADCAST_128, _mlir_ciface_matmul_broadcast_128)
88
113
DEFINE_MATMUL_BENCHMARK (BROADCAST_256, _mlir_ciface_matmul_broadcast_256)
89
114
DEFINE_MATMUL_BENCHMARK (SCALAR, _mlir_ciface_matmul_scalar)
115
+ DEFINE_BATCH_MATMUL_BENCHMARK (SCALAR, _mlir_ciface_batch_matmul_scalar) // batch_matmul
116
+ DEFINE_BATCH_MATMUL_BENCHMARK (BROADCAST_64,
117
+ _mlir_ciface_batch_matmul_broadcast_64) // batch_matmul
90
118
} // namespace
91
119
92
120
// Register benchmark cases.
@@ -98,15 +126,18 @@ BENCHMARK(BM_MATMUL_BROADCAST_32)->Unit(benchmark::kMillisecond);
98
126
BENCHMARK (BM_MATMUL_BROADCAST_64)->Unit(benchmark::kMillisecond );
99
127
BENCHMARK (BM_MATMUL_BROADCAST_128)->Unit(benchmark::kMillisecond );
100
128
BENCHMARK (BM_MATMUL_BROADCAST_256)->Unit(benchmark::kMillisecond );
129
+ BENCHMARK (BM_MATMUL_BROADCAST_256)->Unit(benchmark::kMillisecond );
130
+ BENCHMARK (BBATCH_M_MATMUL_SCALAR)->Unit(benchmark::kMillisecond ); // batch_matmul
131
+ BENCHMARK (BBATCH_M_MATMUL_BROADCAST_64)->Unit(benchmark::kMillisecond ); // batch_matmul
101
132
102
- // / Correctness Verification
103
- // / The verification does not affect the performance.
104
- // / - Set the scalar case as the criteria.
105
- // / - Input elements are random numbers.
106
- // / - Output elements are initialized to zero.
107
- // / - Compare the output of various optimizations with the scalar version to
108
- // / verify correctness.
109
- void verification () {
133
+ // Correctness Verification
134
+ // The verification does not affect the performance.
135
+ // - Set the scalar case as the criteria.
136
+ // - Input elements are random numbers.
137
+ // - Output elements are initialized to zero.
138
+ // - Compare the output of various optimizations with the scalar version to
139
+ // verify correctness.
140
+ void matmul_verification () {
110
141
// Set the random number generator.
111
142
std::random_device rd;
112
143
std::mt19937 generator (rd ());
@@ -206,6 +237,57 @@ void verification() {
206
237
? PASS
207
238
: FAIL)
208
239
<< std::endl;
240
+ }
241
+
242
+ void batch_matmul_verification () {
243
+ // Set the random number generator.
244
+ std::random_device rd;
245
+ std::mt19937 generator (rd ());
246
+ std::uniform_int_distribution<int > distribution (1 , 100 );
247
+
248
+ // Set the layout sizes of input and output memref container.
249
+ intptr_t sizesA[3 ] = {BATCH, BATCH_M, BATCH_K};
250
+ intptr_t sizesB[3 ] = {BATCH, BATCH_K, BATCH_N};
251
+ intptr_t sizesC[3 ] = {BATCH, BATCH_M, BATCH_N};
252
+
253
+ // Generate input A and input B memref container with random numbers.
254
+ const int inputASize = BATCH * (BATCH_M) * (BATCH_K);
255
+ // float inputARand[inputASize];
256
+ auto inputARand = new std::array<float , inputASize>();
257
+ for (int i = 0 ; i < inputASize; ++i) {
258
+ (*inputARand)[i] = distribution (generator);
259
+ }
260
+ MemRef<float , 3 > inputAMemRef (inputARand->data (), sizesA);
261
+
262
+ const int inputBSize = BATCH * (BATCH_K) * (BATCH_N);
263
+ // float inputBRand[inputBSize];
264
+ auto inputBRand = new std::array<float , inputBSize>();
265
+ for (int i = 0 ; i < inputBSize; ++i) {
266
+ (*inputBRand)[i] = distribution (generator);
267
+ }
268
+ MemRef<float , 3 > inputBMemRef (inputBRand->data (), sizesB);
269
+
270
+ // Generate output memref container with zero.
271
+ const int outputSize = BATCH * (BATCH_M) * (BATCH_N);
272
+ MemRef<float , 3 > outputScalar (sizesC, 0 );
273
+ MemRef<float , 3 > outputBroadcast64 (sizesC, 0 );
274
+
275
+ // Perform all the matmul implementation.
276
+ _mlir_ciface_batch_matmul_scalar (&inputAMemRef, &inputBMemRef, &outputScalar);
277
+ _mlir_ciface_batch_matmul_broadcast_64 (&inputAMemRef, &inputBMemRef,
278
+ &outputBroadcast64);
279
+
280
+ // Get the result array.
281
+ auto resultScalar = outputScalar.getData ();
282
+ auto resultBroadcast16 = outputBroadcast64.getData ();
283
+
284
+ // Print the verfication result.
285
+ std::cout << " Batch Broadcast 64 case: "
286
+ << (areArraysEqual (resultScalar, resultBroadcast16,
287
+ outputSize / BATCH)
288
+ ? PASS
289
+ : FAIL)
290
+ << std::endl;
209
291
std::cout << " -----------------------------------------------------------"
210
292
<< std::endl;
211
293
}
0 commit comments