Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flash_attn_with_kvcache discrepancy slicing kv_cache / cache_seqlens #1417

Open
jeromeku opened this issue Jan 1, 2025 · 0 comments
Open

Comments

@jeromeku
Copy link

jeromeku commented Jan 1, 2025

@tridao

There is a discrepancy in outputs when using flash_attn_with_kvcache given following conditions:

  • full k_cache / v_cache with cache_seqlens vs sliced k_cache / v_cache with same cache_seqlens
  • full k_cache / v_cache with cache_seqlens vs k_cache / v_cache sliced to cache_seqlens

My understanding is that cache_seqlens is used to indicate the length of the kv_cache to attend to, so the above cases should give identical outputs since both are attending to the same ctx lengths:

  • In the first case, the only difference is the length of kv_cache being passed into the function (same cache_seqlens); both cache lens are longer than cache_seqlens
  • In the second case, kv_cache is sliced to cache_seqlens and no cache_seqlens arg is passed to the function

Both give the same abs max diff.

Moreover, if I pass in pass in increasing slices of the kv_cache (starting from cache_seqlens) but with the same cache_seqlens arg, the discrepancies are the same until a certain threshold, after which the results agree.

E.g., for {k,v}_cache shape [1, 262144, 32, 128] and cache_seqlens 40960, if I pass in increasing slices of the kv_cache starting from 40960 {k,v}_cache[:, :ctx_len,:] where ctx_len = 40960, ..., 106496, 172032, 262144 and compare with a reference output with full kv_cache and cache_seqlens 40960, the outputs differ for ctx_len < 172032.

Tested with on H100 with:
torch 2.5
flash_attn 2.4.2
python 3.10.2

Minimal script to reproduce:

import torch
from flash_attn import flash_attn_with_kvcache

torch.manual_seed(0)

dtype = torch.bfloat16
bs = 1
qlen = 1
nheads = 32
d = 128
cache_seqlens = 40960
full_cache_len = 262144

q = torch.randn(bs, qlen, nheads, d, device="cuda", dtype=dtype)
kv_cache = torch.randn(bs, full_cache_len, 2, nheads, d, device="cuda", dtype=dtype)
rotary_interleaved = False

k_cache = kv_cache[:, :, 0].contiguous()
v_cache = kv_cache[:, :, 1].contiguous()

rotary_interleaved = False

#print(f"Test params: {q.shape=}, {k_cache.shape=}, {v_cache.shape=} {full_cache_len=} {cache_seqlens=} {rotary_interleaved=} {dtype=}")

ref_out = flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    cache_seqlens=cache_seqlens,
    softmax_scale=None,
    causal=True,
    rotary_interleaved=rotary_interleaved,
    alibi_slopes=None,
)
 

# -------------------------------------------------------------------------------------------------

print("-"*100)
print("Testing at various cache lengths...\n")
cache_lens = [min(cache_seqlens + ctx, full_cache_len) for ctx in [0, 256, 512, 65536, 131072, 262144]]

for cache_len in cache_lens:    
    k_cache_ = k_cache.clone()
    v_cache_ = v_cache.clone()
  
    k_cache_ = k_cache_[:, :cache_len, :].contiguous()
    v_cache_ = v_cache_[:, :cache_len, :].contiguous()
  
    print(f"Test params: {q.shape=}, {k_cache_.shape=}, {v_cache_.shape=} {full_cache_len=} {cache_seqlens=} {rotary_interleaved=} {dtype=}")
    test_out = flash_attn_with_kvcache(
        q,
        k_cache_,
        v_cache_,
        cache_seqlens=cache_seqlens,
        softmax_scale=None,
        causal=True,
        rotary_interleaved=rotary_interleaved,
        alibi_slopes=None,
    )

    if not torch.allclose(ref_out, test_out):
        print(f"Cache len {cache_len}: Failed: {torch.abs(ref_out - test_out).max()}") 
    else:    
        print(f"Cache len {cache_len}: Passed")

    print()
# -------------------------------------------------------------------------------------------------

print("-"*100)
print("Testing with zeroed out cache...\n")

k_cache_ = k_cache.clone()
v_cache_ = v_cache.clone()
k_cache_[:,cache_seqlens:,:] = 0
v_cache_[:,cache_seqlens:,:] = 0

print(f"Test params: {q.shape=}, {k_cache_.shape=}, {v_cache_.shape=} {full_cache_len=} {cache_seqlens=} {rotary_interleaved=} {dtype=}")
test_out = flash_attn_with_kvcache(
    q,
    k_cache_,
    v_cache_,
    cache_seqlens=cache_seqlens,
    softmax_scale=None,
    causal=True,
    rotary_interleaved=rotary_interleaved,
    alibi_slopes=None,
)

if not torch.allclose(ref_out, test_out):
    print(f"Failed: {torch.abs(ref_out - test_out).max()}") 
else:
    print(f"Passed")

# -------------------------------------------------------------------------------------------------

print("-"*100)
print("Testing with random cache...\n")

k_cache_ = k_cache.clone()
v_cache_ = v_cache.clone()
k_cache_[:,cache_seqlens:,:] = torch.randn_like(k_cache_[:,cache_seqlens:,:])
v_cache_[:,cache_seqlens:,:] = torch.randn_like(v_cache_[:,cache_seqlens:,:])

print(f"Test params: {q.shape=}, {k_cache_.shape=}, {v_cache_.shape=} {full_cache_len=} {cache_seqlens=} {rotary_interleaved=} {dtype=}")
test_out = flash_attn_with_kvcache(
    q,
    k_cache_,
    v_cache_,
    cache_seqlens=cache_seqlens,
    softmax_scale=None,
    causal=True,
    rotary_interleaved=rotary_interleaved,
    alibi_slopes=None,
)

if not torch.allclose(ref_out, test_out):
    print(f"Failed: {torch.abs(ref_out - test_out).max()}") 
else:
    print(f"Passed")

# -------------------------------------------------------------------------------------------------

print("-"*100)
print("Testing with no cache_seqlens...\n")
k_cache_ = k_cache.clone()
v_cache_ = v_cache.clone()
k_cache_ = k_cache_[:,:cache_seqlens,:].contiguous()
v_cache_ = v_cache_[:,:cache_seqlens,:].contiguous()
cache_seqlens_ = None

print(f"Test params: {q.shape=}, {k_cache_.shape=}, {v_cache_.shape=} {full_cache_len=} {cache_seqlens_=} {rotary_interleaved=} {dtype=}")
test_out = flash_attn_with_kvcache(
    q,
    k_cache_,
    v_cache_,
    cache_seqlens=cache_seqlens_,
    softmax_scale=None,
    causal=True,
    rotary_interleaved=rotary_interleaved,
    alibi_slopes=None,
)

if not torch.allclose(ref_out, test_out):
    print(f"Failed: {torch.abs(ref_out - test_out).max()}") 
else:
    print(f"Passed")
print()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant