1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import Any , Dict , List , Optional , Tuple
15+ from typing import Any , Dict , Optional , Tuple
1616
1717import torch
1818import torch .nn as nn
2222from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
2323from ...utils import USE_PEFT_BACKEND , get_logger , scale_lora_layers , unscale_lora_layers
2424from ..cache_utils import CacheMixin
25- from ..embeddings import get_1d_rotary_pos_embed
2625from ..modeling_outputs import Transformer2DModelOutput
2726from ..modeling_utils import ModelMixin
2827from ..normalization import AdaLayerNormContinuous
3837logger = get_logger (__name__ ) # pylint: disable=invalid-name
3938
4039
41- class HunyuanVideoFramepackRotaryPosEmbed (nn .Module ):
42- def __init__ (self , patch_size : int , patch_size_t : int , rope_dim : List [int ], theta : float = 256.0 ) -> None :
40+ # class HunyuanVideoFramepackRotaryPosEmbed(nn.Module):
41+ # def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
42+ # super().__init__()
43+
44+ # self.patch_size = patch_size
45+ # self.patch_size_t = patch_size_t
46+ # self.rope_dim = rope_dim
47+ # self.theta = theta
48+
49+ # def forward(self, frame_indices: torch.Tensor, height: int, width: int, device: torch.device):
50+ # frame_indices = frame_indices.unbind(0)
51+ # # This is from the original code. We don't call _forward for each batch index because we know that
52+ # # each batch has the same frame indices. However, it may be possible that the frame indices don't
53+ # # always be the same for every item in a batch (such as in training). We cannot use the original
54+ # # implementation because our `apply_rotary_emb` function broadcasts across the batch dim.
55+ # # freqs = [self._forward(f, height, width, device) for f in frame_indices]
56+ # # freqs_cos, freqs_sin = zip(*freqs)
57+ # # freqs_cos = torch.stack(freqs_cos, dim=0) # [B, W * H * T, D / 2]
58+ # # freqs_sin = torch.stack(freqs_sin, dim=0) # [B, W * H * T, D / 2]
59+ # # return freqs_cos, freqs_sin
60+ # return self._forward(frame_indices[0], height, width, device)
61+
62+ # def _forward(self, frame_indices, height, width, device):
63+ # height = height // self.patch_size
64+ # width = width // self.patch_size
65+ # grid = torch.meshgrid(
66+ # frame_indices.to(device=device, dtype=torch.float32),
67+ # torch.arange(0, height, device=device, dtype=torch.float32),
68+ # torch.arange(0, width, device=device, dtype=torch.float32),
69+ # indexing="ij",
70+ # ) # 3 * [W, H, T]
71+ # grid = torch.stack(grid, dim=0) # [3, W, H, T]
72+
73+ # freqs = []
74+ # for i in range(3):
75+ # freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
76+ # freqs.append(freq)
77+
78+ # freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
79+ # freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
80+
81+ # return freqs_cos, freqs_sin
82+
83+
84+ class HunyuanVideoRotaryPosEmbed (nn .Module ):
85+ def __init__ (self , rope_dim , theta ):
4386 super ().__init__ ()
44-
45- self .patch_size = patch_size
46- self .patch_size_t = patch_size_t
47- self .rope_dim = rope_dim
87+ self .DT , self .DY , self .DX = rope_dim
4888 self .theta = theta
4989
50- def forward (self , frame_indices : torch .Tensor , height : int , width : int , device : torch .device ):
51- frame_indices = frame_indices .unbind (0 )
52- # This is from the original code. We don't call _forward for each batch index because we know that
53- # each batch has the same frame indices. However, it may be possible that the frame indices don't
54- # always be the same for every item in a batch (such as in training). We cannot use the original
55- # implementation because our `apply_rotary_emb` function broadcasts across the batch dim.
56- # freqs = [self._forward(f, height, width, device) for f in frame_indices]
57- # freqs_cos, freqs_sin = zip(*freqs)
58- # freqs_cos = torch.stack(freqs_cos, dim=0) # [B, W * H * T, D / 2]
59- # freqs_sin = torch.stack(freqs_sin, dim=0) # [B, W * H * T, D / 2]
60- # return freqs_cos, freqs_sin
61- return self ._forward (frame_indices [0 ], height , width , device )
62-
63- def _forward (self , frame_indices , height , width , device ):
64- height = height // self .patch_size
65- width = width // self .patch_size
66- grid = torch .meshgrid (
90+ @torch .no_grad ()
91+ def get_frequency (self , dim , pos ):
92+ T , H , W = pos .shape
93+ freqs = 1.0 / (
94+ self .theta ** (torch .arange (0 , dim , 2 , dtype = torch .float32 , device = pos .device )[: (dim // 2 )] / dim )
95+ )
96+ freqs = torch .outer (freqs , pos .reshape (- 1 )).unflatten (- 1 , (T , H , W )).repeat_interleave (2 , dim = 0 )
97+ return freqs .cos (), freqs .sin ()
98+
99+ @torch .no_grad ()
100+ def forward_inner (self , frame_indices , height , width , device ):
101+ # TODO(aryan)
102+ height = height // 2
103+ width = width // 2
104+ GT , GY , GX = torch .meshgrid (
67105 frame_indices .to (device = device , dtype = torch .float32 ),
68106 torch .arange (0 , height , device = device , dtype = torch .float32 ),
69107 torch .arange (0 , width , device = device , dtype = torch .float32 ),
70108 indexing = "ij" ,
71- ) # 3 * [W, H, T]
72- grid = torch .stack (grid , dim = 0 ) # [3, W, H, T]
109+ )
73110
74- freqs = []
75- for i in range (3 ):
76- freq = get_1d_rotary_pos_embed (self .rope_dim [i ], grid [i ].reshape (- 1 ), self .theta , use_real = True )
77- freqs .append (freq )
111+ FCT , FST = self .get_frequency (self .DT , GT )
112+ FCY , FSY = self .get_frequency (self .DY , GY )
113+ FCX , FSX = self .get_frequency (self .DX , GX )
78114
79- freqs_cos = torch .cat ([f [0 ] for f in freqs ], dim = 1 ) # (W * H * T, D / 2)
80- freqs_sin = torch .cat ([f [1 ] for f in freqs ], dim = 1 ) # (W * H * T, D / 2)
115+ result = torch .cat ([FCT , FCY , FCX , FST , FSY , FSX ], dim = 0 )
81116
82- return freqs_cos , freqs_sin
117+ return result .to (device )
118+
119+ @torch .no_grad ()
120+ def forward (self , frame_indices , height , width , device ):
121+ return self .forward_inner (frame_indices [0 ], height , width , device ).unsqueeze (0 )
122+ # frame_indices = frame_indices.unbind(0)
123+ # results = [self.forward_inner(f, height, width, device) for f in frame_indices]
124+ # results = torch.stack(results, dim=0)
125+ # return results
83126
84127
85128class FramepackClipVisionProjection (nn .Module ):
@@ -173,7 +216,8 @@ def __init__(
173216 )
174217
175218 # 2. RoPE
176- self .rope = HunyuanVideoFramepackRotaryPosEmbed (patch_size , patch_size_t , rope_axes_dim , rope_theta )
219+ # self.rope = HunyuanVideoFramepackRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
220+ self .rope = HunyuanVideoRotaryPosEmbed (rope_axes_dim , rope_theta )
177221
178222 # 3. Dual stream transformer blocks
179223 self .transformer_blocks = nn .ModuleList (
@@ -280,10 +324,14 @@ def forward(
280324 effective_condition_sequence_length = encoder_attention_mask .sum (dim = 1 , dtype = torch .int ) # [B,]
281325 effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
282326
283- for i in range (batch_size ):
284- attention_mask [i , : effective_sequence_length [i ]] = True
285- # [B, 1, 1, N], for broadcasting across attention heads
286- attention_mask = attention_mask .unsqueeze (1 ).unsqueeze (1 )
327+ if batch_size == 1 :
328+ encoder_hidden_states = encoder_hidden_states [:, : effective_condition_sequence_length [0 ]]
329+ attention_mask = None
330+ else :
331+ for i in range (batch_size ):
332+ attention_mask [i , : effective_sequence_length [i ]] = True
333+ # [B, 1, 1, N], for broadcasting across attention heads
334+ attention_mask = attention_mask .unsqueeze (1 ).unsqueeze (1 )
287335
288336 if torch .is_grad_enabled () and self .gradient_checkpointing :
289337 for block in self .transformer_blocks :
@@ -345,8 +393,7 @@ def _pack_history_states(
345393 image_rotary_emb = self .rope (
346394 frame_indices = indices_latents , height = height , width = width , device = hidden_states .device
347395 )
348- image_rotary_emb = list (image_rotary_emb ) # convert tuple to list for in-place modification
349- pph , ppw = height // self .config .patch_size , width // self .config .patch_size
396+ image_rotary_emb = image_rotary_emb .flatten (2 ).transpose (1 , 2 )
350397
351398 latents_clean , latents_history_2x , latents_history_4x = self .clean_x_embedder (
352399 latents_clean , latents_history_2x , latents_history_4x
@@ -358,34 +405,93 @@ def _pack_history_states(
358405 image_rotary_emb_clean = self .rope (
359406 frame_indices = indices_latents_clean , height = height , width = width , device = latents_clean .device
360407 )
361- image_rotary_emb [ 0 ] = torch . cat ([ image_rotary_emb_clean [ 0 ], image_rotary_emb [ 0 ]], dim = 0 )
362- image_rotary_emb [ 1 ] = torch .cat ([image_rotary_emb_clean [ 1 ] , image_rotary_emb [ 1 ]] , dim = 0 )
408+ image_rotary_emb_clean = image_rotary_emb_clean . flatten ( 2 ). transpose ( 1 , 2 )
409+ image_rotary_emb = torch .cat ([image_rotary_emb_clean , image_rotary_emb ] , dim = 1 )
363410
364411 if latents_history_2x is not None and indices_latents_history_2x is not None :
365412 hidden_states = torch .cat ([latents_history_2x , hidden_states ], dim = 1 )
366413
367414 image_rotary_emb_history_2x = self .rope (
368415 frame_indices = indices_latents_history_2x , height = height , width = width , device = latents_history_2x .device
369416 )
370- image_rotary_emb_history_2x = self ._pad_rotary_emb (
371- image_rotary_emb_history_2x , indices_latents_history_2x .size (1 ), pph , ppw , (2 , 2 , 2 )
372- )
373- image_rotary_emb [0 ] = torch .cat ([image_rotary_emb_history_2x [0 ], image_rotary_emb [0 ]], dim = 0 )
374- image_rotary_emb [1 ] = torch .cat ([image_rotary_emb_history_2x [1 ], image_rotary_emb [1 ]], dim = 0 )
417+ image_rotary_emb_history_2x = _pad_for_3d_conv (image_rotary_emb_history_2x , (2 , 2 , 2 ))
418+ image_rotary_emb_history_2x = _center_down_sample_3d (image_rotary_emb_history_2x , (2 , 2 , 2 ))
419+ image_rotary_emb_history_2x = image_rotary_emb_history_2x .flatten (2 ).transpose (1 , 2 )
420+ image_rotary_emb = torch .cat ([image_rotary_emb_history_2x , image_rotary_emb ], dim = 1 )
375421
376422 if latents_history_4x is not None and indices_latents_history_4x is not None :
377423 hidden_states = torch .cat ([latents_history_4x , hidden_states ], dim = 1 )
378424
379425 image_rotary_emb_history_4x = self .rope (
380426 frame_indices = indices_latents_history_4x , height = height , width = width , device = latents_history_4x .device
381427 )
382- image_rotary_emb_history_4x = self ._pad_rotary_emb (
383- image_rotary_emb_history_4x , indices_latents_history_4x .size (1 ), pph , ppw , (4 , 4 , 4 )
384- )
385- image_rotary_emb [0 ] = torch .cat ([image_rotary_emb_history_4x [0 ], image_rotary_emb [0 ]], dim = 0 )
386- image_rotary_emb [1 ] = torch .cat ([image_rotary_emb_history_4x [1 ], image_rotary_emb [1 ]], dim = 0 )
387-
388- return hidden_states , image_rotary_emb
428+ image_rotary_emb_history_4x = _pad_for_3d_conv (image_rotary_emb_history_4x , (4 , 4 , 4 ))
429+ image_rotary_emb_history_4x = _center_down_sample_3d (image_rotary_emb_history_4x , (4 , 4 , 4 ))
430+ image_rotary_emb_history_4x = image_rotary_emb_history_4x .flatten (2 ).transpose (1 , 2 )
431+ image_rotary_emb = torch .cat ([image_rotary_emb_history_4x , image_rotary_emb ], dim = 1 )
432+
433+ return hidden_states , image_rotary_emb .squeeze (0 ).chunk (2 , dim = - 1 )
434+
435+ # def _pack_history_states(
436+ # self,
437+ # hidden_states: torch.Tensor,
438+ # indices_latents: torch.Tensor,
439+ # latents_clean: Optional[torch.Tensor] = None,
440+ # latents_history_2x: Optional[torch.Tensor] = None,
441+ # latents_history_4x: Optional[torch.Tensor] = None,
442+ # indices_latents_clean: Optional[torch.Tensor] = None,
443+ # indices_latents_history_2x: Optional[torch.Tensor] = None,
444+ # indices_latents_history_4x: Optional[torch.Tensor] = None,
445+ # ):
446+ # batch_size, num_channels, num_frames, height, width = hidden_states.shape
447+ # if indices_latents is None:
448+ # indices_latents = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1)
449+
450+ # hidden_states = self.x_embedder(hidden_states)
451+ # image_rotary_emb = self.rope(
452+ # frame_indices=indices_latents, height=height, width=width, device=hidden_states.device
453+ # )
454+ # image_rotary_emb = list(image_rotary_emb) # convert tuple to list for in-place modification
455+ # pph, ppw = height // self.config.patch_size, width // self.config.patch_size
456+
457+ # latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder(
458+ # latents_clean, latents_history_2x, latents_history_4x
459+ # )
460+
461+ # if latents_clean is not None:
462+ # hidden_states = torch.cat([latents_clean, hidden_states], dim=1)
463+
464+ # image_rotary_emb_clean = self.rope(
465+ # frame_indices=indices_latents_clean, height=height, width=width, device=latents_clean.device
466+ # )
467+ # image_rotary_emb[0] = torch.cat([image_rotary_emb_clean[0], image_rotary_emb[0]], dim=0)
468+ # image_rotary_emb[1] = torch.cat([image_rotary_emb_clean[1], image_rotary_emb[1]], dim=0)
469+
470+ # if latents_history_2x is not None and indices_latents_history_2x is not None:
471+ # hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1)
472+
473+ # image_rotary_emb_history_2x = self.rope(
474+ # frame_indices=indices_latents_history_2x, height=height, width=width, device=latents_history_2x.device
475+ # )
476+ # image_rotary_emb_history_2x = self._pad_rotary_emb(
477+ # image_rotary_emb_history_2x, indices_latents_history_2x.size(1), pph, ppw, (2, 2, 2)
478+ # )
479+ # image_rotary_emb[0] = torch.cat([image_rotary_emb_history_2x[0], image_rotary_emb[0]], dim=0)
480+ # image_rotary_emb[1] = torch.cat([image_rotary_emb_history_2x[1], image_rotary_emb[1]], dim=0)
481+
482+ # if latents_history_4x is not None and indices_latents_history_4x is not None:
483+ # hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1)
484+
485+ # image_rotary_emb_history_4x = self.rope(
486+ # frame_indices=indices_latents_history_4x, height=height, width=width, device=latents_history_4x.device
487+ # )
488+ # image_rotary_emb_history_4x = self._pad_rotary_emb(
489+ # image_rotary_emb_history_4x, indices_latents_history_4x.size(1), pph, ppw, (4, 4, 4)
490+ # )
491+ # image_rotary_emb[0] = torch.cat([image_rotary_emb_history_4x[0], image_rotary_emb[0]], dim=0)
492+ # image_rotary_emb[1] = torch.cat([image_rotary_emb_history_4x[1], image_rotary_emb[1]], dim=0)
493+
494+ # return hidden_states, image_rotary_emb
389495
390496 def _pad_rotary_emb (
391497 self ,
0 commit comments