Skip to content

Commit 88e8dd3

Browse files
committed
debug
1 parent 6e1b557 commit 88e8dd3

File tree

2 files changed

+169
-67
lines changed

2 files changed

+169
-67
lines changed

src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py

Lines changed: 161 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Dict, List, Optional, Tuple
15+
from typing import Any, Dict, Optional, Tuple
1616

1717
import torch
1818
import torch.nn as nn
@@ -22,7 +22,6 @@
2222
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
2323
from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers
2424
from ..cache_utils import CacheMixin
25-
from ..embeddings import get_1d_rotary_pos_embed
2625
from ..modeling_outputs import Transformer2DModelOutput
2726
from ..modeling_utils import ModelMixin
2827
from ..normalization import AdaLayerNormContinuous
@@ -38,48 +37,92 @@
3837
logger = get_logger(__name__) # pylint: disable=invalid-name
3938

4039

41-
class HunyuanVideoFramepackRotaryPosEmbed(nn.Module):
42-
def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
40+
# class HunyuanVideoFramepackRotaryPosEmbed(nn.Module):
41+
# def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
42+
# super().__init__()
43+
44+
# self.patch_size = patch_size
45+
# self.patch_size_t = patch_size_t
46+
# self.rope_dim = rope_dim
47+
# self.theta = theta
48+
49+
# def forward(self, frame_indices: torch.Tensor, height: int, width: int, device: torch.device):
50+
# frame_indices = frame_indices.unbind(0)
51+
# # This is from the original code. We don't call _forward for each batch index because we know that
52+
# # each batch has the same frame indices. However, it may be possible that the frame indices don't
53+
# # always be the same for every item in a batch (such as in training). We cannot use the original
54+
# # implementation because our `apply_rotary_emb` function broadcasts across the batch dim.
55+
# # freqs = [self._forward(f, height, width, device) for f in frame_indices]
56+
# # freqs_cos, freqs_sin = zip(*freqs)
57+
# # freqs_cos = torch.stack(freqs_cos, dim=0) # [B, W * H * T, D / 2]
58+
# # freqs_sin = torch.stack(freqs_sin, dim=0) # [B, W * H * T, D / 2]
59+
# # return freqs_cos, freqs_sin
60+
# return self._forward(frame_indices[0], height, width, device)
61+
62+
# def _forward(self, frame_indices, height, width, device):
63+
# height = height // self.patch_size
64+
# width = width // self.patch_size
65+
# grid = torch.meshgrid(
66+
# frame_indices.to(device=device, dtype=torch.float32),
67+
# torch.arange(0, height, device=device, dtype=torch.float32),
68+
# torch.arange(0, width, device=device, dtype=torch.float32),
69+
# indexing="ij",
70+
# ) # 3 * [W, H, T]
71+
# grid = torch.stack(grid, dim=0) # [3, W, H, T]
72+
73+
# freqs = []
74+
# for i in range(3):
75+
# freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
76+
# freqs.append(freq)
77+
78+
# freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
79+
# freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
80+
81+
# return freqs_cos, freqs_sin
82+
83+
84+
class HunyuanVideoRotaryPosEmbed(nn.Module):
85+
def __init__(self, rope_dim, theta):
4386
super().__init__()
44-
45-
self.patch_size = patch_size
46-
self.patch_size_t = patch_size_t
47-
self.rope_dim = rope_dim
87+
self.DT, self.DY, self.DX = rope_dim
4888
self.theta = theta
4989

