Skip to content

[rls-v3.11] graph: re-enable sdpa training bwd#4825

Open
ElaineBao wants to merge 7 commits intosyurkevi/fused_sdpa_training_backport311from
yixin/sdpa_ukernel_train_backport
Open

[rls-v3.11] graph: re-enable sdpa training bwd#4825
ElaineBao wants to merge 7 commits intosyurkevi/fused_sdpa_training_backport311from
yixin/sdpa_ukernel_train_backport

Conversation

@ElaineBao
Copy link
Contributor

@ElaineBao ElaineBao commented Mar 13, 2026

Description

Implementation of Proposal 2.C in RFC:

Main branch PR:

Currently there are still some correctness issues (with microkernel), but it doesn't seem to be computation error. large partition can pass.

_DNNL_GRAPH_SDPA_FORCE_PRIMITIVE=0 DNNL_VERBOSE=0 ./tests/benchdnn/benchdnn --graph --engine=gpu --case='/home/gta/yixin/oneDNN/tests/benchdnn/i
nputs/graph/complex_fusion/mha/sdpa-plain-training-backward-bf16-f32.json'

onednn_verbose,v1,graph,exec,gpu,100002,sdp,bmm1;scale_mul;mask_add;subtract;exp;typecast;bmm_dv;bmm_dprobs;mul_o_do;reducesum_correction;sub_dp_corrected;mul_softmax_bwd;scale_mul;typecast;bmm_dk;bmm_dq,,in0_bf16:100:strided:undef:1x16x384x64:393216s24576s64s1 in1_bf16:101:strided:undef:1x16x384x64:393216s24576s64s1 in2_bf16:102:strided:undef:1:1 in3_bf16:103:strided:undef:1x16x384x384:2359296s147456s384s1 in4_f32:8:strided:undef:1x16x384x1:6144s384s1s1 in5_bf16:105:strided:undef:1x16x384x64:393216s24576s64s1 in6_bf16:105:strided:undef:1x16x384x64:393216s24576s64s1 in7_bf16:104:strided:undef:1x16x384x64:393216s24576s64s1 in8_bf16:10:strided:undef:1x16x384x64:393216s24576s64s1 in9_bf16:105:strided:undef:1x16x384x64:393216s24576s64s1 in10_bf16:102:strided:undef:1:1 in11_bf16:100:strided:undef:1x16x384x64:393216s24576s64s1 in12_bf16:101:strided:undef:1x16x384x64:393216s24576s64s1 out0_bf16:13:strided:undef:1x16x384x64:393216s24576s64s1 out1_bf16:31:strided:undef:1x16x384x64:393216s24576s64s1 out2_bf16:29:strided:undef:1x16x384x64:393216s24576s64s1,fpm:strict,sdp_bwd_primitive_kernel_t,dnnl_backend,0.789062

[  65][0:0:1:1] exp_f32:     1.12938 exp:     1.13281 got:     1.10938 diff:0.0234375 rdiff:0.0206897
[  76][0:0:1:12] exp_f32:    0.589406 exp:    0.589844 got:    0.609375 diff:0.0195312 rdiff:0.0331126
[  78][0:0:1:14] exp_f32:     1.13174 exp:     1.13281 got:     1.14844 diff:0.015625 rdiff:0.0137931
[ 115][0:0:1:51] exp_f32:    0.792262 exp:    0.792969 got:     0.78125 diff:0.0117188 rdiff:0.0147783
[ 129][0:0:2:1] exp_f32:   -0.098105 exp:  -0.0981445 got:   0.0105591 diff:0.108704 rdiff: 1.10759
[ 134][0:0:2:6] exp_f32:    -10.9466 exp:    -10.9375 got:    -11.0625 diff:   0.125 rdiff:0.0114286
[ 144][0:0:2:16] exp_f32:    -6.46232 exp:    -6.46875 got:    -6.40625 diff:  0.0625 rdiff:0.00966184
[ 145][0:0:2:17] exp_f32:     5.20519 exp:     5.21875 got:     5.34375 diff:   0.125 rdiff:0.0239521
[ 149][0:0:2:21] exp_f32:     5.00748 exp:           5 got:      4.9375 diff:  0.0625 rdiff:  0.0125
[ 150][0:0:2:22] exp_f32:     5.35494 exp:     5.34375 got:         5.5 diff: 0.15625 rdiff:0.0292398
[COMPARE_STATS]: trh=0 err_max_diff:     0.5 err_max_rdiff: 794.106 all_max_diff:       1 all_max_rdiff: 794.106
[COMPARE_STATS] Norm check is prohibited; error_to_total_ratio: 17422/393216; allowed_ratio: 384/393216;
Error: Function 'doit' at (/home/gta/yixin/oneDNN/tests/benchdnn/graph/graph.cpp:773) returned '1'
0:FAILED (errors:17422 total:393216) (1435 ms) __REPRO: --graph --engine=gpu --case=/home/gta/yixin/oneDNN/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-training-backward-bf16-f32.json
===========================================================
= Failed cases summary (--summary=no-failures to disable) =
===========================================================
0:FAILED (errors:17422 total:393216) (1435 ms) __REPRO: --graph --engine=gpu --case=/home/gta/yixin/oneDNN/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-training-backward-bf16-f32.json
============================
tests:1 passed:0 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:1 listed:0
total: 1.45s; create_pd: 0.73s (50%); create_prim: 0.07s (5%); fill: 0.00s (0%); execute: 0.04s (3%); compute_ref: 0.00s (0%); compare: 0.00s (0%);
_DNNL_GRAPH_SDPA_FORCE_PRIMITIVE=1 DNNL_VERBOSE=0 ./tests/benchdnn/benchdnn --graph --engine=gpu --case='/home/gta/yixin/oneDNN/tests/benchdnn/i
nputs/graph/complex_fusion/mha/sdpa-plain-training-backward-bf16-f32.json'

