Skip to content

Commit eb1372f

Browse files
committed
add maxflops for gmma
1 parent 86163a7 commit eb1372f

File tree

133 files changed

+1229
-10
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

133 files changed

+1229
-10
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Source files split for parallel compilation
2+
# Use wildcard to automatically include all size-specific breakdown files
3+
SRC = MaxFlops_gmma.cu $(wildcard kernels/MaxFlops_gmma_*.cu)
4+
5+
EXE = MaxFlops_gmma
6+
7+
# Add include path for CUTLASS
8+
INCLUDE += -I$(GPUAPPS_ROOT)/src/cuda/cutlass-bench/include -I./
9+
10+
# GMMA is only supported in sm_90a
11+
ARCH?=sm_90a
12+
# Unset the CUDA_CPPFLAGS which is set based on CUDA version
13+
CUDA_CPPFLAGS=
14+
HOPPER_CUDA_CPPFLAGS=$(foreach arch,$(ARCH),-gencode=arch=compute_$(subst sm_,,$(arch)),code=$(arch))
15+
16+
# CUTLASS cute library requires C++17
17+
NVCC_FLAGS := $(HOPPER_CUDA_CPPFLAGS) -std=c++17
18+
19+
include ../../../common/common.mk
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#include <cuda.h>
2+
#include "MaxFlops_gmma.h"
3+
#include "../../../hw_def/hw_def.h"
4+
5+
int main(int argc, char *argv[])
6+
{
7+
intilizeDeviceProp(0, argc, argv);
8+
9+
// Run comprehensive sweep over all valid MMA operations
10+
run_all_wgmma_maxflops_tests();
11+
12+
return 0;
13+
}
Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
#ifndef MAXFLOPS_GMMA_DEF_H
2+
#define MAXFLOPS_GMMA_DEF_H
3+
4+
#include <cuda.h>
5+
#include <stdio.h>
6+
#include <stdlib.h>
7+
#include <cstdint>
8+
9+
10+
// Function declarations for test suites
11+
// These are defined in separate .cu files for parallel compilation
12+
13+
// F32 accumulator tests - TF32 x TF32 -> F32
14+
void run_f32tf32tf32tf32_64x8x8_test();
15+
void run_f32tf32tf32tf32_64x16x8_test();
16+
void run_f32tf32tf32tf32_64x32x8_test();
17+
void run_f32tf32tf32tf32_64x64x8_test();
18+
void run_f32tf32tf32tf32_64x96x8_test();
19+
void run_f32tf32tf32tf32_64x128x8_test();
20+
void run_f32tf32tf32tf32_64x192x8_test();
21+
void run_f32tf32tf32tf32_64x256x8_test();
22+
23+
// F32 accumulator tests - E4M3 x E4M3 -> F32
24+
void run_f32e4m3e4m3e4m3_64x8x32_test();
25+
void run_f32e4m3e4m3e4m3_64x16x32_test();
26+
void run_f32e4m3e4m3e4m3_64x32x32_test();
27+
void run_f32e4m3e4m3e4m3_64x64x32_test();
28+
void run_f32e4m3e4m3e4m3_64x96x32_test();
29+
void run_f32e4m3e4m3e4m3_64x128x32_test();
30+
void run_f32e4m3e4m3e4m3_64x192x32_test();
31+
void run_f32e4m3e4m3e4m3_64x256x32_test();
32+
33+
// F32 accumulator tests - E4M3 x E5M2 -> F32
34+
void run_f32e4m3e5m2e4m3_64x8x32_test();
35+
void run_f32e4m3e5m2e4m3_64x16x32_test();
36+
void run_f32e4m3e5m2e4m3_64x32x32_test();
37+
void run_f32e4m3e5m2e4m3_64x64x32_test();
38+
void run_f32e4m3e5m2e4m3_64x96x32_test();
39+
void run_f32e4m3e5m2e4m3_64x128x32_test();
40+
void run_f32e4m3e5m2e4m3_64x192x32_test();
41+
void run_f32e4m3e5m2e4m3_64x256x32_test();
42+
43+
// F32 accumulator tests - E5M2 x E4M3 -> F32
44+
void run_f32e5m2e4m3e5m2_64x8x32_test();
45+
void run_f32e5m2e4m3e5m2_64x16x32_test();
46+
void run_f32e5m2e4m3e5m2_64x32x32_test();
47+
void run_f32e5m2e4m3e5m2_64x64x32_test();
48+
void run_f32e5m2e4m3e5m2_64x96x32_test();
49+
void run_f32e5m2e4m3e5m2_64x128x32_test();
50+
void run_f32e5m2e4m3e5m2_64x192x32_test();
51+
void run_f32e5m2e4m3e5m2_64x256x32_test();
52+
53+
// F32 accumulator tests - E5M2 x E5M2 -> F32
54+
void run_f32e5m2e5m2e5m2_64x8x32_test();
55+
void run_f32e5m2e5m2e5m2_64x16x32_test();
56+
void run_f32e5m2e5m2e5m2_64x32x32_test();
57+
void run_f32e5m2e5m2e5m2_64x64x32_test();
58+
void run_f32e5m2e5m2e5m2_64x96x32_test();
59+
void run_f32e5m2e5m2e5m2_64x128x32_test();
60+
void run_f32e5m2e5m2e5m2_64x192x32_test();
61+
void run_f32e5m2e5m2e5m2_64x256x32_test();
62+
63+
// INT32 accumulator tests - INT8 x INT8 -> INT32
64+
void run_int32s8s8s8_64x8x32_test();
65+
void run_int32s8s8s8_64x16x32_test();
66+
void run_int32s8s8s8_64x32x32_test();
67+
void run_int32s8s8s8_64x64x32_test();
68+
void run_int32s8s8s8_64x96x32_test();
69+
void run_int32s8s8s8_64x128x32_test();
70+
void run_int32s8s8s8_64x192x32_test();
71+
void run_int32s8s8s8_64x256x32_test();
72+
73+
// INT32 accumulator tests - INT8 x UINT8 -> INT32
74+
void run_int32s8u8s8_64x8x32_test();
75+
void run_int32s8u8s8_64x16x32_test();
76+
void run_int32s8u8s8_64x32x32_test();
77+
void run_int32s8u8s8_64x64x32_test();
78+
void run_int32s8u8s8_64x96x32_test();
79+
void run_int32s8u8s8_64x128x32_test();
80+
void run_int32s8u8s8_64x192x32_test();
81+
void run_int32s8u8s8_64x256x32_test();
82+
83+
// INT32 accumulator tests - UINT8 x INT8 -> INT32
84+
void run_int32u8s8u8_64x8x32_test();
85+
void run_int32u8s8u8_64x16x32_test();
86+
void run_int32u8s8u8_64x32x32_test();
87+
void run_int32u8s8u8_64x64x32_test();
88+
void run_int32u8s8u8_64x96x32_test();
89+
void run_int32u8s8u8_64x128x32_test();
90+
void run_int32u8s8u8_64x192x32_test();
91+
void run_int32u8s8u8_64x256x32_test();
92+
93+
// INT32 accumulator tests - UINT8 x UINT8 -> INT32
94+
void run_int32u8u8u8_64x8x32_test();
95+
void run_int32u8u8u8_64x16x32_test();
96+
void run_int32u8u8u8_64x32x32_test();
97+
void run_int32u8u8u8_64x64x32_test();
98+
void run_int32u8u8u8_64x96x32_test();
99+
void run_int32u8u8u8_64x128x32_test();
100+
void run_int32u8u8u8_64x192x32_test();
101+
void run_int32u8u8u8_64x256x32_test();
102+
103+
// F16 accumulator tests (defined in lat_gmma_f16.cu)
104+
// F32 accumulator tests - FP16 x FP16 -> F32
105+
void run_f32f16f16_64x8x16_test();
106+
void run_f32f16f16_64x16x16_test();
107+
void run_f32f16f16_64x32x16_test();
108+
void run_f32f16f16_64x64x16_test();
109+
void run_f32f16f16_64x96x16_test();
110+
void run_f32f16f16_64x128x16_test();
111+
void run_f32f16f16_64x192x16_test();
112+
void run_f32f16f16_64x256x16_test();
113+
114+
// F32 accumulator tests - BF16 x BF16 -> F32
115+
void run_f32bf16bf16_64x8x16_test();
116+
void run_f32bf16bf16_64x16x16_test();
117+
void run_f32bf16bf16_64x32x16_test();
118+
void run_f32bf16bf16_64x64x16_test();
119+
void run_f32bf16bf16_64x96x16_test();
120+
void run_f32bf16bf16_64x128x16_test();
121+
void run_f32bf16bf16_64x192x16_test();
122+
void run_f32bf16bf16_64x256x16_test();
123+
124+
// F16 accumulator tests - FP16 x FP16 -> F16
125+
void run_f16f16f16_64x8x16_test();
126+
void run_f16f16f16_64x16x16_test();
127+
void run_f16f16f16_64x32x16_test();
128+
void run_f16f16f16_64x64x16_test();
129+
void run_f16f16f16_64x96x16_test();
130+
void run_f16f16f16_64x128x16_test();
131+
void run_f16f16f16_64x192x16_test();
132+
void run_f16f16f16_64x256x16_test();
133+
134+
// F16 accumulator tests - E4M3 x E4M3 -> F16
135+
void run_f16e4m3e4m3_64x8x32_test();
136+
void run_f16e4m3e4m3_64x16x32_test();
137+
void run_f16e4m3e4m3_64x32x32_test();
138+
void run_f16e4m3e4m3_64x64x32_test();
139+
void run_f16e4m3e4m3_64x96x32_test();
140+
void run_f16e4m3e4m3_64x128x32_test();
141+
void run_f16e4m3e4m3_64x192x32_test();
142+
void run_f16e4m3e4m3_64x256x32_test();
143+
144+
// F16 accumulator tests - E4M3 x E5M2 -> F16
145+
void run_f16e4m3e5m2_64x8x32_test();
146+
void run_f16e4m3e5m2_64x16x32_test();
147+
void run_f16e4m3e5m2_64x32x32_test();
148+
void run_f16e4m3e5m2_64x64x32_test();
149+
void run_f16e4m3e5m2_64x96x32_test();
150+
void run_f16e4m3e5m2_64x128x32_test();
151+
void run_f16e4m3e5m2_64x192x32_test();
152+
void run_f16e4m3e5m2_64x256x32_test();
153+
154+
// F16 accumulator tests - E5M2 x E4M3 -> F16
155+
void run_f16e5m2e4m3_64x8x32_test();
156+
void run_f16e5m2e4m3_64x16x32_test();
157+
void run_f16e5m2e4m3_64x32x32_test();
158+
void run_f16e5m2e4m3_64x64x32_test();
159+
void run_f16e5m2e4m3_64x96x32_test();
160+
void run_f16e5m2e4m3_64x128x32_test();
161+
void run_f16e5m2e4m3_64x192x32_test();
162+
void run_f16e5m2e4m3_64x256x32_test();
163+
164+
// F16 accumulator tests - E5M2 x E5M2 -> F16
165+
void run_f16e5m2e5m2_64x8x32_test();
166+
void run_f16e5m2e5m2_64x16x32_test();
167+
void run_f16e5m2e5m2_64x32x32_test();
168+
void run_f16e5m2e5m2_64x64x32_test();
169+
void run_f16e5m2e5m2_64x96x32_test();
170+
void run_f16e5m2e5m2_64x128x32_test();
171+
void run_f16e5m2e5m2_64x192x32_test();
172+
void run_f16e5m2e5m2_64x256x32_test();
173+
174+
void run_f16accumulator_tests() {
175+
run_f16f16f16_64x8x16_test();
176+
run_f16f16f16_64x16x16_test();
177+
run_f16f16f16_64x32x16_test();
178+
run_f16f16f16_64x64x16_test();
179+
run_f16f16f16_64x96x16_test();
180+
run_f16f16f16_64x128x16_test();
181+
run_f16f16f16_64x192x16_test();
182+
run_f16f16f16_64x256x16_test();
183+
run_f16e4m3e4m3_64x8x32_test();
184+
run_f16e4m3e4m3_64x16x32_test();
185+
run_f16e4m3e4m3_64x32x32_test();
186+
run_f16e4m3e4m3_64x64x32_test();
187+
run_f16e4m3e4m3_64x96x32_test();
188+
run_f16e4m3e4m3_64x128x32_test();
189+
run_f16e4m3e4m3_64x192x32_test();
190+
run_f16e4m3e4m3_64x256x32_test();
191+
run_f16e4m3e5m2_64x8x32_test();
192+
run_f16e4m3e5m2_64x16x32_test();
193+
run_f16e4m3e5m2_64x32x32_test();
194+
run_f16e4m3e5m2_64x64x32_test();
195+
run_f16e4m3e5m2_64x96x32_test();
196+
run_f16e4m3e5m2_64x128x32_test();
197+
run_f16e4m3e5m2_64x192x32_test();
198+
run_f16e4m3e5m2_64x256x32_test();
199+
run_f16e5m2e4m3_64x8x32_test();
200+
run_f16e5m2e4m3_64x16x32_test();
201+
run_f16e5m2e4m3_64x32x32_test();
202+
run_f16e5m2e4m3_64x64x32_test();
203+
run_f16e5m2e4m3_64x96x32_test();
204+
run_f16e5m2e4m3_64x128x32_test();
205+
run_f16e5m2e4m3_64x192x32_test();
206+
run_f16e5m2e4m3_64x256x32_test();
207+
run_f16e5m2e5m2_64x8x32_test();
208+
run_f16e5m2e5m2_64x16x32_test();
209+
run_f16e5m2e5m2_64x32x32_test();
210+
run_f16e5m2e5m2_64x64x32_test();
211+
run_f16e5m2e5m2_64x96x32_test();
212+
run_f16e5m2e5m2_64x128x32_test();
213+
run_f16e5m2e5m2_64x192x32_test();
214+
run_f16e5m2e5m2_64x256x32_test();
215+
}
216+
217+
void run_f32accumulator_tests() {
218+
run_f32tf32tf32tf32_64x8x8_test();
219+
run_f32tf32tf32tf32_64x16x8_test();
220+
run_f32tf32tf32tf32_64x32x8_test();
221+
run_f32tf32tf32tf32_64x64x8_test();
222+
run_f32tf32tf32tf32_64x96x8_test();
223+
run_f32tf32tf32tf32_64x128x8_test();
224+
run_f32tf32tf32tf32_64x192x8_test();
225+
run_f32tf32tf32tf32_64x256x8_test();
226+
run_f32f16f16_64x8x16_test();
227+
run_f32f16f16_64x16x16_test();
228+
run_f32f16f16_64x32x16_test();
229+
run_f32f16f16_64x64x16_test();
230+
run_f32f16f16_64x96x16_test();
231+
run_f32f16f16_64x128x16_test();
232+
run_f32f16f16_64x192x16_test();
233+
run_f32f16f16_64x256x16_test();
234+
run_f32bf16bf16_64x8x16_test();
235+
run_f32bf16bf16_64x16x16_test();
236+
run_f32bf16bf16_64x32x16_test();
237+
run_f32bf16bf16_64x64x16_test();
238+
run_f32bf16bf16_64x96x16_test();
239+
run_f32bf16bf16_64x128x16_test();
240+
run_f32bf16bf16_64x192x16_test();
241+
run_f32bf16bf16_64x256x16_test();
242+
run_f32e4m3e4m3e4m3_64x8x32_test();
243+
run_f32e4m3e4m3e4m3_64x16x32_test();
244+
run_f32e4m3e4m3e4m3_64x32x32_test();
245+
run_f32e4m3e4m3e4m3_64x64x32_test();
246+
run_f32e4m3e4m3e4m3_64x96x32_test();
247+
run_f32e4m3e4m3e4m3_64x128x32_test();
248+
run_f32e4m3e4m3e4m3_64x192x32_test();
249+
run_f32e4m3e4m3e4m3_64x256x32_test();
250+
run_f32e4m3e5m2e4m3_64x8x32_test();
251+
run_f32e4m3e5m2e4m3_64x16x32_test();
252+
run_f32e4m3e5m2e4m3_64x32x32_test();
253+
run_f32e4m3e5m2e4m3_64x64x32_test();
254+
run_f32e4m3e5m2e4m3_64x96x32_test();
255+
run_f32e4m3e5m2e4m3_64x128x32_test();
256+
run_f32e4m3e5m2e4m3_64x192x32_test();
257+
run_f32e4m3e5m2e4m3_64x256x32_test();
258+
run_f32e5m2e4m3e5m2_64x8x32_test();
259+
run_f32e5m2e4m3e5m2_64x16x32_test();
260+
run_f32e5m2e4m3e5m2_64x32x32_test();
261+
run_f32e5m2e4m3e5m2_64x64x32_test();
262+
run_f32e5m2e4m3e5m2_64x96x32_test();
263+
run_f32e5m2e4m3e5m2_64x128x32_test();
264+
run_f32e5m2e4m3e5m2_64x192x32_test();
265+
run_f32e5m2e4m3e5m2_64x256x32_test();
266+
run_f32e5m2e5m2e5m2_64x8x32_test();
267+
run_f32e5m2e5m2e5m2_64x16x32_test();
268+
run_f32e5m2e5m2e5m2_64x32x32_test();
269+
run_f32e5m2e5m2e5m2_64x64x32_test();
270+
run_f32e5m2e5m2e5m2_64x96x32_test();
271+
run_f32e5m2e5m2e5m2_64x128x32_test();
272+
run_f32e5m2e5m2e5m2_64x192x32_test();
273+
run_f32e5m2e5m2e5m2_64x256x32_test();
274+
}
275+
276+
void run_int32accumulator_tests() {
277+
run_int32s8s8s8_64x8x32_test();
278+
run_int32s8s8s8_64x16x32_test();
279+
run_int32s8s8s8_64x32x32_test();
280+
run_int32s8s8s8_64x64x32_test();
281+
run_int32s8s8s8_64x96x32_test();
282+
run_int32s8s8s8_64x128x32_test();
283+
run_int32s8s8s8_64x192x32_test();
284+
run_int32s8s8s8_64x256x32_test();
285+
run_int32s8u8s8_64x8x32_test();
286+
run_int32s8u8s8_64x16x32_test();
287+
run_int32s8u8s8_64x32x32_test();
288+
run_int32s8u8s8_64x64x32_test();
289+
run_int32s8u8s8_64x96x32_test();
290+
run_int32s8u8s8_64x128x32_test();
291+
run_int32s8u8s8_64x192x32_test();
292+
run_int32s8u8s8_64x256x32_test();
293+
run_int32u8s8u8_64x8x32_test();
294+
run_int32u8s8u8_64x16x32_test();
295+
run_int32u8s8u8_64x32x32_test();
296+
run_int32u8s8u8_64x64x32_test();
297+
run_int32u8s8u8_64x96x32_test();
298+
run_int32u8s8u8_64x128x32_test();
299+
run_int32u8s8u8_64x192x32_test();
300+
run_int32u8s8u8_64x256x32_test();
301+
run_int32u8u8u8_64x8x32_test();
302+
run_int32u8u8u8_64x16x32_test();
303+
run_int32u8u8u8_64x32x32_test();
304+
run_int32u8u8u8_64x64x32_test();
305+
run_int32u8u8u8_64x96x32_test();
306+
run_int32u8u8u8_64x128x32_test();
307+
run_int32u8u8u8_64x192x32_test();
308+
run_int32u8u8u8_64x256x32_test();
309+
}
310+
311+
// ============================================================================
312+
// Main Test Function - Run All Configurations
313+
// ============================================================================
314+
315+
inline void run_all_wgmma_maxflops_tests() {
316+
printf("\n");
317+
printf("================================================================================\n");
318+
printf(" SM90 GMMA Max Flops Comprehensive Sweep\n");
319+
printf("================================================================================\n");
320+
printf("\n");
321+
322+
// Run F32 accumulator tests
323+
run_f32accumulator_tests();
324+
325+
// Run F16 accumulator tests
326+
run_f16accumulator_tests();
327+
328+
// Run INT32 accumulator tests
329+
run_int32accumulator_tests();
330+
331+
printf("================================================================================\n");
332+
printf(" Sweep Complete\n");
333+
printf("================================================================================\n");
334+
printf("\n");
335+
}
336+
337+
// Legacy function signatures for compatibility
338+
float gmma_maxflops_ss() {
339+
printf("Running comprehensive WGMMA max flops tests...\n");
340+
run_all_wgmma_maxflops_tests();
341+
return 0.0f;
342+
}
343+
344+
#endif

0 commit comments

Comments
 (0)