Skip to content

Conversation

@drisspg
Copy link
Contributor

@drisspg drisspg commented Oct 18, 2025

See: pytorch/pytorch#165817

Sample output:

❯ python examples/flex_autotune_replay.py
Running autotuning phase...
/home/dev/meta/pytorch/torch/_dynamo/pgo.py:539: UserWarning: dynamo_pgo force disabled by torch.compiler.config.force_disable_caches
  warn_once(
/home/dev/meta/pytorch/torch/__init__.py:1549: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /home/dev/meta/pytorch/aten/src/ATen/Context.cpp:45.)
  return _C._get_float32_matmul_precision()
Autotune Choices Stats:
{"num_choices": 5, "num_triton_choices": 5, "best_kernel": "triton_flex_attention_0", "best_kernel_desc": "BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=64, FLOAT32_PRECISION=\"'ieee'\", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, OUTPUT_MAX=False, PRESCALE_QK=False, QK_HEAD_DIM=64, QK_HEAD_DIM_ROUNDED=64, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=True, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=64, V_HEAD_DIM_ROUNDED=64, WRITE_DQ=True, num_stages=3, num_warps=4", "best_time": 0.36745598912239075, "best_triton_pos": 0}
AUTOTUNE flex_attention(2x2x8192x64, 2x2x8192x64, 2x2x8192x64, 2x2x8192, 2x2x8192, 1x1x64, 1x1x64x64, 1x1x64, 1x1x64x64)
strides: [1048576, 524288, 64, 1], [1048576, 524288, 64, 1], [1048576, 524288, 64, 1], [16384, 8192, 1], [16384, 8192, 1], [64, 64, 1], [4096, 4096, 64, 1], [64, 64, 1], [4096, 4096, 64, 1]
dtypes: torch.float16, torch.float16, torch.float16, torch.float32, torch.float32, torch.int32, torch.int32, torch.int32, torch.int32
  triton_flex_attention_0 0.3675 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, OUTPUT_MAX=False, PRESCALE_QK=False, QK_HEAD_DIM=64, QK_HEAD_DIM_ROUNDED=64, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=True, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=64, V_HEAD_DIM_ROUNDED=64, WRITE_DQ=True, num_stages=3, num_warps=4
  triton_flex_attention_3 0.4166 ms 88.2% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=64, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, OUTPUT_MAX=False, PRESCALE_QK=False, QK_HEAD_DIM=64, QK_HEAD_DIM_ROUNDED=64, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=True, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=64, V_HEAD_DIM_ROUNDED=64, WRITE_DQ=True, num_stages=3, num_warps=4
  triton_flex_attention_4 0.4761 ms 77.2% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=64, BLOCK_N=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, OUTPUT_MAX=False, PRESCALE_QK=False, QK_HEAD_DIM=64, QK_HEAD_DIM_ROUNDED=64, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=True, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=64, V_HEAD_DIM_ROUNDED=64, WRITE_DQ=True, num_stages=3, num_warps=4
  triton_flex_attention_1 0.7097 ms 51.8% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, OUTPUT_MAX=False, PRESCALE_QK=False, QK_HEAD_DIM=64, QK_HEAD_DIM_ROUNDED=64, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=True, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=64, V_HEAD_DIM_ROUNDED=64, WRITE_DQ=True, num_stages=3, num_warps=4
  triton_flex_attention_2 0.7138 ms 51.5% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, OUTPUT_MAX=False, PRESCALE_QK=False, QK_HEAD_DIM=64, QK_HEAD_DIM_ROUNDED=64, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=True, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=64, V_HEAD_DIM_ROUNDED=64, WRITE_DQ=True, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.2192 seconds and 1.9980 seconds precompiling for 5 choices
