Skip to content

Commit 24e807b

Browse files
committed
Add grabber script
1 parent 6e36dd9 commit 24e807b

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed

attn_gym/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,19 @@
1616
except ImportError:
1717
from torch._higher_order_ops.flex_attention import TransformGetItemToIndex
1818
from contextlib import nullcontext
19+
from torch._inductor.utils import do_bench_using_profiling
20+
from collections.abc import Callable
1921

2022
Tensor = torch.Tensor
2123

2224

25+
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
26+
"""Thin wrapper around do_bench_using_profiling"""
27+
no_args = lambda: func(*args, **kwargs)
28+
time = do_bench_using_profiling(no_args)
29+
return time * 1e3
30+
31+
2332
def create_score_mod(
2433
query: torch.Tensor,
2534
key: torch.Tensor,

examples/flex_autotune_replay.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import json
2+
import os
3+
import tempfile
4+
from unittest.mock import patch
5+
6+
import torch
7+
from torch.nn.attention.flex_attention import flex_attention, FlexKernelOptions, create_block_mask
8+
from attn_gym.utils import benchmark_cuda_function_in_microseconds
9+
from attn_gym.masks import causal_mask
10+
11+
torch.compiler.config.force_disable_caches = True
12+
torch._functorch.config.donated_buffer = False
13+
14+
15+
def run_autotune(log_file: str):
16+
"""
17+
Runs flex_attention with max-autotune to generate kernel tuning logs.
18+
"""
19+
print("Running autotuning phase...")
20+
query = torch.randn(2, 2, 8192, 64, device="cuda", dtype=torch.float16, requires_grad=True)
21+
key = torch.randn(2, 2, 8192, 64, device="cuda", dtype=torch.float16, requires_grad=True)
22+
value = torch.randn(2, 2, 8192, 64, device="cuda", dtype=torch.float16, requires_grad=True)
23+
24+
block_mask = torch.compile(create_block_mask)(
25+
causal_mask, None, None, 8192, 8192, device="cuda"
26+
)
27+
28+
compiled_flex = torch.compile(flex_attention, mode="max-autotune-no-cudagraphs")
29+
30+
with patch.dict(os.environ, {"TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE": log_file}):
31+
out = compiled_flex(
32+
query=query,
33+
key=key,
34+
value=value,
35+
block_mask=block_mask,
36+
)
37+
out.sum().backward()
38+
print(f"Autotuning logs saved to {log_file}.json")
39+
40+
41+
def parse_log_and_get_best_options(log_file: str) -> FlexKernelOptions:
42+
"""
43+
Parses the autotuning log and returns the best kernel options.
44+
"""
45+
print("\nParsing autotuning logs...")
46+
json_file = log_file + ".json"
47+
with open(json_file) as f:
48+
log_data = json.load(f)
49+
50+
best_options: FlexKernelOptions = {}
51+
for entry in log_data:
52+
dims_key, choices = next(iter(entry.items()))
53+
kernel_type = eval(dims_key)[0] # 'forward' or 'backward'
54+
best_choice = choices[0] # The list is sorted by time
55+
56+
prefix = "fwd_" if kernel_type == "forward" else "bwd_"
57+
58+
for key, value in best_choice.items():
59+
if key not in ["type", "time"]:
60+
# Ensure the key is valid for FlexKernelOptions
61+
if key in FlexKernelOptions.__annotations__:
62+
best_options[f"{prefix}{key}"] = value
63+
print("Best kernel options extracted from logs:")
64+
print(json.dumps(best_options, indent=2))
65+
return best_options
66+
67+
68+
def run_with_best_options(kernel_options: FlexKernelOptions):
69+
"""
70+
Runs flex_attention with the provided kernel options.
71+
"""
72+
print("\nRunning with pre-compiled best options...")
73+
query = torch.randn(2, 2, 8192, 64, device="cuda", dtype=torch.float16, requires_grad=True)
74+
key = torch.randn(2, 2, 8192, 64, device="cuda", dtype=torch.float16, requires_grad=True)
75+
value = torch.randn(2, 2, 8192, 64, device="cuda", dtype=torch.float16, requires_grad=True)
76+
77+
block_mask = torch.compile(create_block_mask)(
78+
causal_mask, None, None, 8192, 8192, device="cuda"
79+
)
80+
81+
# Note: We are not using max-autotune here
82+
compiled_flex = torch.compile(flex_attention)
83+
84+
# Make sure we are not logging this run
85+
if "TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE" in os.environ:
86+
del os.environ["TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE"]
87+
88+
def run_fwd():
89+
return compiled_flex(
90+
query=query,
91+
key=key,
92+
value=value,
93+
block_mask=block_mask,
94+
kernel_options=kernel_options,
95+
)
96+
97+
# Warmup
98+
for _ in range(3):
99+
run_fwd()
100+
fwd_time = benchmark_cuda_function_in_microseconds(run_fwd)
101+
print(f"Execution time with best options: {fwd_time / 1000:.3f} ms")
102+
103+
out = run_fwd()
104+
loss = out.sum()
105+
106+
def run_bwd():
107+
loss.backward(retain_graph=True)
108+
109+
# Warmup
110+
for _ in range(3):
111+
run_bwd()
112+
113+
bwd_time = benchmark_cuda_function_in_microseconds(run_bwd)
114+
print(f"Backward execution time with best options: {bwd_time / 1000:.3f} ms")
115+
116+
117+
def main():
118+
with tempfile.TemporaryDirectory() as tmpdir:
119+
log_file_path = os.path.join(tmpdir, "flex_attention_configs")
120+
121+
# 1. Run autotuning
122+
run_autotune(log_file_path)
123+
124+
# 2. Parse the log file to get the best options
125+
best_kernel_options = parse_log_and_get_best_options(log_file_path)
126+
127+
# 3. Run with the best options
128+
run_with_best_options(best_kernel_options)
129+
130+
131+
if __name__ == "__main__":
132+
main()

0 commit comments

Comments
 (0)