-
Notifications
You must be signed in to change notification settings - Fork 68
[DEBUG] Try run FlexAttn
in parallel
#4376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Anatoly Myachev <[email protected]>
FlexAttn
in parallelFlexAttn
in parallel
Good results! We can add a dedicated workflow for Flex Attention. The number of workers needs to be a parameter to adjust it for client GPUs. Also it would be nice to identify the root cause for failures, looks like accuracy errors in most cases, not sure if it is due to parallelism. |
Part of these problems I see in our usual run https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/15373802137/job/43256311378#step:8:14641: The following tests failed consistently: ['test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s0_v_s0_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s0_v_s0_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s0_v_s0_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s1_v_s1_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s1_v_s1_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s1_v_s1_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s2_v_s2_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s2_v_s2_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s2_v_s2_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s3_v_s3_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s3_v_s3_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s0_k_s3_v_s3_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s0_v_s0_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s0_v_s0_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s0_v_s0_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s1_v_s1_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s1_v_s1_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s1_v_s1_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s2_v_s2_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s2_v_s2_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s2_v_s2_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s3_v_s3_do_s0_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s3_v_s3_do_s1_xpu_float16', 'test/inductor/test_flex_attention.py::TestFlexAttentionXPU::test_strided_inputs_q_s1_k_s3_v_s3_do_s2_xpu_float16', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_head_specific_gate_batch:2_head:4_seq_len:256_headdim:16_dtype:float16_mode_default_xpu', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_head_specific_gate_batch:2_head:4_seq_len:256_headdim:16_dtype:float16_mode_max-autotune-no-cudagraphs_xpu', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_head_specific_gate_batch:2_head:4_seq_len:277_headdim:16_dtype:float16_mode_default_xpu', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_head_specific_gate_batch:2_head:4_seq_len:277_headdim:16_dtype:float16_mode_max-autotune-no-cudagraphs_xpu', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_head_specific_gate_batch:2_head:4_seq_len:37_headdim:16_dtype:float16_mode_default_xpu', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_head_specific_gate_batch:2_head:4_seq_len:37_headdim:16_dtype:float16_mode_max-autotune-no-cudagraphs_xpu', 'test/inductor/test_flex_attention.py::TestLearnableBiasesXPU::test_relative_1d_bias_batch:2_head:4_seq_len:256_headdim:16_dtype:float32_mode_max-autotune-no-cudagraphs_xpu'] |
@chengjunlu investigated above failures before, and believe they are issues in SYCL. |
Looks much better with last PyTorch pin update (https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/15416558343/job/43380600584): This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
====== 7 failed, 633 passed, 15 skipped, 1 xfailed in 2571.39s (0:42:51) ======= |
If you get intermittent failures you can try to rerun the failed tests (in parallel again, or perhaps sequentially). If you rerun in parallel you may have to rerun again... |
…el#4376) There was no test coverage for this. I discovered this while implementing the CPU backend.
PyTorch CI:
-n 16
)-n 8 --reruns 2
, py3.9)n 8 --reruns 2
, py3.10)-n 32
then usual run)-n 16
then usual run)-n 16 --reruns 2
)n 16 --reruns 2
with new PyTorch pin)Seems quite fast (without decoding). Before this, the tests were running in one process, apparently.
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0 ===== 101 failed, 496 passed, 58 skipped, 1 xfailed in 2611.01s (0:43:31) ======
Although I see a lot of errors, perhaps due to parallelism. We can experiment and choose the most successful combination. (It might also be good to enable restart in case of errors)
FYI @pbchekin @vlad-penkin @alexbaden @chengjunlu @whitneywhtsang @etiotto