You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
There is a discrepancy in outputs when using flash_attn_with_kvcache given following conditions:
full k_cache / v_cache with cache_seqlens vs sliced k_cache / v_cache with same cache_seqlens
full k_cache / v_cache with cache_seqlens vs k_cache / v_cache sliced to cache_seqlens
My understanding is that cache_seqlens is used to indicate the length of the kv_cache to attend to, so the above cases should give identical outputs since both are attending to the same ctx lengths:
In the first case, the only difference is the length of kv_cache being passed into the function (same cache_seqlens); both cache lens are longer than cache_seqlens
In the second case, kv_cache is sliced to cache_seqlens and no cache_seqlens arg is passed to the function
Both give the same abs max diff.
Moreover, if I pass in pass in increasing slices of the kv_cache (starting from cache_seqlens) but with the same cache_seqlens arg, the discrepancies are the same until a certain threshold, after which the results agree.
E.g., for {k,v}_cache shape [1, 262144, 32, 128] and cache_seqlens 40960, if I pass in increasing slices of the kv_cache starting from 40960 {k,v}_cache[:, :ctx_len,:] where ctx_len = 40960, ..., 106496, 172032, 262144 and compare with a reference output with full kv_cache and cache_seqlens 40960, the outputs differ for ctx_len < 172032.
Tested with on H100 with: torch 2.5 flash_attn 2.4.2 python 3.10.2
@tridao
There is a discrepancy in outputs when using
flash_attn_with_kvcache
given following conditions:k_cache
/v_cache
withcache_seqlens
vs slicedk_cache
/v_cache
with samecache_seqlens
k_cache
/v_cache
withcache_seqlens
vsk_cache
/v_cache
sliced tocache_seqlens
My understanding is that
cache_seqlens
is used to indicate the length of thekv_cache
to attend to, so the above cases should give identical outputs since both are attending to the same ctx lengths:kv_cache
being passed into the function (samecache_seqlens
); both cache lens are longer thancache_seqlens
kv_cache
is sliced tocache_seqlens
and nocache_seqlens
arg is passed to the functionBoth give the same abs max diff.
Moreover, if I pass in pass in increasing slices of the
kv_cache
(starting fromcache_seqlens
) but with the samecache_seqlens
arg, the discrepancies are the same until a certain threshold, after which the results agree.E.g., for
{k,v}_cache
shape[1, 262144, 32, 128]
andcache_seqlens
40960, if I pass in increasing slices of thekv_cache
starting from 40960{k,v}_cache[:, :ctx_len,:]
wherectx_len = 40960, ..., 106496, 172032, 262144
and compare with a reference output with fullkv_cache
andcache_seqlens
40960, the outputs differ forctx_len < 172032
.Tested with on
H100
with:torch 2.5
flash_attn 2.4.2
python 3.10.2
Minimal script to reproduce:
The text was updated successfully, but these errors were encountered: