Skip to content

Commit a2e6a3d

Browse files
committed
pytorch benchmark and example updates
- added much more explanations to benchmark outputs - added flash attention benchmark which tests both on cpu and gpu the 4 different algorithms for dot product - modified the simple cpu vs gpu benchmark to also output meaningful results both for the cpu and gpus Signed-off-by: Mika Laitio <[email protected]>
1 parent aa3cf8d commit a2e6a3d

21 files changed

+613
-284
lines changed

docs/examples/pytorch/.ipynb_checkpoints/pytorch_amd_gpu_intro-checkpoint.ipynb

Lines changed: 0 additions & 142 deletions
This file was deleted.
Binary file not shown.
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import os
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
6+
# copyright (C) Mika Laitio, [email protected]
7+
# dot product calculation benchmark test based on the documentation at
8+
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
9+
10+
dev_type_arr = ["cuda:0", "cpu"]
11+
solver_name_arr=["Default", "Math", "Flash Attention", "Memory Efficient"]
12+
13+
# testing
14+
print("Pytorch version: " + torch.__version__)
15+
print("dot product calculation test")
16+
query, key, value = torch.randn(2, 3, 8, device=dev_type_arr[0]), torch.randn(2, 3, 8, device=dev_type_arr[0]), torch.randn(2, 3, 8, device=dev_type_arr[0])
17+
print(F.scaled_dot_product_attention(query, key, value))
18+
19+
import torch.utils.benchmark as benchmark
20+
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
21+
t0 = benchmark.Timer(
22+
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
23+
)
24+
return t0.blocked_autorange().mean * 1e6
25+
26+
# benchmark parameters
27+
batch_size = 32
28+
max_sequence_len = 1024
29+
num_heads = 32
30+
embed_dimension = 32
31+
dtype = torch.float16
32+
33+
result_arr=[[0, 0, 0, 0], [0, 0, 0, 0]]
34+
35+
print("")
36+
print("Benchmarking cuda and cpu with Default, Math, Flash Attention amd Memory pytorch backends")
37+
38+
for ii, dev_type_item in enumerate(dev_type_arr):
39+
if dev_type_item == "cpu":
40+
print("Device: cpu-" + str(os.cpu_count()))
41+
else:
42+
print("Device: " + torch.cuda.get_device_name(device=dev_type_item))
43+
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=dev_type_item, dtype=dtype)
44+
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=dev_type_item, dtype=dtype)
45+
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=dev_type_item, dtype=dtype)
46+
47+
print(" " + solver_name_arr[0] + " " + dev_type_item + " benchmark:")
48+
result_arr[ii][0]=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
49+
print(f" {result_arr[ii][0]:.3f} microseconds, {(result_arr[ii][0] / 1e6)} sec")
50+
51+
# Lets explore the speed of each of the 3 implementations
52+
from torch.nn.attention import SDPBackend, sdpa_kernel
53+
54+
print(" " + solver_name_arr[1] + " " + dev_type_item + " benchmark:")
55+
with sdpa_kernel(SDPBackend.MATH):
56+
result_arr[ii][1]=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
57+
print(f" {result_arr[ii][1]:.3f} microseconds, {(result_arr[ii][1] / 1e6)} sec")
58+
59+
print(" " + solver_name_arr[2] + " " + dev_type_item + " benchmark:")
60+
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
61+
try:
62+
result_arr[ii][2]=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
63+
print(f" {result_arr[ii][2]:.3f} microseconds, {(result_arr[ii][2] / 1e6)} sec")
64+
except RuntimeError:
65+
print(" " + solver_name_arr[2] + " " + dev_type_item + " is not supported. See warnings for reasons.")
66+
result_arr[ii][2]=-1
67+
68+
print(" " + solver_name_arr[3] + " " + dev_type_item + " benchmark:")
69+
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
70+
try:
71+
result_arr[ii][3]=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)
72+
print(f" {result_arr[ii][3]:.3f} microseconds, {(result_arr[ii][3] / 1e6)} sec")
73+
except RuntimeError:
74+
print(" " + solver_name_arr[3] + " " + dev_type_item + " is not supported. See warnings for reasons.")
75+
result_arr[ii][3]=-1
76+
77+
print("Summary")
78+
print("\nPytorch version: " + torch.__version__)
79+
print("ROCM HIP version: " + torch.version.hip)
80+
for ii, dev_type_item in enumerate(dev_type_arr):
81+
if dev_type_item == "cpu":
82+
print("Device: cpu-" + str(os.cpu_count()))
83+
else:
84+
print("Device: " + torch.cuda.get_device_name(device=dev_type_item))
85+
for jj, result_item in enumerate(solver_name_arr):
86+
msg_prefix=solver_name_arr[jj] + " " + dev_type_arr[ii] + ":"
87+
msg_prefix=msg_prefix.rjust(30)
88+
msg_number="{:.3f}".format(result_arr[ii][jj])
89+
msg_number=msg_number.rjust(20)
90+
print(f"{msg_prefix} {msg_number} ms")
91+
print("")
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# simple pytorch python example to verify that gpu acceleration is enabled
1+
# pytorch example rocm sdk launcher
22
# if test fails on AMD GPU, enable AMD_LOG_LEVEL and HIP_VISIBLE_DEVICES=0 variables
33
# to get traces to find the failing code part
44
if [ -z $ROCM_HOME ]; then
@@ -7,5 +7,5 @@ if [ -z $ROCM_HOME ]; then
77
echo "before running this script"
88
exit 1
99
fi
10-
python pytorch_gpu_simple_test.py
1110
#AMD_LOG_LEVEL=1 HIP_VISIBLE_DEVICES=0 HIP_LAUNCH_BLOCKING=1 python pytorch_gpu_simple_test.py
11+
python flash_attention_dot_product_benchmark.py
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
3+
print("1) Expected configuration: enable_flash=True, enable_math=False, enable_mem_efficient=False")
4+
print("Real configuration:")
5+
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
6+
print(" cuda.flash_sdp_enabled: " + str(torch.backends.cuda.flash_sdp_enabled()))
7+
# True
8+
print(" cuda.mem_efficient_sdp_enabled: " + str(torch.backends.cuda.mem_efficient_sdp_enabled()))
9+
# False
10+
print(" cuda.math_sdp_enabled: " + str(torch.backends.cuda.math_sdp_enabled()))
11+
# False
12+
13+
print("")
14+
print("2) Expected configuration: enable_flash=False, enable_math=True, enable_mem_efficient=True")
15+
print("Real configuration:")
16+
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH, torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION]):
17+
print(" cuda.flash_sdp_enabled: " + str(torch.backends.cuda.flash_sdp_enabled()))
18+
# False
19+
print(" cuda.mem_efficient_sdp_enabled: " + str(torch.backends.cuda.mem_efficient_sdp_enabled()))
20+
# True
21+
print(" cuda.math_sdp_enabled: " + str(torch.backends.cuda.math_sdp_enabled()))
22+
# True

0 commit comments

Comments
 (0)