50-
def forward(self, frame_indices: torch.Tensor, height: int, width: int, device: torch.device):
51-
frame_indices = frame_indices.unbind(0)
52-
# This is from the original code. We don't call _forward for each batch index because we know that
53-
# each batch has the same frame indices. However, it may be possible that the frame indices don't
54-
# always be the same for every item in a batch (such as in training). We cannot use the original
55-
# implementation because our `apply_rotary_emb` function broadcasts across the batch dim.
56-
# freqs = [self._forward(f, height, width, device) for f in frame_indices]
57-
# freqs_cos, freqs_sin = zip(*freqs)
58-
# freqs_cos = torch.stack(freqs_cos, dim=0) # [B, W * H * T, D / 2]
59-
# freqs_sin = torch.stack(freqs_sin, dim=0) # [B, W * H * T, D / 2]
60-
# return freqs_cos, freqs_sin
61-
return self._forward(frame_indices[0], height, width, device)
62-
63-
def _forward(self, frame_indices, height, width, device):
64-
height = height // self.patch_size
65-
width = width // self.patch_size
66-
grid = torch.meshgrid(
90+
@torch.no_grad()
91+
def get_frequency(self, dim, pos):
92+
T, H, W = pos.shape
93+
freqs = 1.0 / (
94+
self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim)
95+
)
96+
freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0)
97+
return freqs.cos(), freqs.sin()
98+
99+
@torch.no_grad()
100+
def forward_inner(self, frame_indices, height, width, device):
101+
# TODO(aryan)
102+
height = height // 2
103+
width = width // 2
104+
GT, GY, GX = torch.meshgrid(
67105
frame_indices.to(device=device, dtype=torch.float32),
68106
torch.arange(0, height, device=device, dtype=torch.float32),
69107
torch.arange(0, width, device=device, dtype=torch.float32),
70108
indexing="ij",
71-
) # 3 * [W, H, T]
72-
grid = torch.stack(grid, dim=0) # [3, W, H, T]
109+
)
73110

74-
freqs = []
75-
for i in range(3):
76-
freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
77-
freqs.append(freq)
111+
FCT, FST = self.get_frequency(self.DT, GT)
112+
FCY, FSY = self.get_frequency(self.DY, GY)
113+
FCX, FSX = self.get_frequency(self.DX, GX)
78114

79-
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
80-
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
115+
result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0)
81116

82-
return freqs_cos, freqs_sin
117+
return result.to(device)
118+
119+
@torch.no_grad()
120+
def forward(self, frame_indices, height, width, device):
121+
return self.forward_inner(frame_indices[0], height, width, device).unsqueeze(0)
122+
# frame_indices = frame_indices.unbind(0)
123+
# results = [self.forward_inner(f, height, width, device) for f in frame_indices]
124+
# results = torch.stack(results, dim=0)
125+
# return results
83126

84127

85128
class FramepackClipVisionProjection(nn.Module):
@@ -173,7 +216,8 @@ def __init__(
173216
)
174217

175218
# 2. RoPE
176-
self.rope = HunyuanVideoFramepackRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
219+
# self.rope = HunyuanVideoFramepackRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
220+
self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta)
177221

