You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using RingFlashAttnFunc I get the error that the backward pass is expected to return 14 gradients (as the corresponding forward pass takes 14 arguments besides ctx) but it actually only returns 13 values. To fix this, one additional None should be added to the backward function of RingFlashAttnFunc.
mrj-taffy, linchuming, tkuye, HaoyiZhu and philippe-eecs