Autotune Choices Stats:
{"num_choices": 28, "num_triton_choices": 28, "best_kernel": "triton_flex_attention_backward_10", "best_kernel_desc": "BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION=\"'ieee'\", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, OUTPUT_MAX=False, PRESCALE_QK=False, QK_HEAD_DIM=64, QK_HEAD_DIM_ROUNDED=64, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=True, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, V_HEAD_DIM_ROUNDED=64, WRITE_DQ=True, num_stages=3, num_warps=4", "best_time": 1.0371520519256592, "best_triton_pos": 0}
AUTOTUNE flex_attention_backward(2x2x8192x64, 2x2x8192x64, 2x2x8192x64, 2x2x8192, 2x2x8192, 2x2x8192x64, 2x2x8192x64, 2x2x8192x64, 1x1x64, 1x1x64x64, 1x1x64, 1x1x64x64, 1x1x64, 1x1x64x64, 1x1x64, 1x1x64x64)
strides: [1048576, 524288, 64, 1], [1048576, 524288, 64, 1], [1048576, 524288, 64, 1], [16384, 8192, 1], [16384, 8192, 1], [1048576, 524288, 64, 1], [1048576, 524288, 64, 1], [1048576, 524288, 64, 1], [64, 64, 1], [4096, 4096, 64, 1], [64, 64, 1], [4096, 4096, 64, 1], [64, 64, 1], [4096, 4096, 64, 1], [64, 64, 1], [4096, 4096, 64, 1]
dtypes: torch.float16, torch.float16, torch.float16, torch.float32, torch.float32, torch.float16, torch.float16, torch.float16, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32
  triton_flex_attention_backward_10 1.0372 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, OUTPUT_MAX=False, PRESCALE_QK=False, QK_HEAD_DIM=64, QK_HEAD_DIM_ROUNDED=64, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=True, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, V_HEAD_DIM_ROUNDED=64, WRITE_DQ=True, num_stages=3, num_warps=4
  triton_flex_attention_backward_23 1.1644 ms 89.1% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=64, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, OUTPUT_MAX=False, PRESCALE_QK=False, QK_HEAD_DIM=64, QK_HEAD_DIM_ROUNDED=64, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=True, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, V_HEAD_DIM_ROUNDED=64, WRITE_DQ=True, num_stages=4, num_warps=4
  triton_flex_attention_backward_22 1.1652 ms 89.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=64, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, OUTPUT_MAX=False, PRESCALE_QK=False, QK_HEAD_DIM=64, QK_HEAD_DIM_ROUNDED=64, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=True, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, V_HEAD_DIM_ROUNDED=64, WRITE_DQ=True, num_stages=3, num_warps=4
  triton_flex_attention_backward_15 1.2051 ms 86.1% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, OUTPUT_MAX=False, PRESCALE_QK=False, QK_HEAD_DIM=64, QK_HEAD_DIM_ROUNDED=64, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=True, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, V_HEAD_DIM_ROUNDED=64, WRITE_DQ=True, num_stages=3, num_warps=4
  triton_flex_attention_backward_17 1.2421 ms 83.5% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, OUTPUT_MAX=False, PRESCALE_QK=False, QK_HEAD_DIM=64, QK_HEAD_DIM_ROUNDED=64, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=True, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, V_HEAD_DIM_ROUNDED=64, WRITE_DQ=True, num_stages=4, num_warps=4
  triton_flex_attention_backward_12 1.2503 ms 82.9% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, OUTPUT_MAX=False, PRESCALE_QK=False, QK_HEAD_DIM=64, QK_HEAD_DIM_ROUNDED=64, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=True, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, V_HEAD_DIM_ROUNDED=64, WRITE_DQ=True, num_stages=5, num_warps=4
  triton_flex_attention_backward_19 1.2534 ms 82.7% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, OUTPUT_MAX=False, PRESCALE_QK=False, QK_HEAD_DIM=64, QK_HEAD_DIM_ROUNDED=64, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=True, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, V_HEAD_DIM_ROUNDED=64, WRITE_DQ=True, num_stages=5, num_warps=4
  triton_flex_attention_backward_11 1.2780 ms 81.2% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, OUTPUT_MAX=False, PRESCALE_QK=False, QK_HEAD_DIM=64, QK_HEAD_DIM_ROUNDED=64, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=True, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, V_HEAD_DIM_ROUNDED=64, WRITE_DQ=True, num_stages=4, num_warps=4
  triton_flex_attention_backward_18 1.3425 ms 77.3% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, OUTPUT_MAX=False, PRESCALE_QK=False, QK_HEAD_DIM=64, QK_HEAD_DIM_ROUNDED=64, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=True, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, V_HEAD_DIM_ROUNDED=64, WRITE_DQ=True, num_stages=4, num_warps=8
  triton_flex_attention_backward_16 1.3540 ms 76.6% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, OUTPUT_MAX=False, PRESCALE_QK=False, QK_HEAD_DIM=64, QK_HEAD_DIM_ROUNDED=64, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=True, SM_SCALE=0.125, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=64, V_HEAD_DIM_ROUNDED=64, WRITE_DQ=True, num_stages=3, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 1.4277 seconds and 2.9633 seconds precompiling for 28 choices
Autotuning logs saved to /tmp/tmp23wkbqgd/flex_attention_configs.json

Parsing autotuning logs...
Best kernel options extracted from logs:
{
  "fwd_kpack": 2,
  "fwd_num_warps": 4,
  "fwd_BLOCK_N": 64,
  "fwd_matrix_instr_nonkdim": 0,
  "fwd_waves_per_eu": 0,
  "fwd_num_stages": 3,
  "fwd_BLOCK_M": 128,
  "fwd_USE_TMA": false,
  "bwd_BLOCK_M2": 64,
  "bwd_BLOCK_M1": 32,
  "bwd_BLOCK_N1": 64,
  "bwd_BLOCK_N2": 32,
  "bwd_kpack": 2,
  "bwd_num_warps": 4,
  "bwd_matrix_instr_nonkdim": 0,
  "bwd_waves_per_eu": 0,
  "bwd_num_stages": 3
}

Running with pre-compiled best options...
Execution time with best options: 0.151 ms
Backward execution time with best options: 0.332 ms

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 18, 2025
@drisspg drisspg force-pushed the max-autotune-grabber branch from 44a1c2f to 24e807b Compare October 18, 2025 03:49
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Oct 28, 2025
Script: meta-pytorch/attention-gym#169

Feels directionally okay but there is some bike shedding / this could be quite prone to collision of keys depending on mask mod and score mod changes and simple cache key.

Usecase: meta-pytorch/attention-gym#169

Pull Request resolved: #165817
Approved by: https://github.com/Chillee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants