-
Notifications
You must be signed in to change notification settings - Fork 337
Open
Description
In
Sana/diffusion/model/nets/sana_blocks.py
Lines 547 to 559 in 95c9e7c
| # Use internal cache with the same logic as before | |
| if kv_cache is not None: | |
| cusum_vk, cumsum_k_sum = kv_cache[0], kv_cache[1] | |
| if save_kv_cache: | |
| kv_cache[0] = vk.detach().clone() | |
| kv_cache[1] = k_sum.detach().clone() | |
| if cusum_vk is not None and cumsum_k_sum is not None: | |
| # Add accumulated cache from previous chunks | |
| vk = vk + cusum_vk | |
| k_sum = k_sum + cumsum_k_sum |
# ... inside forward ...
# 1. Calculate Local Update
vk = torch.matmul(v, k_rotated.transpose(-1, -2))
k_sum = k.sum(dim=-1, keepdim=True).transpose(-2, -1)
if kv_cache is not None:
cusum_vk, cumsum_k_sum = kv_cache[0], kv_cache[1]
# 2. Add History FIRST (Update the Running Total)
if cusum_vk is not None and cumsum_k_sum is not None:
vk = vk + cusum_vk
k_sum = k_sum + cumsum_k_sum
# 3. Save the NEW TOTAL to Cache
if save_kv_cache:
kv_cache[0] = vk.detach().clone()
kv_cache[1] = k_sum.detach().clone()
# 4. Compute Output
z = 1 / (k_sum @ q + self.eps)
out = torch.matmul(vk, q_rotated)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels