Skip to content

[DEBUG] Try run FlexAttn in parallel #4376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft

Conversation

anmyachev
Copy link
Contributor

@anmyachev anmyachev commented May 30, 2025

PyTorch CI:

Seems quite fast (without decoding). Before this, the tests were running in one process, apparently.

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
===== 101 failed, 496 passed, 58 skipped, 1 xfailed in 2611.01s (0:43:31) ======

Although I see a lot of errors, perhaps due to parallelism. We can experiment and choose the most successful combination. (It might also be good to enable restart in case of errors)

FYI @pbchekin @vlad-penkin @alexbaden @chengjunlu @whitneywhtsang @etiotto

Signed-off-by: Anatoly Myachev <[email protected]>
@anmyachev anmyachev changed the title Try run FlexAttn in parallel [DEBUG] Try run FlexAttn in parallel May 30, 2025
@pbchekin
Copy link
Contributor

Good results! We can add a dedicated workflow for Flex Attention. The number of workers needs to be a parameter to adjust it for client GPUs. Also it would be nice to identify the root cause for failures, looks like accuracy errors in most cases, not sure if it is due to parallelism.

@anmyachev
Copy link
Contributor Author

Part of these problems I see in our usual run https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/15373802137/job/43256311378#step:8:14641:

The following tests failed consistently: ['test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s0_v_s0_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s0_v_s0_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s0_v_s0_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s1_v_s1_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s1_v_s1_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s1_v_s1_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s2_v_s2_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s2_v_s2_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s2_v_s2_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s3_v_s3_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s3_v_s3_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s3_v_s3_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s0_v_s0_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s0_v_s0_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s0_v_s0_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s1_v_s1_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s1_v_s1_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s1_v_s1_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s2_v_s2_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s2_v_s2_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s2_v_s2_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s3_v_s3_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s3_v_s3_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s3_v_s3_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_head_specific_gate_batch:2_head:4_seq_len:256_headdim:16_dtype:float16_mode_default_xpu', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_head_specific_gate_batch:2_head:4_seq_len:256_headdim:16_dtype:float16_mode_max-autotune-no-cudagraphs_xpu', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_head_specific_gate_batch:2_head:4_seq_len:277_headdim:16_dtype:float16_mode_default_xpu', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_head_specific_gate_batch:2_head:4_seq_len:277_headdim:16_dtype:float16_mode_max-autotune-no-cudagraphs_xpu', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_head_specific_gate_batch:2_head:4_seq_len:37_headdim:16_dtype:float16_mode_default_xpu', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_head_specific_gate_batch:2_head:4_seq_len:37_headdim:16_dtype:float16_mode_max-autotune-no-cudagraphs_xpu', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_relative_1d_bias_batch:2_head:4_seq_len:256_headdim:16_dtype:float32_mode_max-autotune-no-cudagraphs_xpu']

@whitneywhtsang
Copy link
Contributor

whitneywhtsang commented Jun 2, 2025

Part of these problems I see in our usual run https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/15373802137/job/43256311378#step:8:14641:

The following tests failed consistently: ['test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s0_v_s0_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s0_v_s0_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s0_v_s0_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s1_v_s1_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s1_v_s1_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s1_v_s1_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s2_v_s2_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s2_v_s2_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s2_v_s2_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s3_v_s3_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s3_v_s3_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s3_v_s3_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s0_v_s0_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s0_v_s0_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s0_v_s0_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s1_v_s1_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s1_v_s1_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s1_v_s1_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s2_v_s2_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s2_v_s2_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s2_v_s2_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s3_v_s3_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s3_v_s3_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s3_v_s3_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_head_specific_gate_batch:2_head:4_seq_len:256_headdim:16_dtype:float16_mode_default_xpu', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_head_specific_gate_batch:2_head:4_seq_len:256_headdim:16_dtype:float16_mode_max-autotune-no-cudagraphs_xpu', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_head_specific_gate_batch:2_head:4_seq_len:277_headdim:16_dtype:float16_mode_default_xpu', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_head_specific_gate_batch:2_head:4_seq_len:277_headdim:16_dtype:float16_mode_max-autotune-no-cudagraphs_xpu', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_head_specific_gate_batch:2_head:4_seq_len:37_headdim:16_dtype:float16_mode_default_xpu', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_head_specific_gate_batch:2_head:4_seq_len:37_headdim:16_dtype:float16_mode_max-autotune-no-cudagraphs_xpu', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_relative_1d_bias_batch:2_head:4_seq_len:256_headdim:16_dtype:float32_mode_max-autotune-no-cudagraphs_xpu']

@chengjunlu investigated above failures before, and believe they are issues in SYCL.
TestFlexAttentionXPU::test_strided_inputs failures are fixed in PyTorch main.

@anmyachev
Copy link
Contributor Author

Looks much better with last PyTorch pin update (https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/15416558343/job/43380600584):

This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
====== 7 failed, 633 passed, 15 skipped, 1 xfailed in 2571.39s (0:42:51) =======

@etiotto
Copy link
Contributor

etiotto commented Jun 3, 2025

If you get intermittent failures you can try to rerun the failed tests (in parallel again, or perhaps sequentially). If you rerun in parallel you may have to rerun again...

david-hls pushed a commit to david-hls/intel-xpu-backend-for-triton that referenced this pull request Jun 18, 2025
…el#4376)

There was no test coverage for this. I discovered this while
implementing the CPU backend.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants