Skip to content

Potential Incorrect KV Cache Update Logic? #345

@lukezhuz

Description

@lukezhuz

In

# Use internal cache with the same logic as before
if kv_cache is not None:
cusum_vk, cumsum_k_sum = kv_cache[0], kv_cache[1]
if save_kv_cache:
kv_cache[0] = vk.detach().clone()
kv_cache[1] = k_sum.detach().clone()
if cusum_vk is not None and cumsum_k_sum is not None:
# Add accumulated cache from previous chunks
vk = vk + cusum_vk
k_sum = k_sum + cumsum_k_sum
, we update the KV Cache using current token's kv result and k's sum. Shouldn't we update the kv cache after we accumulate kv's result and sum's result? What I mean is the corrected code may look like below

# ... inside forward ...

# 1. Calculate Local Update
vk = torch.matmul(v, k_rotated.transpose(-1, -2))
k_sum = k.sum(dim=-1, keepdim=True).transpose(-2, -1)

if kv_cache is not None:
    cusum_vk, cumsum_k_sum = kv_cache[0], kv_cache[1]

    # 2. Add History FIRST (Update the Running Total)
    if cusum_vk is not None and cumsum_k_sum is not None:
        vk = vk + cusum_vk
        k_sum = k_sum + cumsum_k_sum

    # 3. Save the NEW TOTAL to Cache
    if save_kv_cache:
        kv_cache[0] = vk.detach().clone()
        kv_cache[1] = k_sum.detach().clone()

# 4. Compute Output
z = 1 / (k_sum @ q + self.eps)
out = torch.matmul(vk, q_rotated)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions