Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Large loss of accuracy between flashattention and native #1391

Open
fanfanaaaa opened this issue Dec 17, 2024 · 3 comments
Open

Large loss of accuracy between flashattention and native #1391

fanfanaaaa opened this issue Dec 17, 2024 · 3 comments

Comments

@fanfanaaaa
Copy link

I am trying to implement flashAttention in local, which is similar to zhuzilin/ring-flash-attention#24.
But I have found that there is a significant discrepancy between the attention computation using the flashAttention interface and the native attention computation. Could you please explain why this is the case? This error is unacceptable for inference.
This issue has been troubling me for a long time. I am looking forward to your reply.

`

      def _flashAttention_forward(self,
                                  x: torch.Tensor,
                                  start_pos: int,
                                  freqs_cis: torch.Tensor,
                                  mask: Optional[torch.Tensor],
                                  softmax_scale=None,
                                  dropout_p=0,
                                  causal=True,
                                  window_size=(-1, -1),
                                  alibi_slopes=None,
                                  deterministic=False,
          ):
              bsz, seqlen, _ = x.shape
              xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
              xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
              xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
              xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 
      
              xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
              self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
              self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
      
              keys = self.cache_k[:bsz, : start_pos + seqlen]
              values = self.cache_v[:bsz, : start_pos + seqlen] 
         
              keys = repeat_kv(
                  keys, self.n_rep )  # (bs, cache_len + seqlen, n_local_heads, head_dim)
              values = repeat_kv(
                  values, self.n_rep )  # (bs, cache_len + seqlen, n_local_heads, head_dim) 
      
              xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
              keys = keys.transpose(1, 2)  # (bs, n_local_heads, cache_len + seqlen, head_dim)
              values = values.transpose(
                  1, 2
              )  # (bs, n_local_heads, cache_len + seqlen, head_dim)
              scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
              if mask is not None:
                  scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
              scores = F.softmax(scores.float(), dim=-1).type_as(xq)
              # print(f"scores shape: {scores.shape}, values shape: {values.shape}")
              native_output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim) 
              
              
              if softmax_scale is None:
                  softmax_scale = xq.shape[-1] ** (-0.5)
              flash_attn_output, _, _,_ = _flash_attn_forward(
                                      xq,
                                      keys,
                                      values,
                                      dropout_p=dropout_p,
                                      softmax_scale=softmax_scale,
                                      causal=True,
                                      window_size_left=-1, 
                                      window_size_right=-1,
                                      softcap=0.0,
                                      alibi_slopes=alibi_slopes,
                                      return_softmax=True and dropout_p > 0,)
              
              flash_attn_varlen_output = flash_attn_varlen_func(
                  q=xq.reshape(-1, self.n_local_heads, self.head_dim),
                  k=keys.reshape(-1, self.n_local_heads, self.head_dim),
                  v=values.reshape(-1, self.n_local_heads, self.head_dim),
                  cu_seqlens_q=torch.tensor([0, seqlen], dtype=torch.int32),
                  cu_seqlens_k=torch.tensor([0, seqlen], dtype=torch.int32),
                  max_seqlen_q=4096,
                  max_seqlen_k=4096,
                  softmax_scale=softmax_scale,
                  causal=True,
                  window_size=window_size,
                  alibi_slopes=alibi_slopes,
                  softcap=0.0,
              )   # total, nheads, headdim
              
              native_output = native_output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)    
              flash_attn_output = flash_attn_output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)    
              flash_attn_varlen_output = flash_attn_varlen_output.contiguous().view(bsz, seqlen, -1)    
              
              print("[flash_attn_output, native_output] diff max: ", (flash_attn_output - native_output).abs().max().item())
              print("[flash_attn_output, native_output] diff mean: ", (flash_attn_output - native_output).abs().mean().item())
              
              
              print("[flash_attn_varlen_output, native_output] diff max: ", (flash_attn_varlen_output - native_output).abs().max().item())
              print("[flash_attn_varlen_output, native_output] diff mean: ", (flash_attn_varlen_output - native_output).abs().mean().item())
              
             
              return self.wo(output)

`

The print result is:
[flash_attn_output, native_output] diff max: 0.294921875
[flash_attn_output, native_output] diff mean: 0.01226806640625
[flash_attn_varlen_output, native_output] diff max: 0.326171875
[flash_attn_varlen_output, native_output] diff mean: 0.0130615234375

@asahni04
Copy link

+1 even i have noticed this and cannot swap FA-2 with FA-3 for inference only

@fanfanaaaa
Copy link
Author

+1 even i have noticed this and cannot swap FA-2 with FA-3 for inference only

hi,I am not sure what you exactly mean. Why FA-2 cannot be swapped with FA-3?Is it also due to accuracy reasons?

@ZiyaoLi
Copy link

ZiyaoLi commented Dec 27, 2024

@fanfanaaaa is this issue solved? Did you check if not providing mask give you the correct result?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants