|
| 1 | +import json |
| 2 | +import os |
| 3 | +import tempfile |
| 4 | +from unittest.mock import patch |
| 5 | + |
| 6 | +import torch |
| 7 | +from torch.nn.attention.flex_attention import flex_attention, FlexKernelOptions, create_block_mask |
| 8 | +from attn_gym.utils import benchmark_cuda_function_in_microseconds |
| 9 | +from attn_gym.masks import causal_mask |
| 10 | + |
| 11 | +torch.compiler.config.force_disable_caches = True |
| 12 | +torch._functorch.config.donated_buffer = False |
| 13 | + |
| 14 | + |
| 15 | +def run_autotune(log_file: str): |
| 16 | + """ |
| 17 | + Runs flex_attention with max-autotune to generate kernel tuning logs. |
| 18 | + """ |
| 19 | + print("Running autotuning phase...") |
| 20 | + query = torch.randn(2, 2, 8192, 64, device="cuda", dtype=torch.float16, requires_grad=True) |
| 21 | + key = torch.randn(2, 2, 8192, 64, device="cuda", dtype=torch.float16, requires_grad=True) |
| 22 | + value = torch.randn(2, 2, 8192, 64, device="cuda", dtype=torch.float16, requires_grad=True) |
| 23 | + |
| 24 | + block_mask = torch.compile(create_block_mask)( |
| 25 | + causal_mask, None, None, 8192, 8192, device="cuda" |
| 26 | + ) |
| 27 | + |
| 28 | + compiled_flex = torch.compile(flex_attention, mode="max-autotune-no-cudagraphs") |
| 29 | + |
| 30 | + with patch.dict(os.environ, {"TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE": log_file}): |
| 31 | + out = compiled_flex( |
| 32 | + query=query, |
| 33 | + key=key, |
| 34 | + value=value, |
| 35 | + block_mask=block_mask, |
| 36 | + ) |
| 37 | + out.sum().backward() |
| 38 | + print(f"Autotuning logs saved to {log_file}.json") |
| 39 | + |
| 40 | + |
| 41 | +def parse_log_and_get_best_options(log_file: str) -> FlexKernelOptions: |
| 42 | + """ |
| 43 | + Parses the autotuning log and returns the best kernel options. |
| 44 | + """ |
| 45 | + print("\nParsing autotuning logs...") |
| 46 | + json_file = log_file + ".json" |
| 47 | + with open(json_file) as f: |
| 48 | + log_data = json.load(f) |
| 49 | + |
| 50 | + best_options: FlexKernelOptions = {} |
| 51 | + for entry in log_data: |
| 52 | + dims_key, choices = next(iter(entry.items())) |
| 53 | + kernel_type = eval(dims_key)[0] # 'forward' or 'backward' |
| 54 | + best_choice = choices[0] # The list is sorted by time |
| 55 | + |
| 56 | + prefix = "fwd_" if kernel_type == "forward" else "bwd_" |
| 57 | + |
| 58 | + for key, value in best_choice.items(): |
| 59 | + if key not in ["type", "time"]: |
| 60 | + # Ensure the key is valid for FlexKernelOptions |
| 61 | + if key in FlexKernelOptions.__annotations__: |
| 62 | + best_options[f"{prefix}{key}"] = value |
| 63 | + print("Best kernel options extracted from logs:") |
| 64 | + print(json.dumps(best_options, indent=2)) |
| 65 | + return best_options |
| 66 | + |
| 67 | + |
| 68 | +def run_with_best_options(kernel_options: FlexKernelOptions): |
| 69 | + """ |
| 70 | + Runs flex_attention with the provided kernel options. |
| 71 | + """ |
| 72 | + print("\nRunning with pre-compiled best options...") |
| 73 | + query = torch.randn(2, 2, 8192, 64, device="cuda", dtype=torch.float16, requires_grad=True) |
| 74 | + key = torch.randn(2, 2, 8192, 64, device="cuda", dtype=torch.float16, requires_grad=True) |
| 75 | + value = torch.randn(2, 2, 8192, 64, device="cuda", dtype=torch.float16, requires_grad=True) |
| 76 | + |
| 77 | + block_mask = torch.compile(create_block_mask)( |
| 78 | + causal_mask, None, None, 8192, 8192, device="cuda" |
| 79 | + ) |
| 80 | + |
| 81 | + # Note: We are not using max-autotune here |
| 82 | + compiled_flex = torch.compile(flex_attention) |
| 83 | + |
| 84 | + # Make sure we are not logging this run |
| 85 | + if "TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE" in os.environ: |
| 86 | + del os.environ["TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE"] |
| 87 | + |
| 88 | + def run_fwd(): |
| 89 | + return compiled_flex( |
| 90 | + query=query, |
| 91 | + key=key, |
| 92 | + value=value, |
| 93 | + block_mask=block_mask, |
| 94 | + kernel_options=kernel_options, |
| 95 | + ) |
| 96 | + |
| 97 | + # Warmup |
| 98 | + for _ in range(3): |
| 99 | + run_fwd() |
| 100 | + fwd_time = benchmark_cuda_function_in_microseconds(run_fwd) |
| 101 | + print(f"Execution time with best options: {fwd_time / 1000:.3f} ms") |
| 102 | + |
| 103 | + out = run_fwd() |
| 104 | + loss = out.sum() |
| 105 | + |
| 106 | + def run_bwd(): |
| 107 | + loss.backward(retain_graph=True) |
| 108 | + |
| 109 | + # Warmup |
| 110 | + for _ in range(3): |
| 111 | + run_bwd() |
| 112 | + |
| 113 | + bwd_time = benchmark_cuda_function_in_microseconds(run_bwd) |
| 114 | + print(f"Backward execution time with best options: {bwd_time / 1000:.3f} ms") |
| 115 | + |
| 116 | + |
| 117 | +def main(): |
| 118 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 119 | + log_file_path = os.path.join(tmpdir, "flex_attention_configs") |
| 120 | + |
| 121 | + # 1. Run autotuning |
| 122 | + run_autotune(log_file_path) |
| 123 | + |
| 124 | + # 2. Parse the log file to get the best options |
| 125 | + best_kernel_options = parse_log_and_get_best_options(log_file_path) |
| 126 | + |
| 127 | + # 3. Run with the best options |
| 128 | + run_with_best_options(best_kernel_options) |
| 129 | + |
| 130 | + |
| 131 | +if __name__ == "__main__": |
| 132 | + main() |
0 commit comments