-
Notifications
You must be signed in to change notification settings - Fork 741
/
mamf-finder.py
executable file
·373 lines (288 loc) · 13 KB
/
mamf-finder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
#!/usr/bin/env python
"""
This is Maximum Achievable Matmul FLOPS (MAMF) Finder
For discussion and multiple important nuances please refer to
https://github.com/stas00/ml-engineering/tree/master/compute/accelerator/benchmarks#maximum-achievable-matmul-flops-finder
Credits:
- Parts of this benchmark have been derived from https://github.com/EleutherAI/cookbook/tree/main/benchmarks/sizing (highly recommended!)
- Imtiaz Sajwani: HPU porting
- Xiaoyu Zhang https://github.com/BBuf - flexible dtype support
- Oren Leung https://github.com/OrenLeung - flagging the lack of cache/dest-matrix reset and suggesting a fix
"""
from pathlib import Path
import argparse
import datetime
import numpy as np
import os
import platform
import re
import shlex
import signal
import sys
import time
import torch
# important: when changing how the benchmark measures things bump up its version, so that the old
# reports could be differentiated from the new ones
benchmark_version = 2
has_hpu = False
try:
import habana_frameworks.torch as ht
if torch.hpu.is_available():
has_hpu = True
except ModuleNotFoundError:
pass
file_dir = os.path.abspath(os.path.dirname(__file__))
def get_torch_dtype(dtype_str):
"""Convert string dtype to torch dtype object."""
try:
return getattr(torch, dtype_str)
except AttributeError:
raise ValueError(f"Unsupported dtype: {dtype_str}. Must be a valid torch dtype name.")
### Architecture specific helper classes ###
class Arch:
def __init__(self):
self.arch = "unknown"
def __repr__(self):
return self.arch
class CUDAArch(Arch):
""" shared with CUDA and ROCm: NVIDIA + AMD """
def __init__(self):
if torch.version.hip is not None:
self.arch = "rocm"
else:
self.arch = "cuda"
def device(self):
return torch.device('cuda:0')
def name(self):
return self.arch
def device_info(self):
return torch.cuda.get_device_properties(device)
def compute_info(self):
if self.arch == "rocm":
return f"hip={torch.version.hip}, cuda={torch.version.cuda}"
else:
return f"cuda={torch.version.cuda}"
def event(self, enable_timing=True):
return torch.cuda.Event(enable_timing)
def synchronize(self):
torch.cuda.synchronize()
class HPUArch(Arch):
""" Intel Gaudi* """
def __init__(self):
self.arch = "hpu"
def device(self):
return torch.device('hpu')
def name(self):
return self.arch
def device_info(self):
return torch.hpu.get_device_properties(device)
def compute_info(self):
return f"hpu={torch.hpu}"
def event(self, enable_timing=True):
return ht.hpu.Event(enable_timing)
def synchronize(self):
ht.hpu.synchronize()
def get_accelerator_arch():
"""
returns: CUDAArch or HPUArch object
"""
# cuda / rocm
if torch.cuda.is_available():
return CUDAArch()
# hpu
if has_hpu:
return HPUArch()
raise ValueError("Currently only cuda, rocm and hpu are supported")
arch = get_accelerator_arch()
### Helper classes ###
class Tee(object):
def __init__(self, filename, verbose):
Path(filename).resolve().parent.mkdir(parents=True, exist_ok=True)
self.file = open(filename, "w")
self.verbose = verbose
if self.verbose:
self.stdout = sys.stdout
def write(self, message):
if self.verbose:
self.stdout.write(message)
# replace `\r` and `033\[K` which are nice in the console, but we don't want those in the log file
message = re.sub(r"(\r|\033\[K)", "\n", message)
self.file.write(message)
def flush(self):
self.file.flush()
if self.verbose:
self.stdout.flush()
def print_benchmark_header(dtype, device, notes="None"):
device_info = arch.device_info()
compute_info = arch.compute_info()
print(f"""
Benchmark started on {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}
** Command line:
{sys.executable} {" ".join(map(shlex.quote, sys.argv))}
** Dtype: {dtype}
** Platform/Device info:
{" ".join(platform.uname())}
{device_info}
** Critical software versions:
torch={torch.__version__}
{compute_info}
** Additional notes:
benchmark version: {benchmark_version}
{notes}
{"-" * 80}
""")
# Benchmark of a basic GEMM
def benchmark_mm(m, n, k, dtype, device, num_iterations, num_warmup_iterations):
start = arch.event(enable_timing=True)
end = arch.event(enable_timing=True)
# this will be used to write to the accelerator between each benchmark iteration to emulate cache reset.
# On AMD this will really be an l3/LLC cache - later need to figure out how to get the maximum cache
# size automatically, according to this table 256MB is the highest value so far across all
# recent accelerators:
# https://github.com/stas00/ml-engineering/tree/master/compute/accelerator#caches
l2_cache_size_in_mbs = 256
l2_cache = torch.empty(int(l2_cache_size_in_mbs * 2**20 / 4), dtype=torch.int, device=device)
C = torch.empty(m, n, dtype=dtype, device=device).contiguous()
# this random matrix will be used in the loop to ensure that C gets actually written to, as
# otherwise the rerun results will be always the same and no power will be drawn to write - would lead
# to invalid emulation of a real use case
C_rand = torch.randn(m, n, device=device).to(dtype=dtype).contiguous()
def time_it(iters=1):
def decorator(func):
def func_wrapper(*args, **kwargs):
start_events = [arch.event(enable_timing=True) for _ in range(iters)]
end_events = [arch.event(enable_timing=True) for _ in range(iters)]
for i in range(iters):
with torch.no_grad():
l2_cache.zero_() # clear accelerator cache
C.copy_(C_rand) # re-randomize the target matrix
start_events[i].record()
ret = func(*args, **kwargs)
end_events[i].record()
arch.synchronize()
times = np.array([s.elapsed_time(e) for s, e in zip(start_events, end_events)])
return times
return func_wrapper
return decorator
total_iterations = num_iterations + num_warmup_iterations
if dtype == torch.float8_e4m3fn:
A = torch.randn(m, k, dtype=torch.float32, device=device).contiguous()
B = torch.randn(n, k, dtype=torch.float32, device=device).contiguous().t()
scale = torch.tensor([1.0]).to(device)
A = A.to(torch.float8_e4m3fn)
B = B.to(torch.float8_e4m3fn)
# some torch versions require the scale arg, some don't so discover which is required
try:
C = torch._scaled_mm(A, B)
@time_it(total_iterations)
def time_iterations():
C = torch._scaled_mm(A, B)
except:
@time_it(total_iterations)
def time_iterations():
C = torch._scaled_mm(A, B, scale, scale)
else:
A = torch.randn(m, k, dtype=dtype, device=device).contiguous()
B = torch.randn(n, k, dtype=dtype, device=device).contiguous().t()
@time_it(total_iterations)
def time_iterations():
torch.mm(A, B, out=C)
times = time_iterations()[num_warmup_iterations:]
flos = 2 * m * n * k
mean_elapsed_time = np.mean(times)/1000
mean_tflops = flos / (mean_elapsed_time * 10**12)
median_elapsed_time = np.median(times)/1000
median_tflops = flos / (median_elapsed_time * 10**12)
min_elapsed_time = np.amin(times)/1000
max_tflops = flos / (min_elapsed_time * 10**12)
return mean_tflops, median_tflops, max_tflops
if __name__ == '__main__':
parser = argparse.ArgumentParser()
m_group = parser.add_mutually_exclusive_group(required=True)
m_group.add_argument("--m", nargs="+", type=int, help='The first dimension of the GEMM, enter any number of arguments')
m_group.add_argument("--m_range", nargs='+', type=int, help="The first dimension of the GEMM, [start,stop,step]")
n_group = parser.add_mutually_exclusive_group(required=True)
n_group.add_argument("--n", nargs="*", type=int, help='The last dimension of the GEMM, enter any number of arguments')
n_group.add_argument("--n_range", nargs='+', type=int, help="The last dimension of the GEMM, [start,stop,step]")
k_group = parser.add_mutually_exclusive_group(required=True)
k_group.add_argument("--k", nargs="*", type=int, help='The shared (reduction) dimension of the GEMM, enter any number of arguments')
k_group.add_argument("--k_range", nargs='+', type=int, help="The shared (reduction) dimension of the GEMM, [start,stop,step]")
parser.add_argument("--num_iterations", type=int, default=100, help='The number of iterations used to benchmark each GEMM')
parser.add_argument("--num_warmup_iterations", type=int, default=50, help='The number of warmup iterations')
parser.add_argument("--cuda_device", type=int, default=0, help="The cuda device to run the benchmark on")
parser.add_argument("--output_file", type=str, default=f"{file_dir}/results/mm.out")
parser.add_argument("--notes", type=str, default="", help="benchmark-specific notes to add to the output_file's header")
parser.add_argument("--verbose", default=True, action=argparse.BooleanOptionalAction, help='log to stdout besides output_file?')
parser.add_argument("--dtype", type=str, default="bfloat16",
help="Data type to use for the benchmark (e.g. float16, bfloat16, float32)")
args = parser.parse_args()
m = args.m
n = args.n
k = args.k
dtype = get_torch_dtype(args.dtype)
device = arch.device()
if m is None:
start, stop, step = args.m_range
if start == 0: # can't have a 0 dimension
start = step
m = np.arange(start, stop, step)
if n is None:
start, stop, step = args.n_range
if start == 0: # can't have a 0 dimension
start = step
n = np.arange(start, stop, step)
if k is None:
start, stop, step = args.k_range
if start == 0: # can't have a 0 dimension
start = step
k = np.arange(start, stop, step)
sys.stdout = Tee(args.output_file, args.verbose)
print_benchmark_header(dtype, device, args.notes)
# this is useful for when one wants to interrupt the run - and still report the best outcome so far
def sigkill_handler(signum, frame):
finish()
sys.exit(1)
signal.signal(signal.SIGINT, sigkill_handler)
best_tflops = dict(max=0, median=0, mean=0)
best_config = dict(max="", median="", mean="")
num_shapes = 0
start_time = time.time()
def finish():
time_delta = time.time() - start_time
time_str = str(datetime.timedelta(seconds=time_delta)).split(".")[0]
print("", end="\033[K")
print(f"""
Tried {num_shapes} shapes => the best outcomes were:
mean: {best_tflops["mean"]:.1f} TFLOPS @ {best_config["mean"]}
median: {best_tflops["median"]:.1f} TFLOPS @ {best_config["median"]}
max: {best_tflops["max"]:.1f} TFLOPS @ {best_config["max"]}
""")
print(f"Elapsed time: {time_str}")
# XXX: the transpose version seemed to work better for MI300X
# always start with additional warmup iterations to give fare results, otherwise based on
# rerunning this benchmark many times - a cold accelerator gives a higher score on say a single
# shape, than the same shape run after a dozen of other shapes
accelerator_warmup_seconds = 30
end_time = time.monotonic() + accelerator_warmup_seconds
print(f"Warming up the accelerator for {accelerator_warmup_seconds} secs ... ", end="", flush=True)
while time.monotonic() < end_time:
_ = benchmark_mm(m[0], n[0], k[0], dtype, device, args.num_iterations, args.num_warmup_iterations)
print("accelerator warmup finished")
# loop through all sizes to benchmark
for M in m:
for N in n:
for K in k:
num_shapes += 1
mean_tflops, median_tflops, max_tflops = benchmark_mm(M, N, K, dtype, device, args.num_iterations, args.num_warmup_iterations)
cur_config = f"{M}x{N}x{K}"
if median_tflops > best_tflops["median"]:
best_tflops["median"] = median_tflops
best_config["median"] = f"{cur_config} (MxNxK)"
if mean_tflops > best_tflops["mean"]:
best_tflops["mean"] = mean_tflops
best_config["mean"] = f"{cur_config} (MxNxK)"
if max_tflops > best_tflops["max"]:
best_tflops["max"] = max_tflops
best_config["max"] = f"{cur_config} (MxNxK)"
print(f"{num_shapes:>6} | {mean_tflops:6.1f}(mean) {median_tflops:6.1f}(median) {max_tflops:6.1f}(max) @ {cur_config:<20} | best: {best_tflops['mean']:6.1f}(mean) {best_tflops['median']:6.1f}(median) {best_tflops['max']:6.1f}(max)TFLOPS", end="\r")
finish()