diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index aa421a53727b..45556c538ab8 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -18,6 +18,7 @@ from dataclasses import dataclass, field, fields from typing import Any, Dict, List, Literal, Optional, Type, Union +import PIL.Image import torch from ..configuration_utils import ConfigMixin, FrozenDict @@ -342,6 +343,185 @@ class InputParam: def __repr__(self): return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" + @classmethod + def template(cls, name: str) -> Optional["InputParam"]: + """Get template for name if exists, otherwise None.""" + if hasattr(cls, name) and callable(getattr(cls, name)): + return getattr(cls, name)() + return None + + # ====================================================== + # InputParam templates + # ====================================================== + + @classmethod + def prompt(cls) -> "InputParam": + return cls( + name="prompt", type_hint=str, required=True, description="The prompt or prompts to guide image generation." + ) + + @classmethod + def negative_prompt(cls) -> "InputParam": + return cls( + name="negative_prompt", + type_hint=str, + default=None, + description="The prompt or prompts not to guide the image generation.", + ) + + @classmethod + def max_sequence_length(cls, default: int = 512) -> "InputParam": + return cls( + name="max_sequence_length", + type_hint=int, + default=default, + description="Maximum sequence length for prompt encoding.", + ) + + @classmethod + def height(cls, default: Optional[int] = None) -> "InputParam": + return cls( + name="height", type_hint=int, default=default, description="The height in pixels of the generated image." + ) + + @classmethod + def width(cls, default: Optional[int] = None) -> "InputParam": + return cls( + name="width", type_hint=int, default=default, description="The width in pixels of the generated image." + ) + + @classmethod + def num_inference_steps(cls, default: int = 50) -> "InputParam": + return cls( + name="num_inference_steps", type_hint=int, default=default, description="The number of denoising steps." + ) + + @classmethod + def num_images_per_prompt(cls, default: int = 1) -> "InputParam": + return cls( + name="num_images_per_prompt", + type_hint=int, + default=default, + description="The number of images to generate per prompt.", + ) + + @classmethod + def generator(cls) -> "InputParam": + return cls( + name="generator", + type_hint=torch.Generator, + default=None, + description="Torch generator for deterministic generation.", + ) + + @classmethod + def sigmas(cls) -> "InputParam": + return cls( + name="sigmas", type_hint=List[float], default=None, description="Custom sigmas for the denoising process." + ) + + @classmethod + def strength(cls, default: float = 0.9) -> "InputParam": + return cls(name="strength", type_hint=float, default=default, description="Strength for img2img/inpainting.") + + # images + @classmethod + def image(cls) -> "InputParam": + return cls( + name="image", + type_hint=PIL.Image.Image, + required=True, + description="Input image for img2img, editing, or conditioning.", + ) + + @classmethod + def mask_image(cls) -> "InputParam": + return cls( + name="mask_image", type_hint=PIL.Image.Image, required=True, description="Mask image for inpainting." + ) + + @classmethod + def control_image(cls) -> "InputParam": + return cls( + name="control_image", + type_hint=PIL.Image.Image, + required=True, + description="Control image for ControlNet conditioning.", + ) + + @classmethod + def padding_mask_crop(cls) -> "InputParam": + return cls( + name="padding_mask_crop", + type_hint=int, + default=None, + description="Padding for mask cropping in inpainting.", + ) + + @classmethod + def latents(cls) -> "InputParam": + return cls( + name="latents", + type_hint=torch.Tensor, + default=None, + description="Pre-generated noisy latents for image generation.", + ) + + @classmethod + def timesteps(cls) -> "InputParam": + return cls( + name="timesteps", type_hint=torch.Tensor, default=None, description="Timesteps for the denoising process." + ) + + @classmethod + def output_type(cls) -> "InputParam": + return cls(name="output_type", type_hint=str, default="pil", description="Output format: 'pil', 'np', 'pt''.") + + @classmethod + def attention_kwargs(cls) -> "InputParam": + return cls( + name="attention_kwargs", + type_hint=Dict[str, Any], + default=None, + description="Additional kwargs for attention processors.", + ) + + @classmethod + def denoiser_input_fields(cls) -> "InputParam": + return cls( + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", + ) + + # ControlNet + @classmethod + def control_guidance_start(cls, default: float = 0.0) -> "InputParam": + return cls( + name="control_guidance_start", + type_hint=float, + default=default, + description="When to start applying ControlNet.", + ) + + @classmethod + def control_guidance_end(cls, default: float = 1.0) -> "InputParam": + return cls( + name="control_guidance_end", + type_hint=float, + default=default, + description="When to stop applying ControlNet.", + ) + + @classmethod + def controlnet_conditioning_scale(cls, default: float = 1.0) -> "InputParam": + return cls( + name="controlnet_conditioning_scale", + type_hint=float, + default=default, + description="Scale for ControlNet conditioning.", + ) + @dataclass class OutputParam: @@ -357,6 +537,25 @@ def __repr__(self): f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>" ) + @classmethod + def template(cls, name: str) -> Optional["OutputParam"]: + """Get template for name if exists, otherwise None.""" + if hasattr(cls, name) and callable(getattr(cls, name)): + return getattr(cls, name)() + return None + + # ====================================================== + # OutputParam templates + # ====================================================== + + @classmethod + def images(cls) -> "OutputParam": + return cls(name="images", type_hint=List[PIL.Image.Image], description="Generated images.") + + @classmethod + def latents(cls) -> "OutputParam": + return cls(name="latents", type_hint=torch.Tensor, description="Denoised latents.") + def format_inputs_short(inputs): """ @@ -509,10 +708,12 @@ def wrap_text(text, indent, max_length): desc = re.sub(r"\[(.*?)\]\((https?://[^\s\)]+)\)", r"[\1](\2)", param.description) wrapped_desc = wrap_text(desc, desc_indent, max_line_length) param_str += f"\n{desc_indent}{wrapped_desc}" + else: + param_str += f"\n{desc_indent}TODO: Add description." formatted_params.append(param_str) - return "\n\n".join(formatted_params) + return "\n".join(formatted_params) def format_input_params(input_params, indent_level=4, max_line_length=115): @@ -582,7 +783,7 @@ def format_components(components, indent_level=4, max_line_length=115, add_empty loading_field_values = [] for field_name in component.loading_fields(): field_value = getattr(component, field_name) - if field_value is not None: + if field_value: loading_field_values.append(f"{field_name}={field_value}") # Add loading field information if available diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index 0c66d6ea3303..cb808b1d3807 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -134,11 +134,11 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("latents"), - InputParam(name="height"), - InputParam(name="width"), - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="generator"), + InputParam.latents(), + InputParam.height(), + InputParam.width(), + InputParam.num_images_per_prompt(), + InputParam.generator(), InputParam( name="batch_size", required=True, @@ -225,12 +225,14 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("latents"), - InputParam(name="height"), - InputParam(name="width"), - InputParam(name="layers", default=4), - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="generator"), + InputParam.latents(), + InputParam.height(), + InputParam.width(), + InputParam( + name="layers", type_hint=int, default=4, description="Number of layers to extract from the image" + ), + InputParam.num_images_per_prompt(), + InputParam.generator(), InputParam( name="batch_size", required=True, @@ -466,8 +468,8 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="num_inference_steps", default=50), - InputParam(name="sigmas"), + InputParam.num_inference_steps(), + InputParam.sigmas(), InputParam( name="latents", required=True, @@ -532,8 +534,8 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("num_inference_steps", default=50, type_hint=int), - InputParam("sigmas", type_hint=List[float]), + InputParam.num_inference_steps(), + InputParam.sigmas(), InputParam("image_latents", required=True, type_hint=torch.Tensor), ] @@ -590,15 +592,15 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="num_inference_steps", default=50), - InputParam(name="sigmas"), + InputParam.num_inference_steps(), + InputParam.sigmas(), InputParam( name="latents", required=True, type_hint=torch.Tensor, description="The latents to use for the denoising process, used to calculate the image sequence length.", ), - InputParam(name="strength", default=0.9), + InputParam.strength(0.9), ] @property @@ -886,7 +888,7 @@ def description(self) -> str: def inputs(self) -> List[InputParam]: return [ InputParam(name="batch_size", required=True), - InputParam(name="layers", required=True), + InputParam(name="layers", default=4, description="Number of layers to extract from the image"), InputParam(name="height", required=True), InputParam(name="width", required=True), InputParam(name="prompt_embeds_mask"), @@ -971,9 +973,9 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam("control_guidance_start", default=0.0), - InputParam("control_guidance_end", default=1.0), - InputParam("controlnet_conditioning_scale", default=1.0), + InputParam.control_guidance_start(), + InputParam.control_guidance_end(), + InputParam.controlnet_conditioning_scale(), InputParam("control_image_latents", required=True), InputParam( "timesteps", diff --git a/src/diffusers/modular_pipelines/qwenimage/decoders.py b/src/diffusers/modular_pipelines/qwenimage/decoders.py index 24a88ebfca3c..8207e99b69ae 100644 --- a/src/diffusers/modular_pipelines/qwenimage/decoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/decoders.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import List -import numpy as np -import PIL import torch from ...configuration_utils import FrozenDict @@ -91,7 +89,7 @@ def inputs(self) -> List[InputParam]: InputParam("latents", required=True, type_hint=torch.Tensor), InputParam("height", required=True, type_hint=int), InputParam("width", required=True, type_hint=int), - InputParam("layers", required=True, type_hint=int), + InputParam("layers", default=4, description="Number of layers to extract from the image"), ] @torch.no_grad() @@ -140,13 +138,7 @@ def inputs(self) -> List[InputParam]: @property def intermediate_outputs(self) -> List[str]: - return [ - OutputParam( - "images", - type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], - description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array", - ) - ] + return [OutputParam.images()] @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: @@ -198,14 +190,19 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("latents", required=True, type_hint=torch.Tensor), - InputParam("output_type", default="pil", type_hint=str), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to decode, can be generated in the denoise step", + ), + InputParam.output_type(), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]), + OutputParam.images(), ] @torch.no_grad() @@ -273,12 +270,7 @@ def expected_components(self) -> List[ComponentSpec]: def inputs(self) -> List[InputParam]: return [ InputParam("images", required=True, description="the generated image from decoders step"), - InputParam( - name="output_type", - default="pil", - type_hint=str, - description="The type of the output images, can be 'pil', 'np', 'pt'", - ), + InputParam.output_type(), ] @staticmethod @@ -323,12 +315,7 @@ def expected_components(self) -> List[ComponentSpec]: def inputs(self) -> List[InputParam]: return [ InputParam("images", required=True, description="the generated image from decoders step"), - InputParam( - name="output_type", - default="pil", - type_hint=str, - description="The type of the output images, can be 'pil', 'np', 'pt'", - ), + InputParam.output_type(), InputParam("mask_overlay_kwargs"), ] diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py index eb1e5a341c68..472945b2269a 100644 --- a/src/diffusers/modular_pipelines/qwenimage/denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py @@ -218,7 +218,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("attention_kwargs"), + InputParam.attention_kwargs(), InputParam( "latents", required=True, @@ -231,10 +231,7 @@ def inputs(self) -> List[InputParam]: type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), - InputParam( - kwargs_type="denoiser_input_fields", - description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", - ), + InputParam.denoiser_input_fields(), InputParam( "img_shapes", required=True, @@ -322,7 +319,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("attention_kwargs"), + InputParam.attention_kwargs(), InputParam( "latents", required=True, @@ -335,10 +332,7 @@ def inputs(self) -> List[InputParam]: type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), - InputParam( - kwargs_type="denoiser_input_fields", - description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", - ), + InputParam.denoiser_input_fields(), InputParam( "img_shapes", required=True, @@ -424,7 +418,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents."), + OutputParam.latents(), ] @torch.no_grad() diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 4b66dd32e521..8d7b1905423d 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -301,8 +301,12 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam( - name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize" + InputParam.template(self._image_input_name) + or InputParam( + name=self._image_input_name, + required=True, + type_hint=torch.Tensor, + description="Input image for conditioning", ), ] @@ -381,7 +385,8 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam( + InputParam.template(self._image_input_name) + or InputParam( name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize" ), InputParam( @@ -484,7 +489,8 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam( + InputParam.template(self._image_input_name) + or InputParam( name=self._image_input_name, required=True, type_hint=torch.Tensor, @@ -564,7 +570,9 @@ def expected_configs(self) -> List[ConfigSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="prompt", type_hint=str, description="The prompt to encode"), + InputParam( + name="prompt", type_hint=str, description="The prompt to encode" + ), # it is not required for qwenimage-layered, unlike other pipelines InputParam( name="resized_image", required=True, @@ -647,11 +655,9 @@ def expected_configs(self) -> List[ConfigSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"), - InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"), - InputParam( - name="max_sequence_length", type_hint=int, description="The max sequence length to use", default=1024 - ), + InputParam.prompt(), + InputParam.negative_prompt(), + InputParam.max_sequence_length(1024), ] @property @@ -772,8 +778,8 @@ def expected_configs(self) -> List[ConfigSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"), - InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"), + InputParam.prompt(), + InputParam.negative_prompt(), InputParam( name="resized_image", required=True, @@ -895,8 +901,8 @@ def expected_configs(self) -> List[ConfigSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"), - InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"), + InputParam.prompt(), + InputParam.negative_prompt(), InputParam( name="resized_cond_image", required=True, @@ -1010,11 +1016,11 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("mask_image", required=True), - InputParam("image", required=True), - InputParam("height"), - InputParam("width"), - InputParam("padding_mask_crop"), + InputParam.mask_image(), + InputParam.image(), + InputParam.height(), + InputParam.width(), + InputParam.padding_mask_crop(), ] @property @@ -1082,9 +1088,14 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("mask_image", required=True), - InputParam("resized_image", required=True), - InputParam("padding_mask_crop"), + InputParam.mask_image(), + InputParam( + "resized_image", + required=True, + type_hint=PIL.Image.Image, + description="The resized image. should be generated using a resize step", + ), + InputParam.padding_mask_crop(), ] @property @@ -1140,9 +1151,9 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam("image", required=True), - InputParam("height"), - InputParam("width"), + InputParam.image(), + InputParam.height(), + InputParam.width(), ] @property @@ -1312,7 +1323,11 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: - return [InputParam(self._image_input_name, required=True), InputParam("generator")] + return [ + InputParam.template(self._image_input_name) + or InputParam(name=self._image_input_name, required=True, description="The image tensor to encode"), + InputParam.generator(), + ] @property def intermediate_outputs(self) -> List[OutputParam]: @@ -1383,10 +1398,10 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: inputs = [ - InputParam("control_image", required=True), - InputParam("height"), - InputParam("width"), - InputParam("generator"), + InputParam.control_image(), + InputParam.height(), + InputParam.width(), + InputParam.generator(), ] return inputs diff --git a/src/diffusers/modular_pipelines/qwenimage/inputs.py b/src/diffusers/modular_pipelines/qwenimage/inputs.py index 4a1cf3700c57..e28493ecc369 100644 --- a/src/diffusers/modular_pipelines/qwenimage/inputs.py +++ b/src/diffusers/modular_pipelines/qwenimage/inputs.py @@ -129,7 +129,7 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="num_images_per_prompt", default=1), + InputParam.num_images_per_prompt(), InputParam(name="prompt_embeds", required=True, kwargs_type="denoiser_input_fields"), InputParam(name="prompt_embeds_mask", required=True, kwargs_type="denoiser_input_fields"), InputParam(name="negative_prompt_embeds", kwargs_type="denoiser_input_fields"), @@ -269,17 +269,17 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: inputs = [ - InputParam(name="num_images_per_prompt", default=1), + InputParam.num_images_per_prompt(), InputParam(name="batch_size", required=True), - InputParam(name="height"), - InputParam(name="width"), + InputParam.height(), + InputParam.width(), ] for image_latent_input_name in self._image_latent_inputs: - inputs.append(InputParam(name=image_latent_input_name)) + inputs.append(InputParam.template(image_latent_input_name) or InputParam(name=image_latent_input_name)) for input_name in self._additional_batch_inputs: - inputs.append(InputParam(name=input_name)) + inputs.append(InputParam.template(input_name) or InputParam(name=input_name)) return inputs @@ -398,17 +398,17 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: inputs = [ - InputParam(name="num_images_per_prompt", default=1), + InputParam.num_images_per_prompt(), InputParam(name="batch_size", required=True), - InputParam(name="height"), - InputParam(name="width"), + InputParam.height(), + InputParam.width(), ] for image_latent_input_name in self._image_latent_inputs: - inputs.append(InputParam(name=image_latent_input_name)) + inputs.append(InputParam.template(image_latent_input_name) or InputParam(name=image_latent_input_name)) for input_name in self._additional_batch_inputs: - inputs.append(InputParam(name=input_name)) + inputs.append(InputParam.template(input_name) or InputParam(name=input_name)) return inputs @@ -544,15 +544,15 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: inputs = [ - InputParam(name="num_images_per_prompt", default=1), + InputParam.num_images_per_prompt(), InputParam(name="batch_size", required=True), ] for image_latent_input_name in self._image_latent_inputs: - inputs.append(InputParam(name=image_latent_input_name)) + inputs.append(InputParam.template(image_latent_input_name) or InputParam(name=image_latent_input_name)) for input_name in self._additional_batch_inputs: - inputs.append(InputParam(name=input_name)) + inputs.append(InputParam.template(input_name) or InputParam(name=input_name)) return inputs @@ -638,9 +638,9 @@ def inputs(self) -> List[InputParam]: return [ InputParam(name="control_image_latents", required=True), InputParam(name="batch_size", required=True), - InputParam(name="num_images_per_prompt", default=1), - InputParam(name="height"), - InputParam(name="width"), + InputParam.num_images_per_prompt(), + InputParam.height(), + InputParam.width(), ] @torch.no_grad() diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py index ebe0bbbd75ba..645c01f66ee5 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage.py @@ -12,10 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List - -import PIL.Image -import torch from ...utils import logging from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks @@ -59,11 +55,108 @@ # ==================== -# 1. VAE ENCODER +# 1. TEXT ENCODER +# ==================== + + +# auto_docstring +class QwenImageAutoTextEncoderStep(AutoPipelineBlocks): + """ + Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block. + + Components: + + text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use + + tokenizer (`Qwen2Tokenizer`): The tokenizer to use + + guider (`ClassifierFreeGuidance`) + + Configs: + + prompt_template_encode (default: <|im_start|>system + Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the + objects and background:<|im_end|> <|im_start|>user {}<|im_end|> <|im_start|>assistant ) + + prompt_template_encode_start_idx (default: 34) + + tokenizer_max_length (default: 1024) + + Inputs: + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings + prompt_embeds_mask (`Tensor`): + The encoder attention mask + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask + """ + + model_name = "qwenimage" + block_classes = [QwenImageTextEncoderStep()] + block_names = ["text_encoder"] + block_trigger_inputs = ["prompt"] + + @property + def description(self) -> str: + return "Text encoder step that encodes the text prompt into a text embedding. This is an auto pipeline block." + " - `QwenImageTextEncoderStep` (text_encoder) is used when `prompt` is provided." + " - if `prompt` is not provided, step will be skipped." + + +# ==================== +# 2. VAE ENCODER # ==================== +# auto_docstring class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks): + """ + This step is used for processing image and mask inputs for inpainting tasks. It: + - Resizes the image to the target size, based on `height` and `width`. + - Processes and updates `image` and `mask_image`. + - Creates `image_latents`. + + Components: + + image_mask_processor (`InpaintProcessor`) + + vae (`AutoencoderKLQwenImage`) + + Inputs: + mask_image (`Image`): + Mask image for inpainting. + image (`Image`): + Input image for img2img, editing, or conditioning. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + processed_image (`None`): + TODO: Add description. + processed_mask_image (`None`): + TODO: Add description. + mask_overlay_kwargs (`Dict`): + The kwargs for the postprocess step to apply the mask overlay + image_latents (`Tensor`): + The latents representing the reference image(s). Single tensor or list depending on input. + """ + model_name = "qwenimage" block_classes = [QwenImageInpaintProcessImagesInputStep(), QwenImageVaeEncoderStep()] block_names = ["preprocess", "encode"] @@ -78,7 +171,34 @@ def description(self) -> str: ) +# auto_docstring class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that preprocess andencode the image inputs into their latent representations. + + Components: + + image_processor (`VaeImageProcessor`) + + vae (`AutoencoderKLQwenImage`) + + Inputs: + image (`Image`): + Input image for img2img, editing, or conditioning. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + processed_image (`None`): + TODO: Add description. + image_latents (`Tensor`): + The latents representing the reference image(s). Single tensor or list depending on input. + """ + model_name = "qwenimage" block_classes = [QwenImageProcessImagesInputStep(), QwenImageVaeEncoderStep()] @@ -89,7 +209,6 @@ def description(self) -> str: return "Vae encoder step that preprocess andencode the image inputs into their latent representations." -# Auto VAE encoder class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks): block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep] block_names = ["inpaint", "img2img"] @@ -107,7 +226,37 @@ def description(self): # optional controlnet vae encoder +# auto_docstring class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + This is an auto pipeline block. + - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided. + - if `control_image` is not provided, step will be skipped. + + Components: + + vae (`AutoencoderKLQwenImage`) + + controlnet (`QwenImageControlNetModel`) + + control_image_processor (`VaeImageProcessor`) + + Inputs: + control_image (`Image`, *optional*): + Control image for ControlNet conditioning. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + control_image_latents (`Tensor`): + The latents representing the control image + """ + block_classes = [QwenImageControlNetVaeEncoderStep] block_names = ["controlnet"] block_trigger_inputs = ["control_image"] @@ -123,12 +272,49 @@ def description(self): # ==================== -# 2. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) +# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise) # ==================== # assemble input steps +# auto_docstring class QwenImageImg2ImgInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the img2img denoising step. It: + + Components: + + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`None`): + TODO: Add description. + prompt_embeds_mask (`None`): + TODO: Add description. + negative_prompt_embeds (`None`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`None`, *optional*): + TODO: Add description. + + Outputs: + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + dtype (`dtype`): + Data type of model tensor inputs (determined by `prompt_embeds`) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + """ + model_name = "qwenimage" block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"])] block_names = ["text_inputs", "additional_inputs"] @@ -140,7 +326,46 @@ def description(self): " - update height/width based `image_latents`, patchify `image_latents`." +# auto_docstring class QwenImageInpaintInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the inpainting denoising step. It: + + Components: + + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`None`): + TODO: Add description. + prompt_embeds_mask (`None`): + TODO: Add description. + negative_prompt_embeds (`None`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`None`, *optional*): + TODO: Add description. + processed_mask_image (`None`, *optional*): + TODO: Add description. + + Outputs: + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + dtype (`dtype`): + Data type of model tensor inputs (determined by `prompt_embeds`) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + """ + model_name = "qwenimage" block_classes = [ QwenImageTextInputsStep(), @@ -158,7 +383,43 @@ def description(self): # assemble prepare latents steps +# auto_docstring class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks): + """ + This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It: + - Add noise to the image latents to create the latents input for the denoiser. + - Create the pachified latents `mask` based on the processedmask image. + + Components: + + scheduler (`FlowMatchEulerDiscreteScheduler`) + + pachifier (`QwenImagePachifier`) + + Inputs: + latents (`Tensor`): + The initial random noised, can be generated in prepare latent step. + image_latents (`Tensor`): + The image latents to use for the denoising process. Can be generated in vae encoder and packed in input + step. + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + processed_mask_image (`Tensor`): + The processed mask to use for the inpainting process. + height (`None`): + TODO: Add description. + width (`None`): + TODO: Add description. + dtype (`None`): + TODO: Add description. + + Outputs: + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + mask (`Tensor`): + The mask to use for the inpainting process. + """ + model_name = "qwenimage" block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()] block_names = ["add_noise_to_latents", "create_mask_latents"] @@ -176,7 +437,55 @@ def description(self) -> str: # Qwen Image (text2image) +# auto_docstring class QwenImageCoreDenoiseStep(SequentialPipelineBlocks): + """ + step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs + (timesteps, latents, rope inputs etc.). + + Components: + + pachifier (`QwenImagePachifier`) + + scheduler (`FlowMatchEulerDiscreteScheduler`) + + guider (`ClassifierFreeGuidance`) + + transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`None`): + TODO: Add description. + prompt_embeds_mask (`None`): + TODO: Add description. + negative_prompt_embeds (`None`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask (`None`, *optional*): + TODO: Add description. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`Tensor`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageTextInputsStep(), @@ -199,9 +508,69 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)." + @property + def outputs(self): + return [ + OutputParam.latents(), + ] + # Qwen Image (inpainting) +# auto_docstring class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint + task. + + Components: + + pachifier (`QwenImagePachifier`) + + scheduler (`FlowMatchEulerDiscreteScheduler`) + + guider (`ClassifierFreeGuidance`) + + transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`None`): + TODO: Add description. + prompt_embeds_mask (`None`): + TODO: Add description. + negative_prompt_embeds (`None`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`None`, *optional*): + TODO: Add description. + processed_mask_image (`None`, *optional*): + TODO: Add description. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`Tensor`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageInpaintInputStep(), @@ -226,9 +595,67 @@ class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task." + @property + def outputs(self): + return [ + OutputParam.latents(), + ] + # Qwen Image (image2image) +# auto_docstring class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img + task. + + Components: + + pachifier (`QwenImagePachifier`) + + scheduler (`FlowMatchEulerDiscreteScheduler`) + + guider (`ClassifierFreeGuidance`) + + transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`None`): + TODO: Add description. + prompt_embeds_mask (`None`): + TODO: Add description. + negative_prompt_embeds (`None`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`None`, *optional*): + TODO: Add description. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`Tensor`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageImg2ImgInputStep(), @@ -253,9 +680,74 @@ class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task." + @property + def outputs(self): + return [ + OutputParam.latents(), + ] + # Qwen Image (text2image) with controlnet +# auto_docstring class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks): + """ + step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs + (timesteps, latents, rope inputs etc.). + + Components: + + pachifier (`QwenImagePachifier`) + + scheduler (`FlowMatchEulerDiscreteScheduler`) + + controlnet (`QwenImageControlNetModel`) + + guider (`ClassifierFreeGuidance`) + + transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`None`): + TODO: Add description. + prompt_embeds_mask (`None`): + TODO: Add description. + negative_prompt_embeds (`None`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask (`None`, *optional*): + TODO: Add description. + control_image_latents (`None`): + TODO: Add description. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + **denoiser_input_fields (`None`, *optional*): + All conditional model inputs for the denoiser. It should contain prompt_embeds/negative_prompt_embeds, + txt_seq_lens/negative_txt_seq_lens. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageTextInputsStep(), @@ -282,9 +774,80 @@ class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)." + @property + def outputs(self): + return [ + OutputParam.latents(), + ] + # Qwen Image (inpainting) with controlnet +# auto_docstring class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint + task. + + Components: + + pachifier (`QwenImagePachifier`) + + scheduler (`FlowMatchEulerDiscreteScheduler`) + + controlnet (`QwenImageControlNetModel`) + + guider (`ClassifierFreeGuidance`) + + transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`None`): + TODO: Add description. + prompt_embeds_mask (`None`): + TODO: Add description. + negative_prompt_embeds (`None`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`None`, *optional*): + TODO: Add description. + processed_mask_image (`None`, *optional*): + TODO: Add description. + control_image_latents (`None`): + TODO: Add description. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + **denoiser_input_fields (`None`, *optional*): + All conditional model inputs for the denoiser. It should contain prompt_embeds/negative_prompt_embeds, + txt_seq_lens/negative_txt_seq_lens. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageInpaintInputStep(), @@ -313,9 +876,78 @@ class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task." + @property + def outputs(self): + return [ + OutputParam.latents(), + ] + # Qwen Image (image2image) with controlnet +# auto_docstring class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): + """ + Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img + task. + + Components: + + pachifier (`QwenImagePachifier`) + + scheduler (`FlowMatchEulerDiscreteScheduler`) + + controlnet (`QwenImageControlNetModel`) + + guider (`ClassifierFreeGuidance`) + + transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`None`): + TODO: Add description. + prompt_embeds_mask (`None`): + TODO: Add description. + negative_prompt_embeds (`None`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`None`, *optional*): + TODO: Add description. + control_image_latents (`None`): + TODO: Add description. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + **denoiser_input_fields (`None`, *optional*): + All conditional model inputs for the denoiser. It should contain prompt_embeds/negative_prompt_embeds, + txt_seq_lens/negative_txt_seq_lens. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage" block_classes = [ QwenImageImg2ImgInputStep(), @@ -344,6 +976,12 @@ class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task." + @property + def outputs(self): + return [ + OutputParam.latents(), + ] + # Auto denoise step for QwenImage class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks): @@ -402,19 +1040,38 @@ def description(self): @property def outputs(self): return [ - OutputParam( - name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step" - ), + OutputParam.latents(), ] # ==================== -# 3. DECODE +# 4. DECODE # ==================== # standard decode step works for most tasks except for inpaint +# auto_docstring class QwenImageDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image. + + Components: + + vae (`AutoencoderKLQwenImage`) + + image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The latents to decode, can be generated in the denoise step + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt''. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage" block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] block_names = ["decode", "postprocess"] @@ -425,7 +1082,31 @@ def description(self): # Inpaint decode step +# auto_docstring class QwenImageInpaintDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask + overally to the original image. + + Components: + + vae (`AutoencoderKLQwenImage`) + + image_mask_processor (`InpaintProcessor`) + + Inputs: + latents (`Tensor`): + The latents to decode, can be generated in the denoise step + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt''. + mask_overlay_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage" block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()] block_names = ["decode", "postprocess"] @@ -452,11 +1133,11 @@ def description(self): # ==================== -# 4. AUTO BLOCKS & PRESETS +# 5. AUTO BLOCKS & PRESETS # ==================== AUTO_BLOCKS = InsertableDict( [ - ("text_encoder", QwenImageTextEncoderStep()), + ("text_encoder", QwenImageAutoTextEncoderStep()), ("vae_encoder", QwenImageAutoVaeEncoderStep()), ("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()), ("denoise", QwenImageAutoCoreDenoiseStep()), @@ -465,7 +1146,114 @@ def description(self): ) +# auto_docstring class QwenImageAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage. + - for image-to-image generation, you need to provide `image` + - for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`. + - to run the controlnet workflow, you need to provide `control_image` + - for text-to-image generation, all you need to provide is `prompt` + + Components: + + text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use + + tokenizer (`Qwen2Tokenizer`): The tokenizer to use + + guider (`ClassifierFreeGuidance`) + + image_mask_processor (`InpaintProcessor`) + + vae (`AutoencoderKLQwenImage`) + + image_processor (`VaeImageProcessor`) + + controlnet (`QwenImageControlNetModel`) + + control_image_processor (`VaeImageProcessor`) + + pachifier (`QwenImagePachifier`) + + scheduler (`FlowMatchEulerDiscreteScheduler`) + + transformer (`QwenImageTransformer2DModel`) + + Configs: + + prompt_template_encode (default: <|im_start|>system + Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the + objects and background:<|im_end|> <|im_start|>user {}<|im_end|> <|im_start|>assistant ) + + prompt_template_encode_start_idx (default: 34) + + tokenizer_max_length (default: 1024) + + Inputs: + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + mask_image (`Image`, *optional*): + Mask image for inpainting. + image (`Image`, *optional*): + Input image for img2img, editing, or conditioning. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + control_image (`Image`, *optional*): + Control image for ControlNet conditioning. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`None`): + TODO: Add description. + prompt_embeds_mask (`None`): + TODO: Add description. + negative_prompt_embeds (`None`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask (`None`, *optional*): + TODO: Add description. + latents (`Tensor`): + Pre-generated noisy latents for image generation. + num_inference_steps (`int`): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`Tensor`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + image_latents (`None`, *optional*): + TODO: Add description. + processed_mask_image (`None`, *optional*): + TODO: Add description. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + control_image_latents (`None`, *optional*): + TODO: Add description. + control_guidance_start (`float`, *optional*, defaults to 0.0): + When to start applying ControlNet. + control_guidance_end (`float`, *optional*, defaults to 1.0): + When to stop applying ControlNet. + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + Scale for ControlNet conditioning. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt''. + mask_overlay_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage" block_classes = AUTO_BLOCKS.values() @@ -476,7 +1264,7 @@ def description(self): return ( "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n" + "- for image-to-image generation, you need to provide `image`\n" - + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.\n" + "- to run the controlnet workflow, you need to provide `control_image`\n" + "- for text-to-image generation, all you need to provide is `prompt`" ) @@ -484,5 +1272,5 @@ def description(self): @property def outputs(self): return [ - OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]), + OutputParam.images(), ] diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py index 2683e64080bf..0bfbb921c9c4 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional - -import PIL.Image -import torch +from typing import Optional from ...utils import logging from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks @@ -59,8 +56,51 @@ # ==================== +# auto_docstring class QwenImageEditVLEncoderStep(SequentialPipelineBlocks): - """VL encoder that takes both image and text prompts.""" + """ + QwenImage-Edit VL encoder step that encode the image and text prompts together. + + Components: + + image_resize_processor (`VaeImageProcessor`) + + text_encoder (`Qwen2_5_VLForConditionalGeneration`) + + processor (`Qwen2VLProcessor`) + + guider (`ClassifierFreeGuidance`) + + Configs: + + prompt_template_encode (default: <|im_start|>system + Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how + the user's text instruction should alter or modify the image. Generate a new image that meets the user's + requirements while maintaining consistency with the original input where appropriate.<|im_end|> <|im_start|>user + <|vision_start|><|image_pad|><|vision_end|>{}<|im_end|> <|im_start|>assistant ) + + prompt_template_encode_start_idx (default: 64) + + Inputs: + image (`Image`): + Input image for img2img, editing, or conditioning. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + + Outputs: + resized_image (`List`): + The resized images + prompt_embeds (`Tensor`): + The prompt embeddings + prompt_embeds_mask (`Tensor`): + The encoder attention mask + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask + """ model_name = "qwenimage-edit" block_classes = [ @@ -80,7 +120,34 @@ def description(self) -> str: # Edit VAE encoder +# auto_docstring class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + + Components: + + image_resize_processor (`VaeImageProcessor`) + + image_processor (`VaeImageProcessor`) + + vae (`AutoencoderKLQwenImage`) + + Inputs: + image (`Image`): + Input image for img2img, editing, or conditioning. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`List`): + The resized images + processed_image (`None`): + TODO: Add description. + image_latents (`Tensor`): + The latents representing the reference image(s). Single tensor or list depending on input. + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageEditResizeStep(), @@ -95,7 +162,45 @@ def description(self) -> str: # Edit Inpaint VAE encoder +# auto_docstring class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks): + """ + This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It: + - resize the image for target area (1024 * 1024) while maintaining the aspect ratio. + - process the resized image and mask image. + - create image latents. + + Components: + + image_resize_processor (`VaeImageProcessor`) + + image_mask_processor (`InpaintProcessor`) + + vae (`AutoencoderKLQwenImage`) + + Inputs: + image (`Image`): + Input image for img2img, editing, or conditioning. + mask_image (`Image`): + Mask image for inpainting. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`List`): + The resized images + processed_image (`None`): + TODO: Add description. + processed_mask_image (`None`): + TODO: Add description. + mask_overlay_kwargs (`Dict`): + The kwargs for the postprocess step to apply the mask overlay + image_latents (`Tensor`): + The latents representing the reference image(s). Single tensor or list depending on input. + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageEditResizeStep(), @@ -137,7 +242,46 @@ def description(self): # assemble input steps +# auto_docstring class QwenImageEditInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the edit denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs. + - update height/width based `image_latents`, patchify `image_latents`. + + Components: + + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`None`): + TODO: Add description. + prompt_embeds_mask (`None`): + TODO: Add description. + negative_prompt_embeds (`None`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`None`, *optional*): + TODO: Add description. + + Outputs: + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + dtype (`dtype`): + Data type of model tensor inputs (determined by `prompt_embeds`) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageTextInputsStep(), @@ -154,7 +298,48 @@ def description(self): ) +# auto_docstring class QwenImageEditInpaintInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the edit inpaint denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs. + - update height/width based `image_latents`, patchify `image_latents`. + + Components: + + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`None`): + TODO: Add description. + prompt_embeds_mask (`None`): + TODO: Add description. + negative_prompt_embeds (`None`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`None`, *optional*): + TODO: Add description. + processed_mask_image (`None`, *optional*): + TODO: Add description. + + Outputs: + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + dtype (`dtype`): + Data type of model tensor inputs (determined by `prompt_embeds`) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageTextInputsStep(), @@ -174,7 +359,43 @@ def description(self): # assemble prepare latents steps +# auto_docstring class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks): + """ + This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It: + - Add noise to the image latents to create the latents input for the denoiser. + - Create the patchified latents `mask` based on the processed mask image. + + Components: + + scheduler (`FlowMatchEulerDiscreteScheduler`) + + pachifier (`QwenImagePachifier`) + + Inputs: + latents (`Tensor`): + The initial random noised, can be generated in prepare latent step. + image_latents (`Tensor`): + The image latents to use for the denoising process. Can be generated in vae encoder and packed in input + step. + timesteps (`Tensor`): + The timesteps to use for the denoising process. Can be generated in set_timesteps step. + processed_mask_image (`Tensor`): + The processed mask to use for the inpainting process. + height (`None`): + TODO: Add description. + width (`None`): + TODO: Add description. + dtype (`None`): + TODO: Add description. + + Outputs: + initial_noise (`Tensor`): + The initial random noised used for inpainting denoising. + mask (`Tensor`): + The mask to use for the inpainting process. + """ + model_name = "qwenimage-edit" block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()] block_names = ["add_noise_to_latents", "create_mask_latents"] @@ -189,7 +410,56 @@ def description(self) -> str: # Qwen Image Edit (image2image) core denoise step +# auto_docstring class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Edit edit (img2img) task. + + Components: + + pachifier (`QwenImagePachifier`) + + scheduler (`FlowMatchEulerDiscreteScheduler`) + + guider (`ClassifierFreeGuidance`) + + transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`None`): + TODO: Add description. + prompt_embeds_mask (`None`): + TODO: Add description. + negative_prompt_embeds (`None`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`None`, *optional*): + TODO: Add description. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`Tensor`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageEditInputStep(), @@ -212,9 +482,68 @@ class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "Core denoising workflow for QwenImage-Edit edit (img2img) task." + @property + def outputs(self): + return [ + OutputParam.latents(), + ] + # Qwen Image Edit (inpainting) core denoise step +# auto_docstring class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Edit edit inpaint task. + + Components: + + pachifier (`QwenImagePachifier`) + + scheduler (`FlowMatchEulerDiscreteScheduler`) + + guider (`ClassifierFreeGuidance`) + + transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`None`): + TODO: Add description. + prompt_embeds_mask (`None`): + TODO: Add description. + negative_prompt_embeds (`None`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`None`, *optional*): + TODO: Add description. + processed_mask_image (`None`, *optional*): + TODO: Add description. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`Tensor`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage-edit" block_classes = [ QwenImageEditInpaintInputStep(), @@ -239,6 +568,12 @@ class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks): def description(self): return "Core denoising workflow for QwenImage-Edit edit inpaint task." + @property + def outputs(self): + return [ + OutputParam.latents(), + ] + # Auto core denoise step for QwenImage Edit class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks): @@ -267,6 +602,12 @@ def description(self): "Supports edit (img2img) and edit inpainting tasks for QwenImage-Edit." ) + @property + def outputs(self): + return [ + OutputParam.latents(), + ] + # ==================== # 4. DECODE @@ -274,7 +615,28 @@ def description(self): # Decode step (standard) +# auto_docstring class QwenImageEditDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image. + + Components: + + vae (`AutoencoderKLQwenImage`) + + image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The latents to decode, can be generated in the denoise step + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt''. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage-edit" block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] block_names = ["decode", "postprocess"] @@ -285,7 +647,31 @@ def description(self): # Inpaint decode step +# auto_docstring class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask + overlay to the original image. + + Components: + + vae (`AutoencoderKLQwenImage`) + + image_mask_processor (`InpaintProcessor`) + + Inputs: + latents (`Tensor`): + The latents to decode, can be generated in the denoise step + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt''. + mask_overlay_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage-edit" block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()] block_names = ["decode", "postprocess"] @@ -313,9 +699,7 @@ def description(self): @property def outputs(self): return [ - OutputParam( - name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step" - ), + OutputParam.latents(), ] @@ -333,7 +717,91 @@ def outputs(self): ) +# auto_docstring class QwenImageEditAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit. + - for edit (img2img) generation, you need to provide `image` + - for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide + `padding_mask_crop` + + Components: + + image_resize_processor (`VaeImageProcessor`) + + text_encoder (`Qwen2_5_VLForConditionalGeneration`) + + processor (`Qwen2VLProcessor`) + + guider (`ClassifierFreeGuidance`) + + image_mask_processor (`InpaintProcessor`) + + vae (`AutoencoderKLQwenImage`) + + image_processor (`VaeImageProcessor`) + + pachifier (`QwenImagePachifier`) + + scheduler (`FlowMatchEulerDiscreteScheduler`) + + transformer (`QwenImageTransformer2DModel`) + + Configs: + + prompt_template_encode (default: <|im_start|>system + Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how + the user's text instruction should alter or modify the image. Generate a new image that meets the user's + requirements while maintaining consistency with the original input where appropriate.<|im_end|> <|im_start|>user + <|vision_start|><|image_pad|><|vision_end|>{}<|im_end|> <|im_start|>assistant ) + + prompt_template_encode_start_idx (default: 64) + + Inputs: + image (`Image`): + Input image for img2img, editing, or conditioning. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + mask_image (`Image`, *optional*): + Mask image for inpainting. + padding_mask_crop (`int`, *optional*): + Padding for mask cropping in inpainting. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + image_latents (`None`): + TODO: Add description. + processed_mask_image (`None`, *optional*): + TODO: Add description. + latents (`Tensor`): + Pre-generated noisy latents for image generation. + num_inference_steps (`int`): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + strength (`float`, *optional*, defaults to 0.9): + Strength for img2img/inpainting. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`Tensor`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt''. + mask_overlay_kwargs (`None`, *optional*): + TODO: Add description. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage-edit" block_classes = EDIT_AUTO_BLOCKS.values() block_names = EDIT_AUTO_BLOCKS.keys() @@ -349,5 +817,5 @@ def description(self): @property def outputs(self): return [ - OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"), + OutputParam.images(), ] diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py index 99c5b109bf38..8dab6fbcf95d 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_edit_plus.py @@ -12,10 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List - -import PIL.Image -import torch from ...utils import logging from ..modular_pipeline import SequentialPipelineBlocks @@ -53,8 +49,53 @@ # ==================== +# auto_docstring class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks): - """VL encoder that takes both image and text prompts. Uses 384x384 target area.""" + """ + QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together. + + Components: + + image_resize_processor (`VaeImageProcessor`) + + text_encoder (`Qwen2_5_VLForConditionalGeneration`) + + processor (`Qwen2VLProcessor`) + + guider (`ClassifierFreeGuidance`) + + Configs: + + prompt_template_encode (default: <|im_start|>system + Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how + the user's text instruction should alter or modify the image. Generate a new image that meets the user's + requirements while maintaining consistency with the original input where appropriate.<|im_end|> <|im_start|>user + {}<|im_end|> <|im_start|>assistant ) + + img_template_encode (default: Picture {}: <|vision_start|><|image_pad|><|vision_end|>) + + prompt_template_encode_start_idx (default: 64) + + Inputs: + image (`Image`): + Input image for img2img, editing, or conditioning. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + + Outputs: + resized_cond_image (`List`): + The resized images + prompt_embeds (`Tensor`): + The prompt embeddings + prompt_embeds_mask (`Tensor`): + The encoder attention mask + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask + """ model_name = "qwenimage-edit-plus" block_classes = [ @@ -73,8 +114,34 @@ def description(self) -> str: # ==================== +# auto_docstring class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks): - """VAE encoder that handles multiple images with different sizes. Uses 1024x1024 target area.""" + """ + VAE encoder step that encodes image inputs into latent representations. + Each image is resized independently based on its own aspect ratio to 1024x1024 target area. + + Components: + + image_resize_processor (`VaeImageProcessor`) + + image_processor (`VaeImageProcessor`) + + vae (`AutoencoderKLQwenImage`) + + Inputs: + image (`Image`): + Input image for img2img, editing, or conditioning. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`List`): + The resized images + processed_image (`None`): + TODO: Add description. + image_latents (`Tensor`): + The latents representing the reference image(s). Single tensor or list depending on input. + """ model_name = "qwenimage-edit-plus" block_classes = [ @@ -98,7 +165,48 @@ def description(self) -> str: # assemble input steps +# auto_docstring class QwenImageEditPlusInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the Edit Plus denoising step. It: + - Standardizes text embeddings batch size. + - Processes list of image latents: patchifies, concatenates along dim=1, expands batch. + - Outputs lists of image_height/image_width for RoPE calculation. + - Defaults height/width from last image in the list. + + Components: + + pachifier (`QwenImagePachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`None`): + TODO: Add description. + prompt_embeds_mask (`None`): + TODO: Add description. + negative_prompt_embeds (`None`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`None`, *optional*): + TODO: Add description. + + Outputs: + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + dtype (`dtype`): + Data type of model tensor inputs (determined by `prompt_embeds`) + image_height (`List`): + The image heights calculated from the image latents dimension + image_width (`List`): + The image widths calculated from the image latents dimension + """ + model_name = "qwenimage-edit-plus" block_classes = [ QwenImageTextInputsStep(), @@ -118,7 +226,56 @@ def description(self): # Qwen Image Edit Plus (image2image) core denoise step +# auto_docstring class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Edit Plus edit (img2img) task. + + Components: + + pachifier (`QwenImagePachifier`) + + scheduler (`FlowMatchEulerDiscreteScheduler`) + + guider (`ClassifierFreeGuidance`) + + transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`None`): + TODO: Add description. + prompt_embeds_mask (`None`): + TODO: Add description. + negative_prompt_embeds (`None`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + image_latents (`None`, *optional*): + TODO: Add description. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`Tensor`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage-edit-plus" block_classes = [ QwenImageEditPlusInputStep(), @@ -144,9 +301,7 @@ def description(self): @property def outputs(self): return [ - OutputParam( - name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step" - ), + OutputParam.latents(), ] @@ -155,7 +310,28 @@ def outputs(self): # ==================== +# auto_docstring class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks): + """ + Decode step that decodes the latents to images and postprocesses the generated image. + + Components: + + vae (`AutoencoderKLQwenImage`) + + image_processor (`VaeImageProcessor`) + + Inputs: + latents (`Tensor`): + The latents to decode, can be generated in the denoise step + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt''. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage-edit-plus" block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()] block_names = ["decode", "postprocess"] @@ -179,7 +355,79 @@ def description(self): ) +# auto_docstring class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus. + - `image` is required input (can be single image or list of images). + - Each image is resized independently based on its own aspect ratio. + - VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area. + + Components: + + image_resize_processor (`VaeImageProcessor`) + + text_encoder (`Qwen2_5_VLForConditionalGeneration`) + + processor (`Qwen2VLProcessor`) + + guider (`ClassifierFreeGuidance`) + + image_processor (`VaeImageProcessor`) + + vae (`AutoencoderKLQwenImage`) + + pachifier (`QwenImagePachifier`) + + scheduler (`FlowMatchEulerDiscreteScheduler`) + + transformer (`QwenImageTransformer2DModel`) + + Configs: + + prompt_template_encode (default: <|im_start|>system + Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how + the user's text instruction should alter or modify the image. Generate a new image that meets the user's + requirements while maintaining consistency with the original input where appropriate.<|im_end|> <|im_start|>user + {}<|im_end|> <|im_start|>assistant ) + + img_template_encode (default: Picture {}: <|vision_start|><|image_pad|><|vision_end|>) + + prompt_template_encode_start_idx (default: 64) + + Inputs: + image (`Image`): + Input image for img2img, editing, or conditioning. + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`Tensor`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt''. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage-edit-plus" block_classes = EDIT_PLUS_AUTO_BLOCKS.values() block_names = EDIT_PLUS_AUTO_BLOCKS.keys() @@ -196,5 +444,5 @@ def description(self): @property def outputs(self): return [ - OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"), + OutputParam.images(), ] diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py index 63ee36df5112..544b1abfc3ed 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks_qwenimage_layered.py @@ -13,11 +13,6 @@ # limitations under the License. -from typing import List - -import PIL.Image -import torch - from ...utils import logging from ..modular_pipeline import SequentialPipelineBlocks from ..modular_pipeline_utils import InsertableDict, OutputParam @@ -55,8 +50,89 @@ # ==================== +# auto_docstring class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks): - """Text encoder that takes text prompt, will generate a prompt based on image if not provided.""" + """ + QwenImage-Layered Text encoder step that encode the text prompt, will generate a prompt based on image if not + provided. + + Components: + + image_resize_processor (`VaeImageProcessor`) + + text_encoder (`Qwen2_5_VLForConditionalGeneration`) + + processor (`Qwen2VLProcessor`) + + tokenizer (`Qwen2Tokenizer`): The tokenizer to use + + guider (`ClassifierFreeGuidance`) + + Configs: + + image_caption_prompt_en (default: <|im_start|>system + You are a helpful assistant.<|im_end|> <|im_start|>user # Image Annotator You are a professional image annotator. + Please write an image caption based on the input image: + 1. Write the caption using natural, descriptive language without structured formats or rich text. + 2. Enrich caption details by including: + - Object attributes, such as quantity, color, shape, size, material, state, position, actions, and so on + - Vision Relations between objects, such as spatial relations, functional relations, possessive relations, + attachment relations, action relations, comparative relations, causal relations, and so on + - Environmental details, such as weather, lighting, colors, textures, atmosphere, and so on + - Identify the text clearly visible in the image, without translation or explanation, and highlight it in the + caption with quotation marks + 3. Maintain authenticity and accuracy: + - Avoid generalizations + - Describe all visible information in the image, while do not add information not explicitly shown in the image + <|vision_start|><|image_pad|><|vision_end|><|im_end|> <|im_start|>assistant ) + + image_caption_prompt_cn (default: <|im_start|>system + You are a helpful assistant.<|im_end|> <|im_start|>user # 图像标注器 你是一个专业的图像标注器。请基于输入图像,撰写图注: + 1. 使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。 + 2. 通过加入以下内容,丰富图注细节: + - 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等 + - 对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等 + - 环境细节:例如天气、光照、颜色、纹理、气氛等 + - 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在图注中强调 + 3. 保持真实性与准确性: + - 不要使用笼统的描述 + - 描述图像中所有可见的信息,但不要加入没有在图像中出现的内容 + <|vision_start|><|image_pad|><|vision_end|><|im_end|> <|im_start|>assistant ) + + prompt_template_encode (default: <|im_start|>system + Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the + objects and background:<|im_end|> <|im_start|>user {}<|im_end|> <|im_start|>assistant ) + + prompt_template_encode_start_idx (default: 34) + + tokenizer_max_length (default: 1024) + + Inputs: + image (`Image`): + Input image for img2img, editing, or conditioning. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + prompt (`str`, *optional*): + The prompt to encode + use_en_prompt (`bool`, *optional*, defaults to False): + Whether to use English prompt template + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + + Outputs: + resized_image (`List`): + The resized images + prompt_embeds (`Tensor`): + The prompt embeddings + prompt_embeds_mask (`Tensor`): + The encoder attention mask + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask + """ model_name = "qwenimage-layered" block_classes = [ @@ -77,7 +153,36 @@ def description(self) -> str: # Edit VAE encoder +# auto_docstring class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks): + """ + Vae encoder step that encode the image inputs into their latent representations. + + Components: + + image_resize_processor (`VaeImageProcessor`) + + image_processor (`VaeImageProcessor`) + + vae (`AutoencoderKLQwenImage`) + + Inputs: + image (`Image`): + Input image for img2img, editing, or conditioning. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + + Outputs: + resized_image (`List`): + The resized images + processed_image (`None`): + TODO: Add description. + image_latents (`Tensor`): + The latents representing the reference image(s). Single tensor or list depending on input. + """ + model_name = "qwenimage-layered" block_classes = [ QwenImageLayeredResizeStep(), @@ -98,7 +203,46 @@ def description(self) -> str: # assemble input steps +# auto_docstring class QwenImageLayeredInputStep(SequentialPipelineBlocks): + """ + Input step that prepares the inputs for the layered denoising step. It: + - make sure the text embeddings have consistent batch size as well as the additional inputs. + - update height/width based `image_latents`, patchify `image_latents`. + + Components: + + pachifier (`QwenImageLayeredPachifier`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`None`): + TODO: Add description. + prompt_embeds_mask (`None`): + TODO: Add description. + negative_prompt_embeds (`None`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask (`None`, *optional*): + TODO: Add description. + image_latents (`None`, *optional*): + TODO: Add description. + + Outputs: + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt + dtype (`dtype`): + Data type of model tensor inputs (determined by `prompt_embeds`) + image_height (`int`): + The image height calculated from the image latents dimension + image_width (`int`): + The image width calculated from the image latents dimension + height (`int`): + The height of the image output + width (`int`): + The width of the image output + """ + model_name = "qwenimage-layered" block_classes = [ QwenImageTextInputsStep(), @@ -116,7 +260,54 @@ def description(self): # Qwen Image Layered (image2image) core denoise step +# auto_docstring class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks): + """ + Core denoising workflow for QwenImage-Layered img2img task. + + Components: + + pachifier (`QwenImageLayeredPachifier`) + + scheduler (`FlowMatchEulerDiscreteScheduler`) + + guider (`ClassifierFreeGuidance`) + + transformer (`QwenImageTransformer2DModel`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`None`): + TODO: Add description. + prompt_embeds_mask (`None`): + TODO: Add description. + negative_prompt_embeds (`None`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask (`None`, *optional*): + TODO: Add description. + image_latents (`None`, *optional*): + TODO: Add description. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`Tensor`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "qwenimage-layered" block_classes = [ QwenImageLayeredInputStep(), @@ -142,9 +333,7 @@ def description(self): @property def outputs(self): return [ - OutputParam( - name="latents", type_hint=torch.Tensor, description="The latents generated by the denoising step" - ), + OutputParam.latents(), ] @@ -162,7 +351,109 @@ def outputs(self): ) +# auto_docstring class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks): + """ + Auto Modular pipeline for layered denoising tasks using QwenImage-Layered. + + Components: + + image_resize_processor (`VaeImageProcessor`) + + text_encoder (`Qwen2_5_VLForConditionalGeneration`) + + processor (`Qwen2VLProcessor`) + + tokenizer (`Qwen2Tokenizer`): The tokenizer to use + + guider (`ClassifierFreeGuidance`) + + image_processor (`VaeImageProcessor`) + + vae (`AutoencoderKLQwenImage`) + + pachifier (`QwenImageLayeredPachifier`) + + scheduler (`FlowMatchEulerDiscreteScheduler`) + + transformer (`QwenImageTransformer2DModel`) + + Configs: + + image_caption_prompt_en (default: <|im_start|>system + You are a helpful assistant.<|im_end|> <|im_start|>user # Image Annotator You are a professional image annotator. + Please write an image caption based on the input image: + 1. Write the caption using natural, descriptive language without structured formats or rich text. + 2. Enrich caption details by including: + - Object attributes, such as quantity, color, shape, size, material, state, position, actions, and so on + - Vision Relations between objects, such as spatial relations, functional relations, possessive relations, + attachment relations, action relations, comparative relations, causal relations, and so on + - Environmental details, such as weather, lighting, colors, textures, atmosphere, and so on + - Identify the text clearly visible in the image, without translation or explanation, and highlight it in the + caption with quotation marks + 3. Maintain authenticity and accuracy: + - Avoid generalizations + - Describe all visible information in the image, while do not add information not explicitly shown in the image + <|vision_start|><|image_pad|><|vision_end|><|im_end|> <|im_start|>assistant ) + + image_caption_prompt_cn (default: <|im_start|>system + You are a helpful assistant.<|im_end|> <|im_start|>user # 图像标注器 你是一个专业的图像标注器。请基于输入图像,撰写图注: + 1. 使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。 + 2. 通过加入以下内容,丰富图注细节: + - 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等 + - 对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等 + - 环境细节:例如天气、光照、颜色、纹理、气氛等 + - 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在图注中强调 + 3. 保持真实性与准确性: + - 不要使用笼统的描述 + - 描述图像中所有可见的信息,但不要加入没有在图像中出现的内容 + <|vision_start|><|image_pad|><|vision_end|><|im_end|> <|im_start|>assistant ) + + prompt_template_encode (default: <|im_start|>system + Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the + objects and background:<|im_end|> <|im_start|>user {}<|im_end|> <|im_start|>assistant ) + + prompt_template_encode_start_idx (default: 34) + + tokenizer_max_length (default: 1024) + + Inputs: + image (`Image`): + Input image for img2img, editing, or conditioning. + resolution (`int`, *optional*, defaults to 640): + The target area to resize the image to, can be 1024 or 640 + prompt (`str`, *optional*): + The prompt to encode + use_en_prompt (`bool`, *optional*, defaults to False): + Whether to use English prompt template + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 1024): + Maximum sequence length for prompt encoding. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + layers (`int`, *optional*, defaults to 4): + Number of layers to extract from the image + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`List`, *optional*): + Custom sigmas for the denoising process. + attention_kwargs (`Dict`, *optional*): + Additional kwargs for attention processors. + **denoiser_input_fields (`Tensor`, *optional*): + conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt''. + + Outputs: + images (`List`): + Generated images. + """ + model_name = "qwenimage-layered" block_classes = LAYERED_AUTO_BLOCKS.values() block_names = LAYERED_AUTO_BLOCKS.keys() @@ -174,5 +465,5 @@ def description(self): @property def outputs(self): return [ - OutputParam(name="images", type_hint=List[List[PIL.Image.Image]], description="The generated images"), + OutputParam.images(), ] diff --git a/src/diffusers/modular_pipelines/z_image/denoise.py b/src/diffusers/modular_pipelines/z_image/denoise.py index 3d5a00a9df50..a165fb513f3c 100644 --- a/src/diffusers/modular_pipelines/z_image/denoise.py +++ b/src/diffusers/modular_pipelines/z_image/denoise.py @@ -129,10 +129,7 @@ def inputs(self) -> List[Tuple[str, Any]]: type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), - InputParam( - kwargs_type="denoiser_input_fields", - description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", - ), + InputParam.denoiser_input_fields(), ] guider_input_names = [] uncond_guider_input_names = [] diff --git a/utils/modular_auto_docstring.py b/utils/modular_auto_docstring.py new file mode 100644 index 000000000000..7bb2c87e81da --- /dev/null +++ b/utils/modular_auto_docstring.py @@ -0,0 +1,300 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Auto Docstring Generator for Modular Pipeline Blocks + +This script scans Python files for classes that have `# auto_docstring` comment above them +and inserts/updates the docstring from the class's `doc` property. + +Run from the root of the repo: + python utils/modular_auto_docstring.py [path] [--fix_and_overwrite] + +Examples: + # Check for auto_docstring markers (will error if found without proper docstring) + python utils/modular_auto_docstring.py + + # Check specific directory + python utils/modular_auto_docstring.py src/diffusers/modular_pipelines/ + + # Fix and overwrite the docstrings + python utils/modular_auto_docstring.py --fix_and_overwrite + +Usage in code: + # auto_docstring + class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks): + # docstring will be automatically inserted here + + @property + def doc(self): + return "Your docstring content..." +""" + +import argparse +import ast +import glob +import importlib +import os +import re +import sys + + +# All paths are set with the intent you should run this script from the root of the repo +DIFFUSERS_PATH = "src/diffusers" +REPO_PATH = "." + +# Pattern to match the auto_docstring comment +AUTO_DOCSTRING_PATTERN = re.compile(r"^\s*#\s*auto_docstring\s*$") + + +def setup_diffusers_import(): + """Setup import path to use the local diffusers module.""" + src_path = os.path.join(REPO_PATH, "src") + if src_path not in sys.path: + sys.path.insert(0, src_path) + + +def get_module_from_filepath(filepath: str) -> str: + """Convert a filepath to a module name.""" + filepath = os.path.normpath(filepath) + + if filepath.startswith("src" + os.sep): + filepath = filepath[4:] + + if filepath.endswith(".py"): + filepath = filepath[:-3] + + module_name = filepath.replace(os.sep, ".") + return module_name + + +def load_module(filepath: str): + """Load a module from filepath.""" + setup_diffusers_import() + module_name = get_module_from_filepath(filepath) + + try: + module = importlib.import_module(module_name) + return module + except Exception as e: + print(f"Warning: Could not import module {module_name}: {e}") + return None + + +def get_doc_from_class(module, class_name: str) -> str: + """Get the doc property from an instantiated class.""" + if module is None: + return None + + cls = getattr(module, class_name, None) + if cls is None: + return None + + try: + instance = cls() + if hasattr(instance, "doc"): + return instance.doc + except Exception as e: + print(f"Warning: Could not instantiate {class_name}: {e}") + + return None + + +def find_auto_docstring_classes(filepath: str) -> list: + """ + Find all classes in a file that have # auto_docstring comment above them. + + Returns list of (class_name, class_line_number, has_existing_docstring, docstring_end_line) + """ + with open(filepath, "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + + # Parse AST to find class locations and their docstrings + content = "".join(lines) + try: + tree = ast.parse(content) + except SyntaxError as e: + print(f"Syntax error in {filepath}: {e}") + return [] + + # Build a map of class_name -> (class_line, has_docstring, docstring_end_line) + class_info = {} + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + has_docstring = False + docstring_end_line = node.lineno # default to class line + + if node.body and isinstance(node.body[0], ast.Expr): + first_stmt = node.body[0] + if isinstance(first_stmt.value, ast.Constant) and isinstance(first_stmt.value.value, str): + has_docstring = True + docstring_end_line = first_stmt.end_lineno or first_stmt.lineno + + class_info[node.name] = (node.lineno, has_docstring, docstring_end_line) + + # Now scan for # auto_docstring comments + classes_to_update = [] + + for i, line in enumerate(lines): + if AUTO_DOCSTRING_PATTERN.match(line): + # Found the marker, look for class definition on next non-empty, non-comment line + j = i + 1 + while j < len(lines): + next_line = lines[j].strip() + if next_line and not next_line.startswith("#"): + break + j += 1 + + if j < len(lines) and lines[j].strip().startswith("class "): + # Extract class name + match = re.match(r"class\s+(\w+)", lines[j].strip()) + if match: + class_name = match.group(1) + if class_name in class_info: + class_line, has_docstring, docstring_end_line = class_info[class_name] + classes_to_update.append((class_name, class_line, has_docstring, docstring_end_line)) + + return classes_to_update + + +def strip_class_name_line(doc: str, class_name: str) -> str: + """Remove the 'class ClassName' line from the doc if present.""" + lines = doc.strip().split("\n") + if lines and lines[0].strip() == f"class {class_name}": + # Remove the class line and any blank line following it + lines = lines[1:] + while lines and not lines[0].strip(): + lines = lines[1:] + return "\n".join(lines) + + +def format_docstring(doc: str, indent: str = " ") -> str: + """Format a doc string as a properly indented docstring.""" + lines = doc.strip().split("\n") + + if len(lines) == 1: + return f'{indent}"""{lines[0]}"""\n' + else: + result = [f'{indent}"""\n'] + for line in lines: + if line.strip(): + result.append(f"{indent}{line}\n") + else: + result.append("\n") + result.append(f'{indent}"""\n') + return "".join(result) + + +def process_file(filepath: str, overwrite: bool = False) -> list: + """ + Process a file and find/insert docstrings for # auto_docstring marked classes. + + Returns list of classes that need updating. + """ + classes_to_update = find_auto_docstring_classes(filepath) + + if not classes_to_update: + return [] + + if not overwrite: + # Just return the list of classes that need updating + return [(filepath, cls_name, line) for cls_name, line, _, _ in classes_to_update] + + # Load the module to get doc properties + module = load_module(filepath) + + with open(filepath, "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + + # Process in reverse order to maintain line numbers + updated = False + for class_name, class_line, has_docstring, docstring_end_line in reversed(classes_to_update): + doc = get_doc_from_class(module, class_name) + + if doc is None: + print(f"Warning: Could not get doc for {class_name} in {filepath}") + continue + + # Remove the "class ClassName" line since it's redundant in a docstring + doc = strip_class_name_line(doc, class_name) + + # Format the new docstring with 4-space indent + new_docstring = format_docstring(doc, " ") + + if has_docstring: + # Replace existing docstring (line after class definition to docstring_end_line) + # class_line is 1-indexed, we want to replace from class_line+1 to docstring_end_line + lines = lines[:class_line] + [new_docstring] + lines[docstring_end_line:] + else: + # Insert new docstring right after class definition line + # class_line is 1-indexed, so lines[class_line-1] is the class line + # Insert at position class_line (which is right after the class line) + lines = lines[:class_line] + [new_docstring] + lines[class_line:] + + updated = True + print(f"Updated docstring for {class_name} in {filepath}") + + if updated: + with open(filepath, "w", encoding="utf-8", newline="\n") as f: + f.writelines(lines) + + return [(filepath, cls_name, line) for cls_name, line, _, _ in classes_to_update] + + +def check_auto_docstrings(path: str = None, overwrite: bool = False): + """ + Check all files for # auto_docstring markers and optionally fix them. + """ + if path is None: + path = DIFFUSERS_PATH + + if os.path.isfile(path): + all_files = [path] + else: + all_files = glob.glob(os.path.join(path, "**/*.py"), recursive=True) + + all_markers = [] + + for filepath in all_files: + markers = process_file(filepath, overwrite) + all_markers.extend(markers) + + if not overwrite and len(all_markers) > 0: + message = "\n".join([f"- {f}: {cls} at line {line}" for f, cls, line in all_markers]) + raise ValueError( + f"Found the following # auto_docstring markers that need docstrings:\n{message}\n\n" + f"Run `python utils/modular_auto_docstring.py --fix_and_overwrite` to fix them." + ) + + if overwrite and len(all_markers) > 0: + print(f"\nUpdated {len(all_markers)} docstring(s).") + elif len(all_markers) == 0: + print("No # auto_docstring markers found.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Check and fix # auto_docstring markers in modular pipeline blocks", + ) + parser.add_argument("path", nargs="?", default=None, help="File or directory to process (default: src/diffusers)") + parser.add_argument( + "--fix_and_overwrite", + action="store_true", + help="Whether to fix the docstrings by inserting them from doc property.", + ) + + args = parser.parse_args() + + check_auto_docstrings(args.path, args.fix_and_overwrite)