Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 199 additions & 0 deletions src/diffusers/modular_pipelines/modular_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Literal, Optional, Type, Union

import PIL.Image
import torch

from ..configuration_utils import ConfigMixin, FrozenDict
Expand Down Expand Up @@ -342,6 +343,185 @@ class InputParam:
def __repr__(self):
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"

@classmethod
def template(cls, name: str) -> Optional["InputParam"]:
"""Get template for name if exists, otherwise None."""
if hasattr(cls, name) and callable(getattr(cls, name)):
return getattr(cls, name)()
return None

# ======================================================
# InputParam templates
# ======================================================

@classmethod
def prompt(cls) -> "InputParam":
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our pipeline parameter are pretty consistent across different pipelines, e.g. you always have prompt, height, width, num_inference_steps, etc. I made template for these common ones, so that it is easier to define

before you need

InputParam(
    name="prompt",
    type_hint=str,
    required=True,
    description="The prompt or prompts to guide image generation."
)

now you do

InputParam.prompt()
InputParam.height(default=1024)
InputParam.num_inference_steps(default=28)
InputParam.generator()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit apprehensive about introducing dedicated class methods for common parameters in this way. I think the class can become quite large as common inputs expand.

I would prefer to keep current syntax (IMO this ensures InputParams are defined in a consistent way) and use post init on the dataclass to automatically add a description. e.g

# centralised descriptions would live somewhere like constants.py 
# can be used for modular + non-modular
INPUT_PARAM_TEMPLATES = {
      "prompt": {"type_hint": str, "required": True, "description": "The prompt or prompts to guide image generation."},
      "height": {"type_hint": int, "description": "The height in pixels of the generated image."},
      "width": {"type_hint": int, "description": "The width in pixels of the generated image."},
      "generator": {"type_hint": torch.Generator, "description": "Torch generator for deterministic generation."},
      # ...
  }

  @dataclass
  class InputParam:
      name: str = None
      type_hint: Any = None
      required: bool = False
      default: Any = None
      description: str = None

  def __post_init__(self):
      if not self.name or self.name not in INPUT_PARAM_TEMPLATES:
          return

      template = INPUT_PARAM_TEMPLATES[self.name]
      if self.type_hint is None:
          self.type_hint = template.get("type_hint")
      if self.description is None:
          self.description = template.get("description")

If we feel that methods for these inputs are necessary, one way to address it without adding individual methods to the InputParam is to use a metaclass. It would result in the InputParam object being less crowded.

  class InputParamMeta(type):
      def __getattr__(cls, name: str):
          if name in INPUT_PARAM_TEMPLATES:
              def factory(**overrides):
                  return cls(name=name, **overrides)
              return factory
          raise AttributeError(f"No template named '{name}'")


  @dataclass
  class InputParam(metaclass=InputParamMeta):

return cls(
name="prompt", type_hint=str, required=True, description="The prompt or prompts to guide image generation."
)

@classmethod
def negative_prompt(cls) -> "InputParam":
return cls(
name="negative_prompt",
type_hint=str,
default=None,
description="The prompt or prompts not to guide the image generation.",
)

@classmethod
def max_sequence_length(cls, default: int = 512) -> "InputParam":
return cls(
name="max_sequence_length",
type_hint=int,
default=default,
description="Maximum sequence length for prompt encoding.",
)

@classmethod
def height(cls, default: Optional[int] = None) -> "InputParam":
return cls(
name="height", type_hint=int, default=default, description="The height in pixels of the generated image."
)

@classmethod
def width(cls, default: Optional[int] = None) -> "InputParam":
return cls(
name="width", type_hint=int, default=default, description="The width in pixels of the generated image."
)

@classmethod
def num_inference_steps(cls, default: int = 50) -> "InputParam":
return cls(
name="num_inference_steps", type_hint=int, default=default, description="The number of denoising steps."
)

@classmethod
def num_images_per_prompt(cls, default: int = 1) -> "InputParam":
return cls(
name="num_images_per_prompt",
type_hint=int,
default=default,
description="The number of images to generate per prompt.",
)

@classmethod
def generator(cls) -> "InputParam":
return cls(
name="generator",
type_hint=torch.Generator,
default=None,
description="Torch generator for deterministic generation.",
)

@classmethod
def sigmas(cls) -> "InputParam":
return cls(
name="sigmas", type_hint=List[float], default=None, description="Custom sigmas for the denoising process."
)

@classmethod
def strength(cls, default: float = 0.9) -> "InputParam":
return cls(name="strength", type_hint=float, default=default, description="Strength for img2img/inpainting.")

# images
@classmethod
def image(cls) -> "InputParam":
return cls(
name="image",
type_hint=PIL.Image.Image,
required=True,
description="Input image for img2img, editing, or conditioning.",
)

@classmethod
def mask_image(cls) -> "InputParam":
return cls(
name="mask_image", type_hint=PIL.Image.Image, required=True, description="Mask image for inpainting."
)

@classmethod
def control_image(cls) -> "InputParam":
return cls(
name="control_image",
type_hint=PIL.Image.Image,
required=True,
description="Control image for ControlNet conditioning.",
)

@classmethod
def padding_mask_crop(cls) -> "InputParam":
return cls(
name="padding_mask_crop",
type_hint=int,
default=None,
description="Padding for mask cropping in inpainting.",
)

