-
Notifications
You must be signed in to change notification settings - Fork 6.7k
Improve FluxPipeline checks and logging
#9064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
48d8256
6b6c480
e86b9cc
d307205
a674e80
83cac59
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -146,6 +146,11 @@ class FluxPipeline(DiffusionPipeline, SD3LoraLoaderMixin): | |
| r""" | ||
| The Flux pipeline for text-to-image generation. | ||
|
|
||
| Note: | ||
| This pipeline expects `latents` to be in a packed format. If you're providing | ||
| custom latents, make sure to use the `_pack_latents` method to prepare them. | ||
| Packed latents should be a 3D tensor of shape (batch_size, num_patches, channels). | ||
|
||
|
|
||
| Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ | ||
|
|
||
| Args: | ||
|
|
@@ -391,6 +396,7 @@ def check_inputs( | |
| pooled_prompt_embeds=None, | ||
| callback_on_step_end_tensor_inputs=None, | ||
| max_sequence_length=None, | ||
| latents=None, | ||
| ): | ||
| if height % 8 != 0 or width % 8 != 0: | ||
| raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | ||
|
|
@@ -429,6 +435,26 @@ def check_inputs( | |
| if max_sequence_length is not None and max_sequence_length > 512: | ||
| raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") | ||
|
|
||
| if latents is not None: | ||
| if not isinstance(latents, torch.Tensor): | ||
| raise ValueError(f"`latents` has to be of type `torch.Tensor` but is {type(latents)}") | ||
|
|
||
| if not _are_latents_packed(latents): | ||
|
||
| raise ValueError(f"`latents` should be a 3-dimensional tensor but has {latents.ndim=} dimensions") | ||
|
|
||
| batch_size, num_patches, channels = latents.shape | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't this throw an error if the shape is |
||
| if channels != self.transformer.config.in_channels: | ||
| raise ValueError( | ||
| f"Number of channels in `latents` ({channels}) does not match the number of channels expected by" | ||
| f" the transformer ({self.transformer.config.in_channels=})." | ||
| ) | ||
|
|
||
| if num_patches != (height // self.vae_scale_factor) * (width // self.vae_scale_factor): | ||
| raise ValueError( | ||
| f"Number of patches in `latents` ({num_patches}) does not match the number of patches expected by" | ||
| f" the transformer ({(height // self.vae_scale_factor) * (width // self.vae_scale_factor)=})." | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def _prepare_latent_image_ids(batch_size, height, width, device, dtype): | ||
| latent_image_ids = torch.zeros(height // 2, width // 2, 3) | ||
|
|
@@ -466,6 +492,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor): | |
|
|
||
| return latents | ||
|
|
||
| @staticmethod | ||
| def _are_latents_packed(latents): | ||
| return latents.ndim == 3 | ||
|
|
||
| def prepare_latents( | ||
| self, | ||
| batch_size, | ||
|
|
@@ -477,15 +507,21 @@ def prepare_latents( | |
| generator, | ||
| latents=None, | ||
| ): | ||
| if latents is not None: | ||
| if latents.ndim == 4: | ||
| logger.warning( | ||
| "Unpacked latents detected. These will be automatically packed. " | ||
| "In the future, please provide packed latents to improve performance." | ||
| ) | ||
| latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Packing the latents is an inexpensive operation. So, I think this warning is unnecessary.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, the warning can be removed. Note also that I moved the definitions for |
||
| latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) | ||
| return latents.to(device=device, dtype=dtype), latent_image_ids | ||
|
|
||
| height = 2 * (int(height) // self.vae_scale_factor) | ||
| width = 2 * (int(width) // self.vae_scale_factor) | ||
|
|
||
| shape = (batch_size, num_channels_latents, height, width) | ||
|
|
||
| if latents is not None: | ||
| latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) | ||
| return latents.to(device=device, dtype=dtype), latent_image_ids | ||
|
|
||
| if isinstance(generator, list) and len(generator) != batch_size: | ||
| raise ValueError( | ||
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | ||
|
|
@@ -621,6 +657,7 @@ def __call__( | |
| pooled_prompt_embeds=pooled_prompt_embeds, | ||
| callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, | ||
| max_sequence_length=max_sequence_length, | ||
| latents=latents, | ||
| ) | ||
|
|
||
| self._guidance_scale = guidance_scale | ||
|
|
@@ -668,6 +705,11 @@ def __call__( | |
| latents, | ||
| ) | ||
|
|
||
| if not self._are_latents_packed(latents): | ||
| raise ValueError( | ||
| "Latents are not in the correct packed format. Please use `_pack_latents` to prepare them." | ||
| ) | ||
|
||
|
|
||
| # 5. Prepare timesteps | ||
| sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) | ||
| image_seq_len = latents.shape[1] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was actually fixing it here #9057. But okay.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood, we can ignore this then.