@@ -234,6 +234,11 @@ def generate_chunk_with_cache(
234234 return output , denoised_timestep_from , denoised_timestep_to
235235
236236 def _recache_after_switch (self , output , current_start_frame , new_conditional_dict , local_start_frame = None , switch_recache_frames = None ):
237+ for block_idx in range (self .num_transformer_blocks ):
238+ cache = self .kv_cache1 [block_idx ]
239+ # update local end index pointer so that we rebuild the cache from the beginning
240+ cache ["local_end_index" ].fill_ (cache ["local_end_index" ].item () - self .frame_seq_length * self .slice_last_frames )
241+
237242 # reset cross-attention cache
238243 for blk in self .crossattn_cache :
239244 blk ["k" ].zero_ ()
@@ -244,19 +249,19 @@ def _recache_after_switch(self, output, current_start_frame, new_conditional_dic
244249 return
245250
246251 if switch_recache_frames is not None :
247- frames_to_recache = torch .cat ([switch_recache_frames , output ], dim = 1 )[:, - 21 :, ...]
252+ frames_to_recache = torch .cat ([switch_recache_frames , output ], dim = 1 )[:, - self . local_attn_size :, ...]
248253 num_recache_frames = frames_to_recache .shape [1 ]
249254 if DEBUG and (not dist .is_initialized () or dist .get_rank () == 0 ):
250255 print (f"[SeqTrain-DMDSwitch] Using external switch_recache_frames (previous_frames): { frames_to_recache .shape } " )
251256 else :
252257 # Determine how to fetch frames based on whether local_start_frame is provided
253258 if local_start_frame is not None :
254259 # Chunk mode: output is the current chunk's output; use relative coordinates
255- num_recache_frames = min (local_start_frame , 21 )
260+ num_recache_frames = min (local_start_frame , self . local_attn_size )
256261 frames_to_recache = output [:, - num_recache_frames :]
257262 else :
258263 # Full sequence mode: output is the complete sequence; use absolute coordinates
259- num_recache_frames = min (current_start_frame , 21 )
264+ num_recache_frames = min (current_start_frame , self . local_attn_size )
260265 frames_to_recache = output [:, - num_recache_frames :]
261266
262267 batch_size , num_recache_frames , c , h , w = frames_to_recache .shape
0 commit comments