Skip to content

Commit 68a7fa6

Browse files
committed
batch_size=1 optimization
1 parent 65c24f6 commit 68a7fa6

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -280,14 +280,14 @@ def forward(
280280
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
281281
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
282282

283-
# if batch_size == 1:
284-
# encoder_hidden_states = encoder_hidden_states[:, : effective_condition_sequence_length[0]]
285-
# attention_mask = None
286-
# else:
287-
for i in range(batch_size):
288-
attention_mask[i, : effective_sequence_length[i]] = True
289-
# [B, 1, 1, N], for broadcasting across attention heads
290-
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
283+
if batch_size == 1:
284+
encoder_hidden_states = encoder_hidden_states[:, : effective_condition_sequence_length[0]]
285+
attention_mask = None
286+
else:
287+
for i in range(batch_size):
288+
attention_mask[i, : effective_sequence_length[i]] = True
289+
# [B, 1, 1, N], for broadcasting across attention heads
290+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
291291

292292
if torch.is_grad_enabled() and self.gradient_checkpointing:
293293
for block in self.transformer_blocks:
@@ -311,7 +311,6 @@ def forward(
311311
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
312312
)
313313

314-
# 5. Output projection
315314
hidden_states = hidden_states[:, -original_context_length:]
316315
hidden_states = self.norm_out(hidden_states, temb)
317316
hidden_states = self.proj_out(hidden_states)

0 commit comments

Comments
 (0)