@@ -280,14 +280,14 @@ def forward(
280280 effective_condition_sequence_length = encoder_attention_mask .sum (dim = 1 , dtype = torch .int ) # [B,]
281281 effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
282282
283- # if batch_size == 1:
284- # encoder_hidden_states = encoder_hidden_states[:, : effective_condition_sequence_length[0]]
285- # attention_mask = None
286- # else:
287- for i in range (batch_size ):
288- attention_mask [i , : effective_sequence_length [i ]] = True
289- # [B, 1, 1, N], for broadcasting across attention heads
290- attention_mask = attention_mask .unsqueeze (1 ).unsqueeze (1 )
283+ if batch_size == 1 :
284+ encoder_hidden_states = encoder_hidden_states [:, : effective_condition_sequence_length [0 ]]
285+ attention_mask = None
286+ else :
287+ for i in range (batch_size ):
288+ attention_mask [i , : effective_sequence_length [i ]] = True
289+ # [B, 1, 1, N], for broadcasting across attention heads
290+ attention_mask = attention_mask .unsqueeze (1 ).unsqueeze (1 )
291291
292292 if torch .is_grad_enabled () and self .gradient_checkpointing :
293293 for block in self .transformer_blocks :
@@ -311,7 +311,6 @@ def forward(
311311 hidden_states , encoder_hidden_states , temb , attention_mask , image_rotary_emb
312312 )
313313
314- # 5. Output projection
315314 hidden_states = hidden_states [:, - original_context_length :]
316315 hidden_states = self .norm_out (hidden_states , temb )
317316 hidden_states = self .proj_out (hidden_states )
0 commit comments