Skip to content

Commit

Permalink
Fix CPU SDP doesn't match FlashAttention result.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Jan 1, 2024
1 parent 3d5713a commit becc88d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
float* const ssp = saved_softmax ? saved_softmax->data.f32 : 0;
const float scale = cmd.info.scaled_dot_product_attention.scale;
const int is_causal = cmd.info.scaled_dot_product_attention.is_causal;
const int h_h_k_ratio = qdim[2] / kdim[2];
assert(kdim[2] == vdim[2]);
assert(qdim[2] >= kdim[2]);
assert(qdim[2] % kdim[2] == 0);
for (i[0] = 0; i[0] < qdim[0]; i[0]++)
{
const float* const qp0 = qp + i[0] * qstride[0];
Expand All @@ -120,8 +124,8 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
for (i[1] = 0; i[1] < qdim[2]; i[1]++)
{
const float* const qp1 = qp0 + i[1] * qstride[2];
const float* const kp1 = kp0 + (i[1] % kdim[2]) * kstride[2];
const float* const vp1 = vp0 + (i[1] % vdim[2]) * vstride[2];
const float* const kp1 = kp0 + (i[1] / h_h_k_ratio) * kstride[2];
const float* const vp1 = vp0 + (i[1] / h_h_k_ratio) * vstride[2];
const float* const amp1 = amp && amdim[1] > 1 ? amp0 + i[1] * amstride[1] : amp0;
float* const cp1 = cp0 + i[1] * cstride[2];
float* const ssp1 = ssp0 ? ssp0 + i[1] * ssstride[1] : 0;
Expand Down
14 changes: 7 additions & 7 deletions test/int/nnc/cublas.tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -2618,13 +2618,13 @@ TEST_CASE("scaled dot product attention with flash_attn")
#define num_trials (num_long_trials + num_short_trials)

for (int trial = 0; trial < num_trials; ++trial) {
int B_candidates[num_trials] = { 32, 12, 16, 1, 2, 15 };
int R_candidates[num_trials] = { 160, 256, 128, 77, 77, 512 };
int C_candidates[num_trials] = { 128, 128, 128, 128, 128, 128 };
int Hq_candidates[num_trials] = { 8, 8, 8, 8, 8, 8 };
int Hk_candidates[num_trials] = { 8, 8, 8, 8, 2, 4 };
int D_candidates[num_trials] = { 64, 40, 160, 224, 224, 64 };
int is_causal_candidates[num_trials] = { 1, 0, 1, 1, 0, 0 };
int B_candidates[num_trials] = { 32, 12, 16, 1, 2, 1 };
int R_candidates[num_trials] = { 160, 256, 128, 77, 77, 5 };
int C_candidates[num_trials] = { 128, 128, 128, 128, 128, 5 };
int Hq_candidates[num_trials] = { 8, 8, 8, 8, 8, 32 };
int Hk_candidates[num_trials] = { 8, 8, 8, 8, 2, 8 };
int D_candidates[num_trials] = { 64, 40, 160, 224, 224, 128 };
int is_causal_candidates[num_trials] = { 1, 0, 1, 1, 0, 1 };

int B = B_candidates[trial];
int R = R_candidates[trial];
Expand Down

0 comments on commit becc88d

Please sign in to comment.