[rls-v3.11] graph: re-enable sdpa training bwd#4825
[rls-v3.11] graph: re-enable sdpa training bwd#4825ElaineBao wants to merge 7 commits intosyurkevi/fused_sdpa_training_backport311from
Conversation
|
make test |
6a65117 to
70d711f
Compare
|
make test |
This only works on PRs targeting production branches... |
70d711f to
43dcf3d
Compare
| --reset --dt=1:bf16+2:bf16+3:bf16+4:bf16+6:bf16+104:bf16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json | ||
| --reset --dt=0:bf16+1:bf16+7:bf16+9:bf16+10:bf16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f32.json | ||
| --reset --dt=0:bf16+1:bf16+7:bf16+8:bf16+9:bf16 --case=complex_fusion/mha/sdpa-plain-wo-scale-f32.json | ||
| --reset --case=complex_fusion/mha/gqa-plain-training-backward-w-dmask-bf16-f32.json |
There was a problem hiding this comment.
Does that mean we supported dmask in v3.11 but now it gets removed?
There was a problem hiding this comment.
How hard will it be to add it back?
There was a problem hiding this comment.
The pattern is already supported, so ideally dmask is supported, but there's no test case for it. I'll add a test case later.
56c6434 to
eeeb25c
Compare
eeeb25c to
ea318d2
Compare
Description
Implementation of Proposal 2.C in RFC:
Main branch PR:
Currently there are still some correctness issues (with microkernel), but it doesn't seem to be computation error. large partition can pass.