178222
# 3. Dual stream transformer blocks
179223
self.transformer_blocks = nn.ModuleList(
@@ -280,10 +324,14 @@ def forward(
280324
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
281325
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
282326

283-
for i in range(batch_size):
284-
attention_mask[i, : effective_sequence_length[i]] = True
285-
# [B, 1, 1, N], for broadcasting across attention heads
286-
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
327+
if batch_size == 1:
328+
encoder_hidden_states = encoder_hidden_states[:, : effective_condition_sequence_length[0]]
329+
attention_mask = None
330+
else:
331+
for i in range(batch_size):
332+
attention_mask[i, : effective_sequence_length[i]] = True
333+
# [B, 1, 1, N], for broadcasting across attention heads
334+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
287335

288336
if torch.is_grad_enabled() and self.gradient_checkpointing:
289337
for block in self.transformer_blocks:
@@ -345,8 +393,7 @@ def _pack_history_states(
345393
image_rotary_emb = self.rope(
346394
frame_indices=indices_latents, height=height, width=width, device=hidden_states.device
347395
)
348-
image_rotary_emb = list(image_rotary_emb) # convert tuple to list for in-place modification
349-
pph, ppw = height // self.config.patch_size, width // self.config.patch_size
396+
image_rotary_emb = image_rotary_emb.flatten(2).transpose(1, 2)
350397

351398
latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder(
352399
latents_clean, latents_history_2x, latents_history_4x
@@ -358,34 +405,93 @@ def _pack_history_states(
358405
image_rotary_emb_clean = self.rope(
359406
frame_indices=indices_latents_clean, height=height, width=width, device=latents_clean.device
360407
)
361-
image_rotary_emb[0] = torch.cat([image_rotary_emb_clean[0], image_rotary_emb[0]], dim=0)
362-
image_rotary_emb[1] = torch.cat([image_rotary_emb_clean[1], image_rotary_emb[1]], dim=0)
408+
image_rotary_emb_clean = image_rotary_emb_clean.flatten(2).transpose(1, 2)
409+
image_rotary_emb = torch.cat([image_rotary_emb_clean, image_rotary_emb], dim=1)
363410

364411
if latents_history_2x is not None and indices_latents_history_2x is not None:
365412
hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1)
366413

367414
image_rotary_emb_history_2x = self.rope(
368415
frame_indices=indices_latents_history_2x, height=height, width=width, device=latents_history_2x.device
369416
)
370-
image_rotary_emb_history_2x = self._pad_rotary_emb(
371-
image_rotary_emb_history_2x, indices_latents_history_2x.size(1), pph, ppw, (2, 2, 2)
372-
)
373-
image_rotary_emb[0] = torch.cat([image_rotary_emb_history_2x[0], image_rotary_emb[0]], dim=0)
374-
image_rotary_emb[1] = torch.cat([image_rotary_emb_history_2x[1], image_rotary_emb[1]], dim=0)
417+
image_rotary_emb_history_2x = _pad_for_3d_conv(image_rotary_emb_history_2x, (2, 2, 2))
418+
image_rotary_emb_history_2x = _center_down_sample_3d(image_rotary_emb_history_2x, (2, 2, 2))
419+
image_rotary_emb_history_2x = image_rotary_emb_history_2x.flatten(2).transpose(1, 2)
420+
image_rotary_emb = torch.cat([image_rotary_emb_history_2x, image_rotary_emb], dim=1)
375421

376422
if latents_history_4x is not None and indices_latents_history_4x is not None:
377423
hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1)
378424

379425
image_rotary_emb_history_4x = self.rope(
380426
frame_indices=indices_latents_history_4x, height=height, width=width, device=latents_history_4x.device
381427
)
382-
image_rotary_emb_history_4x = self._pad_rotary_emb(
383-
image_rotary_emb_history_4x, indices_latents_history_4x.size(1), pph, ppw, (4, 4, 4)
384-
)
385-
image_rotary_emb[0] = torch.cat([image_rotary_emb_history_4x[0], image_rotary_emb[0]], dim=0)
386-
image_rotary_emb[1] = torch.cat([image_rotary_emb_history_4x[1], image_rotary_emb[1]], dim=0)
387-
388-
return hidden_states, image_rotary_emb
428+
image_rotary_emb_history_4x = _pad_for_3d_conv(image_rotary_emb_history_4x, (4, 4, 4))
429+
image_rotary_emb_history_4x = _center_down_sample_3d(image_rotary_emb_history_4x, (4, 4, 4))
430+
image_rotary_emb_history_4x = image_rotary_emb_history_4x.flatten(2).transpose(1, 2)
431+
image_rotary_emb = torch.cat([image_rotary_emb_history_4x, image_rotary_emb], dim=1)
432+
433+
return hidden_states, image_rotary_emb.squeeze(0).chunk(2, dim=-1)
434+
435+
# def _pack_history_states(
436+
# self,
437+
# hidden_states: torch.Tensor,
438+
# indices_latents: torch.Tensor,
439+
# latents_clean: Optional[torch.Tensor] = None,
440+
# latents_history_2x: Optional[torch.Tensor] = None,
441+
# latents_history_4x: Optional[torch.Tensor] = None,
442+
# indices_latents_clean: Optional[torch.Tensor] = None,
443+
# indices_latents_history_2x: Optional[torch.Tensor] = None,
444+
# indices_latents_history_4x: Optional[torch.Tensor] = None,
445+
# ):
446+
# batch_size, num_channels, num_frames, height, width = hidden_states.shape
447+
# if indices_latents is None:
448+
# indices_latents = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1)
449+
450+
# hidden_states = self.x_embedder(hidden_states)
451+
# image_rotary_emb = self.rope(
452+
# frame_indices=indices_latents, height=height, width=width, device=hidden_states.device
453+
# )
454+
# image_rotary_emb = list(image_rotary_emb) # convert tuple to list for in-place modification
455+
# pph, ppw = height // self.config.patch_size, width // self.config.patch_size
456+
457+
# latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder(
458+
# latents_clean, latents_history_2x, latents_history_4x
459+
# )
460+
461+
# if latents_clean is not None:
462+
# hidden_states = torch.cat([latents_clean, hidden_states], dim=1)
463+
464+
# image_rotary_emb_clean = self.rope(
465+
# frame_indices=indices_latents_clean, height=height, width=width, device=latents_clean.device
466+
# )
467+
# image_rotary_emb[0] = torch.cat([image_rotary_emb_clean[0], image_rotary_emb[0]], dim=0)
468+
# image_rotary_emb[1] = torch.cat([image_rotary_emb_clean[1], image_rotary_emb[1]], dim=0)
469+
470+
# if latents_history_2x is not None and indices_latents_history_2x is not None:
471+
# hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1)
472+
473+
# image_rotary_emb_history_2x = self.rope(
474+
# frame_indices=indices_latents_history_2x, height=height, width=width, device=latents_history_2x.device
475+
# )
476+
# image_rotary_emb_history_2x = self._pad_rotary_emb(
477+
# image_rotary_emb_history_2x, indices_latents_history_2x.size(1), pph, ppw, (2, 2, 2)
478+
# )
479+
# image_rotary_emb[0] = torch.cat([image_rotary_emb_history_2x[0], image_rotary_emb[0]], dim=0)
480+
# image_rotary_emb[1] = torch.cat([image_rotary_emb_history_2x[1], image_rotary_emb[1]], dim=0)
481+
482+
# if latents_history_4x is not None and indices_latents_history_4x is not None:
483+
# hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1)
484+
485+
# image_rotary_emb_history_4x = self.rope(
486+
# frame_indices=indices_latents_history_4x, height=height, width=width, device=latents_history_4x.device
487+
# )
488+
# image_rotary_emb_history_4x = self._pad_rotary_emb(
489+
# image_rotary_emb_history_4x, indices_latents_history_4x.size(1), pph, ppw, (4, 4, 4)
490+
# )
491+
# image_rotary_emb[0] = torch.cat([image_rotary_emb_history_4x[0], image_rotary_emb[0]], dim=0)
492+
# image_rotary_emb[1] = torch.cat([image_rotary_emb_history_4x[1], image_rotary_emb[1]], dim=0)
493+
494+
# return hidden_states, image_rotary_emb
389495

390496
def _pad_rotary_emb(
391497
self,

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,6 @@ def __call__(
551551
guidance_scale: float = 6.0,
552552
num_videos_per_prompt: Optional[int] = 1,
553553
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
554-
latents: Optional[torch.Tensor] = None,
555554
image_latents: Optional[torch.Tensor] = None,
556555
prompt_embeds: Optional[torch.Tensor] = None,
557556
pooled_prompt_embeds: Optional[torch.Tensor] = None,
@@ -614,10 +613,8 @@ def __call__(
614613
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
615614
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
616615
generation deterministic.
617-
latents (`torch.Tensor`, *optional*):
618-
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
619-
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
620-
tensor is generated by sampling using the supplied random `generator`.
616+
image_latents (`torch.Tensor`, *optional*):
617+
Pre-encoded image latents.
621618
prompt_embeds (`torch.Tensor`, *optional*):
622619
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
623620
provided, text embeddings are generated from the `prompt` input argument.
@@ -767,10 +764,9 @@ def __call__(
767764
guidance = torch.tensor([guidance_scale] * batch_size, dtype=transformer_dtype, device=device) * 1000.0
768765

769766
# 7. Denoising loop
770-
for i in range(num_latent_sections):
771-
current_latent_padding = latent_paddings[i]
772-
is_last_section = current_latent_padding == 0
773-
latent_padding_size = current_latent_padding * latent_window_size
767+
for k in range(num_latent_sections):
768+
is_last_section = latent_paddings[k] == 0
769+
latent_padding_size = latent_paddings[k] * latent_window_size
774770

775771
indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, *history_sizes])).unsqueeze(0)
776772
(
@@ -799,12 +795,12 @@ def __call__(
799795
dtype=torch.float32,
800796
device=device,
801797
generator=generator,
802-
latents=latents,
798+
latents=None,
803799
)
804800

805801
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
806802
image_seq_len = (
807-
latents.shape[1] * latents.shape[2] * latents.shape[3] / self.transformer.config.patch_size**2
803+
latents.shape[2] * latents.shape[3] * latents.shape[4] / self.transformer.config.patch_size**2
808804
)
809805
exp_max = 7.0
810806
mu = calculate_shift(
@@ -887,7 +883,7 @@ def __call__(
887883

888884
if XLA_AVAILABLE:
889885
xm.mark_step()
890-
886+
891887
if is_last_section:
892888
latents = torch.cat([image_latents, latents], dim=2)
893889

0 commit comments

Comments
 (0)