onednn_verbose,v1,graph,exec,gpu,100002,sdp,bmm1;scale_mul;mask_add;subtract;exp;typecast;bmm_dv;bmm_dprobs;mul_o_do;reducesum_correction;sub_dp_corrected;mul_softmax_bwd;scale_mul;typecast;bmm_dk;bmm_dq,,in0_bf16:100:strided:undef:1x16x384x64:393216s24576s64s1 in1_bf16:101:strided:undef:1x16x384x64:393216s24576s64s1 in2_bf16:102:strided:undef:1:1 in3_bf16:103:strided:undef:1x16x384x384:2359296s147456s384s1 in4_f32:8:strided:undef:1x16x384x1:6144s384s1s1 in5_bf16:105:strided:undef:1x16x384x64:393216s24576s64s1 in6_bf16:105:strided:undef:1x16x384x64:393216s24576s64s1 in7_bf16:104:strided:undef:1x16x384x64:393216s24576s64s1 in8_bf16:10:strided:undef:1x16x384x64:393216s24576s64s1 in9_bf16:105:strided:undef:1x16x384x64:393216s24576s64s1 in10_bf16:102:strided:undef:1:1 in11_bf16:100:strided:undef:1x16x384x64:393216s24576s64s1 in12_bf16:101:strided:undef:1x16x384x64:393216s24576s64s1 out0_bf16:13:strided:undef:1x16x384x64:393216s24576s64s1 out1_bf16:31:strided:undef:1x16x384x64:393216s24576s64s1 out2_bf16:29:strided:undef:1x16x384x64:393216s24576s64s1,fpm:strict,larger_partition_kernel_t,dnnl_backend,56.4551

0:PASSED (1580 ms) __REPRO: --graph --engine=gpu --case=/home/gta/yixin/oneDNN/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-training-backward-bf16-f32.json
tests:1 passed:1 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:0 listed:0
total: 1.58s; create_pd: 0.73s (46%); create_prim: 0.07s (5%); fill: 0.00s (0%); execute: 0.04s (2%); compute_ref: 0.00s (0%); compare: 0.00s (0%);

@ElaineBao ElaineBao self-assigned this Mar 13, 2026
@ElaineBao ElaineBao requested review from a team as code owners March 13, 2026 07:08
@ElaineBao ElaineBao added the component:graph-api Codeowner: @oneapi-src/onednn-graph label Mar 13, 2026
@github-actions github-actions bot added component:tests Codeowner: @oneapi-src/onednn-arch component:examples and removed backport labels Mar 13, 2026
@ElaineBao
Copy link
Contributor Author

make test
set test_scope=NIGHTLY
disable benchdnn_all
enable benchdnn_graph

@ElaineBao ElaineBao force-pushed the yixin/sdpa_ukernel_train_backport branch from 6a65117 to 70d711f Compare March 13, 2026 09:22
@ElaineBao
Copy link
Contributor Author

make test
disable benchdnn_all
set test_scope=NIGHTLY
enable benchdnn_graph
disable test_device_cpu
enable test_device_gpu
enable arch_gpu_xe-hpc
enable arch_gpu_xe-lpg+
enable arch_gpu_xe2-hpg-bmg
enable arch_gpu_xe2-lpg

@vpirogov
Copy link
Contributor

make test disable benchdnn_all set test_scope=NIGHTLY enable benchdnn_graph disable test_device_cpu enable test_device_gpu enable arch_gpu_xe-hpc enable arch_gpu_xe-lpg+ enable arch_gpu_xe2-hpg-bmg enable arch_gpu_xe2-lpg

This only works on PRs targeting production branches...

@vpirogov
Copy link
Contributor

vpirogov commented Mar 13, 2026

Orphan CI.

@ElaineBao ElaineBao force-pushed the yixin/sdpa_ukernel_train_backport branch from 70d711f to 43dcf3d Compare March 15, 2026 13:43
--reset --dt=1:bf16+2:bf16+3:bf16+4:bf16+6:bf16+104:bf16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json
--reset --dt=0:bf16+1:bf16+7:bf16+9:bf16+10:bf16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f32.json
--reset --dt=0:bf16+1:bf16+7:bf16+8:bf16+9:bf16 --case=complex_fusion/mha/sdpa-plain-wo-scale-f32.json
--reset --case=complex_fusion/mha/gqa-plain-training-backward-w-dmask-bf16-f32.json
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean we supported dmask in v3.11 but now it gets removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How hard will it be to add it back?

Copy link
Contributor Author

@ElaineBao ElaineBao Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pattern is already supported, so ideally dmask is supported, but there's no test case for it. I'll add a test case later.

@ElaineBao ElaineBao force-pushed the yixin/sdpa_ukernel_train_backport branch 2 times, most recently from 56c6434 to eeeb25c Compare March 16, 2026 07:43
@ElaineBao ElaineBao force-pushed the yixin/sdpa_ukernel_train_backport branch from eeeb25c to ea318d2 Compare March 17, 2026 16:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

component:examples component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants