Skip to content

Commit 2296291

Browse files
committed
Set attention_mask to dtype=torch.bool for ChromaInpaintPipeline.
1 parent a24f566 commit 2296291

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,7 @@ def _prepare_attention_mask(
761761

762762
# Extend the prompt attention mask to account for image tokens in the final sequence
763763
attention_mask = torch.cat(
764-
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
764+
[attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)],
765765
dim=1,
766766
)
767767
attention_mask = attention_mask.to(dtype)

0 commit comments

Comments
 (0)