@classmethod
def latents(cls) -> "InputParam":
return cls(
name="latents",
type_hint=torch.Tensor,
default=None,
description="Pre-generated noisy latents for image generation.",
)

@classmethod
def timesteps(cls) -> "InputParam":
return cls(
name="timesteps", type_hint=torch.Tensor, default=None, description="Timesteps for the denoising process."
)

@classmethod
def output_type(cls) -> "InputParam":
return cls(name="output_type", type_hint=str, default="pil", description="Output format: 'pil', 'np', 'pt''.")

@classmethod
def attention_kwargs(cls) -> "InputParam":
return cls(
name="attention_kwargs",
type_hint=Dict[str, Any],
default=None,
description="Additional kwargs for attention processors.",
)

@classmethod
def denoiser_input_fields(cls) -> "InputParam":
return cls(
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
)

# ControlNet
@classmethod
def control_guidance_start(cls, default: float = 0.0) -> "InputParam":
return cls(
name="control_guidance_start",
type_hint=float,
default=default,
description="When to start applying ControlNet.",
)

@classmethod
def control_guidance_end(cls, default: float = 1.0) -> "InputParam":
return cls(
name="control_guidance_end",
type_hint=float,
default=default,
description="When to stop applying ControlNet.",
)

@classmethod
def controlnet_conditioning_scale(cls, default: float = 1.0) -> "InputParam":
return cls(
name="controlnet_conditioning_scale",
type_hint=float,
default=default,
description="Scale for ControlNet conditioning.",
)


@dataclass
class OutputParam:
Expand All @@ -357,6 +537,25 @@ def __repr__(self):
f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
)

@classmethod
def template(cls, name: str) -> Optional["OutputParam"]:
"""Get template for name if exists, otherwise None."""
if hasattr(cls, name) and callable(getattr(cls, name)):
return getattr(cls, name)()
return None

# ======================================================
# OutputParam templates
# ======================================================

@classmethod
def images(cls) -> "OutputParam":
return cls(name="images", type_hint=List[PIL.Image.Image], description="Generated images.")

@classmethod
def latents(cls) -> "OutputParam":
return cls(name="latents", type_hint=torch.Tensor, description="Denoised latents.")


def format_inputs_short(inputs):
"""
Expand Down
46 changes: 24 additions & 22 deletions src/diffusers/modular_pipelines/qwenimage/before_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,11 @@ def expected_components(self) -> List[ComponentSpec]:
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("latents"),
InputParam(name="height"),
InputParam(name="width"),
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="generator"),
InputParam.latents(),
InputParam.height(),
InputParam.width(),
InputParam.num_images_per_prompt(),
InputParam.generator(),
InputParam(
name="batch_size",
required=True,
Expand Down Expand Up @@ -225,12 +225,14 @@ def expected_components(self) -> List[ComponentSpec]:
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("latents"),
InputParam(name="height"),
InputParam(name="width"),
InputParam(name="layers", default=4),
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="generator"),
InputParam.latents(),
InputParam.height(),
InputParam.width(),
InputParam(
name="layers", type_hint=int, default=4, description="Number of layers to extract from the image"
),
InputParam.num_images_per_prompt(),
InputParam.generator(),
InputParam(
name="batch_size",
required=True,
Expand Down Expand Up @@ -466,8 +468,8 @@ def expected_components(self) -> List[ComponentSpec]:
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="num_inference_steps", default=50),
InputParam(name="sigmas"),
InputParam.num_inference_steps(),
InputParam.sigmas(),
InputParam(
name="latents",
required=True,
Expand Down Expand Up @@ -532,8 +534,8 @@ def expected_components(self) -> List[ComponentSpec]:
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_inference_steps", default=50, type_hint=int),
InputParam("sigmas", type_hint=List[float]),
InputParam.num_inference_steps(),
InputParam.sigmas(),
InputParam("image_latents", required=True, type_hint=torch.Tensor),
]

Expand Down Expand Up @@ -590,15 +592,15 @@ def expected_components(self) -> List[ComponentSpec]:
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="num_inference_steps", default=50),
InputParam(name="sigmas"),
InputParam.num_inference_steps(),
InputParam.sigmas(),
InputParam(
name="latents",
required=True,
type_hint=torch.Tensor,
description="The latents to use for the denoising process, used to calculate the image sequence length.",
),
InputParam(name="strength", default=0.9),
InputParam.strength(0.9),
]

@property
Expand Down Expand Up @@ -886,7 +888,7 @@ def description(self) -> str:
def inputs(self) -> List[InputParam]:
return [
InputParam(name="batch_size", required=True),
InputParam(name="layers", required=True),
InputParam(name="layers", default=4, description="Number of layers to extract from the image"),
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(name="prompt_embeds_mask"),
Expand Down Expand Up @@ -971,9 +973,9 @@ def description(self) -> str:
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("control_guidance_start", default=0.0),
InputParam("control_guidance_end", default=1.0),
InputParam("controlnet_conditioning_scale", default=1.0),
InputParam.control_guidance_start(),
InputParam.control_guidance_end(),
InputParam.controlnet_conditioning_scale(),
InputParam("control_image_latents", required=True),
InputParam(
"timesteps",
Expand Down
Loading
Loading