Skip to content

Commit f5597f1

Browse files
author
Haithem Turki
committed
more tweaks
1 parent 4f69408 commit f5597f1

File tree

4 files changed

+18
-12
lines changed

4 files changed

+18
-12
lines changed

pipeline/streaming_switch_training.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,11 @@ def generate_chunk_with_cache(
234234
return output, denoised_timestep_from, denoised_timestep_to
235235

236236
def _recache_after_switch(self, output, current_start_frame, new_conditional_dict, local_start_frame=None, switch_recache_frames=None):
237+
for block_idx in range(self.num_transformer_blocks):
238+
cache = self.kv_cache1[block_idx]
239+
# update local end index pointer so that we rebuild the cache from the beginning
240+
cache["local_end_index"].fill_(cache["local_end_index"].item() - self.frame_seq_length * self.slice_last_frames)
241+
237242
# reset cross-attention cache
238243
for blk in self.crossattn_cache:
239244
blk["k"].zero_()
@@ -244,19 +249,19 @@ def _recache_after_switch(self, output, current_start_frame, new_conditional_dic
244249
return
245250

246251
if switch_recache_frames is not None:
247-
frames_to_recache = torch.cat([switch_recache_frames, output], dim=1)[:, -21:, ...]
252+
frames_to_recache = torch.cat([switch_recache_frames, output], dim=1)[:, -self.local_attn_size:, ...]
248253
num_recache_frames = frames_to_recache.shape[1]
249254
if DEBUG and (not dist.is_initialized() or dist.get_rank() == 0):
250255
print(f"[SeqTrain-DMDSwitch] Using external switch_recache_frames (previous_frames): {frames_to_recache.shape}")
251256
else:
252257
# Determine how to fetch frames based on whether local_start_frame is provided
253258
if local_start_frame is not None:
254259
# Chunk mode: output is the current chunk's output; use relative coordinates
255-
num_recache_frames = min(local_start_frame, 21)
260+
num_recache_frames = min(local_start_frame, self.local_attn_size)
256261
frames_to_recache = output[:, -num_recache_frames:]
257262
else:
258263
# Full sequence mode: output is the complete sequence; use absolute coordinates
259-
num_recache_frames = min(current_start_frame, 21)
264+
num_recache_frames = min(current_start_frame, self.local_attn_size)
260265
frames_to_recache = output[:, -num_recache_frames:]
261266

262267
batch_size, num_recache_frames, c, h, w = frames_to_recache.shape

pipeline/streaming_training.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(self,
4848
self.global_sink = kwargs.get("global_sink", False)
4949

5050
slice_last_frames: int = int(kwargs.get("slice_last_frames", 21))
51+
self.slice_last_frames = slice_last_frames
5152
self.kv_cache_size = (self.local_attn_size + slice_last_frames) * self.frame_seq_length
5253
if DEBUG:
5354
print(f"[KV policy] local_attn_size={self.local_attn_size} slice_last_frames={slice_last_frames} -> kv_frames={self.kv_cache_size}")

pipeline/switch_causal_inference.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,17 @@ def _recache_after_switch(self, output, current_start_frame, new_conditional_dic
6363
local_attn_size=self.local_attn_size
6464
)
6565

66-
context_timestep = torch.ones([batch_size, recompute_frames],
66+
context_timestep = torch.ones([batch_size, num_recache_frames],
6767
device=device, dtype=torch.int64) * self.args.context_noise
6868

6969
with torch.no_grad():
7070
self.generator(
71-
noisy_image_or_video=frames_to_recompute,
71+
noisy_image_or_video=frames_to_recache,
7272
conditional_dict=new_conditional_dict,
7373
timestep=context_timestep,
7474
kv_cache=self.kv_cache1,
7575
crossattn_cache=self.crossattn_cache,
76-
current_start=recompute_start_frame * self.frame_seq_length,
76+
current_start=recache_start_frame * self.frame_seq_length,
7777
block_mask=block_mask,
7878
)
7979

@@ -166,7 +166,7 @@ def inference(
166166
else:
167167
cond_in_use = cond_second if using_second else cond_first
168168

169-
noisy_input = noise[:, current_start_frame - num_input_frames : current_start_frame + current_num_frames - num_input_frames]
169+
noisy_input = noise[:, current_start_frame - (1 if initial_latent is not None else 0) : current_start_frame + current_num_frames - (1 if initial_latent is not None else 0)]
170170

171171
# Spatial denoising loop (same as parent but uses cond_in_use)
172172
for index, current_timestep in enumerate(self.denoising_step_list):

wan/modules/causal_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def qkv_fn(x):
227227

228228
# Compute cache update parameters without modifying kv_cache directly
229229
cache_update_info = None
230-
is_recompute = block_mask is not None
230+
is_recompute = current_end <= kv_cache["global_end_index"].item() and current_start > 0
231231
if self.local_attn_size != -1 and (current_end > kv_cache["global_end_index"].item()) and (
232232
num_new_tokens + kv_cache["local_end_index"].item() > kv_cache_size):
233233
# Calculate the number of new tokens added in this step
@@ -257,8 +257,8 @@ def qkv_fn(x):
257257
temp_v[:, sink_tokens + num_evicted_tokens:sink_tokens + num_evicted_tokens + num_rolled_tokens].clone()
258258

259259
# Insert new key/value into the temporary cache
260-
# Protect sink_tokens only during recomputation; regular forward generation allows writing into the initial sink region
261-
write_start_index = max(local_start_index, sink_tokens) if (is_recompute and kv_cache.get("global_sink", False)) else local_start_index
260+
# Protect sink_tokens only during recaching; regular forward generation allows writing into the initial sink region
261+
write_start_index = max(local_start_index, sink_tokens) if ((block_mask is not None) and kv_cache.get("global_sink", False)) else local_start_index
262262
roped_offset = max(0, write_start_index - local_start_index)
263263
write_len = max(0, local_end_index - write_start_index)
264264
if write_len > 0:
@@ -291,8 +291,8 @@ def qkv_fn(x):
291291
# Construct full k, v for attention computation (without modifying the original cache)
292292
temp_k = kv_cache["k"].clone()
293293
temp_v = kv_cache["v"].clone()
294-
# Protect sink_tokens only during recomputation; regular forward generation allows writing into the initial sink region
295-
write_start_index = max(local_start_index, sink_tokens) if (is_recompute and kv_cache.get("global_sink", False)) else local_start_index
294+
# Protect sink_tokens only during recaching; regular forward generation allows writing into the initial sink region
295+
write_start_index = max(local_start_index, sink_tokens) if ((block_mask is not None) and kv_cache.get("global_sink", False)) else local_start_index
296296
roped_offset = max(0, write_start_index - local_start_index)
297297
write_len = max(0, local_end_index - write_start_index)
298298
if write_len > 0:

0 commit comments

Comments
 (0)