Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,6 @@ def forward(
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

print(f"{txt_ids.shape=}, {img_ids.shape=}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was actually fixing it here #9057. But okay.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, we can ignore this then.

ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pos_embed(ids)

Expand Down
50 changes: 46 additions & 4 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin):
r"""
The Flux pipeline for text-to-image generation.

Note:
This pipeline expects `latents` to be in a packed format. If you're providing
custom latents, make sure to use the `_pack_latents` method to prepare them.
Packed latents should be a 3D tensor of shape (batch_size, num_patches, channels).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this note is unnecessary as well given that _pack_latents() is an inexpensive operation. I think your updates to check_inputs() should do the trick.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, maybe then I can add this note as part of the latents docstring in the __call__


Reference: https://blackforestlabs.ai/announcing-black-forest-labs/

Args:
Expand Down Expand Up @@ -391,6 +396,7 @@ def check_inputs(
pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
latents=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
Expand Down Expand Up @@ -429,6 +435,26 @@ def check_inputs(
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")

if latents is not None:
if not isinstance(latents, torch.Tensor):
raise ValueError(f"`latents` has to be of type `torch.Tensor` but is {type(latents)}")

if not _are_latents_packed(latents):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But later in prepare_latents(), we are packing the latents and throwing a warning, no? So, I think we can remove this check.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, this check can be removed from the check_inputs function.

raise ValueError(f"`latents` should be a 3-dimensional tensor but has {latents.ndim=} dimensions")

batch_size, num_patches, channels = latents.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this throw an error if the shape is ndim==4. Perhaps we check to see if ndim==3 here and raise and error rather than letting it fail.

if channels != self.transformer.config.in_channels:
raise ValueError(
f"Number of channels in `latents` ({channels}) does not match the number of channels expected by"
f" the transformer ({self.transformer.config.in_channels=})."
)

if num_patches != (height // self.vae_scale_factor) * (width // self.vae_scale_factor):
raise ValueError(
f"Number of patches in `latents` ({num_patches}) does not match the number of patches expected by"
f" the transformer ({(height // self.vae_scale_factor) * (width // self.vae_scale_factor)=})."
)

@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
Expand Down Expand Up @@ -466,6 +492,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor):

return latents

@staticmethod
def _are_latents_packed(latents):
return latents.ndim == 3

def prepare_latents(
self,
batch_size,
Expand All @@ -477,15 +507,21 @@ def prepare_latents(
generator,
latents=None,
):
if latents is not None:
if latents.ndim == 4:
logger.warning(
"Unpacked latents detected. These will be automatically packed. "
"In the future, please provide packed latents to improve performance."
)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Packing the latents is an inexpensive operation. So, I think this warning is unnecessary.

Copy link
Author

@PDillis PDillis Aug 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, the warning can be removed. Note also that I moved the definitions for height and width in case no latents were provided, as otherwise we would be giving the wrong dimensions to self._prepare_latent_image_ids

latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids

height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)

shape = (batch_size, num_channels_latents, height, width)

if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids

if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
Expand Down Expand Up @@ -621,6 +657,7 @@ def __call__(
pooled_prompt_embeds=pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
latents=latents,
)

self._guidance_scale = guidance_scale
Expand Down Expand Up @@ -668,6 +705,11 @@ def __call__(
latents,
)

if not self._are_latents_packed(latents):
raise ValueError(
"Latents are not in the correct packed format. Please use `_pack_latents` to prepare them."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed too since essentially check_inputs() should catch these.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the check_inputs function should catch these, so this can be removed. Then there is no need for the _are_latents_packed method, so that can also be removed.


# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
Expand Down