diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index ed5f01a0250d..47cf6f09f77d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -353,6 +353,8 @@ title: Flux2Transformer2DModel - local: api/models/flux_transformer title: FluxTransformer2DModel + - local: api/models/glm_image_transformer2d + title: GlmImageTransformer2DModel - local: api/models/hidream_image_transformer title: HiDreamImageTransformer2DModel - local: api/models/hunyuan_transformer2d @@ -547,6 +549,8 @@ title: Flux2 - local: api/pipelines/control_flux_inpaint title: FluxControlInpaint + - local: api/pipelines/glm_image + title: GLM-Image - local: api/pipelines/hidream title: HiDream-I1 - local: api/pipelines/hunyuandit diff --git a/docs/source/en/api/models/glm_image_transformer2d.md b/docs/source/en/api/models/glm_image_transformer2d.md new file mode 100644 index 000000000000..8a8b07456046 --- /dev/null +++ b/docs/source/en/api/models/glm_image_transformer2d.md @@ -0,0 +1,18 @@ + + +# GlmImageTransformer2DModel + +A Diffusion Transformer model for 2D data from [GlmImageTransformer2DModel]() + +## GlmImageTransformer2DModel + +[[autodoc]] GlmImageTransformer2DModel diff --git a/docs/source/en/api/pipelines/glm_image.md b/docs/source/en/api/pipelines/glm_image.md new file mode 100644 index 000000000000..98e0145b6807 --- /dev/null +++ b/docs/source/en/api/pipelines/glm_image.md @@ -0,0 +1,95 @@ + + +# GLM-Image + +## Overview + +GLM-Image is an image generation model adopts a hybrid autoregressive + diffusion decoder architecture, effectively pushing the upper bound of visual fidelity and fine-grained details. In general image generation quality, it aligns with industry-standard LDM-based approaches, while demonstrating significant advantages in knowledge-intensive image generation scenarios. + +Model architecture: a hybrid autoregressive + diffusion decoder design、 + ++ Autoregressive generator: a 9B-parameter model initialized from [GLM-4-9B-0414](https://huggingface.co/zai-org/GLM-4-9B-0414), with an expanded vocabulary to incorporate visual tokens. The model first generates a compact encoding of approximately 256 tokens, then expands to 1K–4K tokens, corresponding to 1K–2K high-resolution image outputs. You can check AR model in class `GlmImageForConditionalGeneration` of transformers library. ++ Diffusion Decoder: a 7B-parameter decoder based on a single-stream DiT architecture for latent-space image decoding. It is equipped with a Glyph Encoder text module, significantly improving accurate text rendering within images. + +Post-training with decoupled reinforcement learning: the model introduces a fine-grained, modular feedback strategy using the GRPO algorithm, substantially enhancing both semantic understanding and visual detail quality. + ++ Autoregressive module: provides low-frequency feedback signals focused on aesthetics and semantic alignment, improving instruction following and artistic expressiveness. ++ Decoder module: delivers high-frequency feedback targeting detail fidelity and text accuracy, resulting in highly realistic textures, lighting, and color reproduction, as well as more precise text rendering. + +GLM-Image supports both text-to-image and image-to-image generation within a single model + ++ Text-to-image: generates high-detail images from textual descriptions, with particularly strong performance in information-dense scenarios. ++ Image-to-image: supports a wide range of tasks, including image editing, style transfer, multi-subject consistency, and identity-preserving generation for people and objects. + +This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The codebase can be found [here](https://huggingface.co/zai-org/GLM-Image). + +## Usage examples + +### Text to Image Generation + +```python +import torch +from diffusers.pipelines.glm_image import GlmImagePipeline + +pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda") +prompt = "A beautifully designed modern food magazine style dessert recipe illustration, themed around a raspberry mousse cake. The overall layout is clean and bright, divided into four main areas: the top left features a bold black title 'Raspberry Mousse Cake Recipe Guide', with a soft-lit close-up photo of the finished cake on the right, showcasing a light pink cake adorned with fresh raspberries and mint leaves; the bottom left contains an ingredient list section, titled 'Ingredients' in a simple font, listing 'Flour 150g', 'Eggs 3', 'Sugar 120g', 'Raspberry puree 200g', 'Gelatin sheets 10g', 'Whipping cream 300ml', and 'Fresh raspberries', each accompanied by minimalist line icons (like a flour bag, eggs, sugar jar, etc.); the bottom right displays four equally sized step boxes, each containing high-definition macro photos and corresponding instructions, arranged from top to bottom as follows: Step 1 shows a whisk whipping white foam (with the instruction 'Whip egg whites to stiff peaks'), Step 2 shows a red-and-white mixture being folded with a spatula (with the instruction 'Gently fold in the puree and batter'), Step 3 shows pink liquid being poured into a round mold (with the instruction 'Pour into mold and chill for 4 hours'), Step 4 shows the finished cake decorated with raspberries and mint leaves (with the instruction 'Decorate with raspberries and mint'); a light brown information bar runs along the bottom edge, with icons on the left representing 'Preparation time: 30 minutes', 'Cooking time: 20 minutes', and 'Servings: 8'. The overall color scheme is dominated by creamy white and light pink, with a subtle paper texture in the background, featuring compact and orderly text and image layout with clear information hierarchy." +image = pipe( + prompt=prompt, + height=32 * 32, + width=36 * 32, + num_inference_steps=30, + guidance_scale=1.5, + generator=torch.Generator(device="cuda").manual_seed(42), +).images[0] + +image.save("output_t2i.png") +``` + +### Image to Image Generation + +```python +import torch +from diffusers.pipelines.glm_image import GlmImagePipeline +from PIL import Image + +pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda") +image_path = "cond.jpg" +prompt = "Replace the background of the snow forest with an underground station featuring an automatic escalator." +image = Image.open(image_path).convert("RGB") +image = pipe( + prompt=prompt, + image=[image], # can input multiple images for multi-image-to-image generation such as [image, image1] + height=33 * 32, + width=32 * 32, + num_inference_steps=30, + guidance_scale=1.5, + generator=torch.Generator(device="cuda").manual_seed(42), +).images[0] + +image.save("output_i2i.png") +``` + ++ Since the AR model used in GLM-Image is configured with `do_sample=True` and a temperature of `0.95` by default, the generated images can vary significantly across runs. We do not recommend setting do_sample=False, as this may lead to incorrect or degenerate outputs from the AR model. + +## GlmImagePipeline + +[[autodoc]] GlmImagePipeline + - all + - __call__ + +## GlmImagePipelineOutput + +[[autodoc]] pipelines.cogview4.pipeline_output.GlmImagePipelineOutput diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index c749bad4be47..2f96348a4af0 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -225,6 +225,7 @@ "FluxControlNetModel", "FluxMultiControlNetModel", "FluxTransformer2DModel", + "GlmImageTransformer2DModel", "HiDreamImageTransformer2DModel", "HunyuanDiT2DControlNetModel", "HunyuanDiT2DModel", @@ -490,6 +491,7 @@ "FluxKontextPipeline", "FluxPipeline", "FluxPriorReduxPipeline", + "GlmImagePipeline", "HiDreamImagePipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", @@ -977,6 +979,7 @@ FluxControlNetModel, FluxMultiControlNetModel, FluxTransformer2DModel, + GlmImageTransformer2DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, @@ -1212,6 +1215,7 @@ FluxKontextPipeline, FluxPipeline, FluxPriorReduxPipeline, + GlmImagePipeline, HiDreamImagePipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 81730b7516be..4d1db36a7352 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -98,6 +98,7 @@ _import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"] + _import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"] _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"] @@ -208,6 +209,7 @@ EasyAnimateTransformer3DModel, Flux2Transformer2DModel, FluxTransformer2DModel, + GlmImageTransformer2DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DModel, HunyuanImageTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index f0c65202d080..d9d1b27a1e40 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -27,6 +27,7 @@ from .transformer_easyanimate import EasyAnimateTransformer3DModel from .transformer_flux import FluxTransformer2DModel from .transformer_flux2 import Flux2Transformer2DModel + from .transformer_glm_image import GlmImageTransformer2DModel from .transformer_hidream_image import HiDreamImageTransformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py new file mode 100644 index 000000000000..f4b5b819ede5 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -0,0 +1,614 @@ +# Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...utils import logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..attention_processor import Attention +from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import LayerNorm, RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class GlmImageCombinedTimestepSizeEmbeddings(nn.Module): + def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256): + super().__init__() + + self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim) + self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward( + self, + timestep: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + hidden_dtype: torch.dtype, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + + crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1) + target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1) + + # (B, 2 * condition_dim) + condition_proj = torch.cat([crop_coords_proj, target_size_proj], dim=1) + + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) + condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim) + + conditioning = timesteps_emb + condition_emb + conditioning = F.silu(conditioning) + + return conditioning + + +class GlmImageImageProjector(nn.Module): + def __init__( + self, + in_channels: int = 16, + hidden_size: int = 2560, + patch_size: int = 2, + ): + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Linear(in_channels * patch_size**2, hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, channel, height, width = hidden_states.shape + post_patch_height = height // self.patch_size + post_patch_width = width // self.patch_size + + hidden_states = hidden_states.reshape( + batch_size, channel, post_patch_height, self.patch_size, post_patch_width, self.patch_size + ) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + hidden_states = self.proj(hidden_states) + + return hidden_states + + +class GlmImageAdaLayerNormZero(nn.Module): + def __init__(self, embedding_dim: int, dim: int) -> None: + super().__init__() + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True) + + def forward( + self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + dtype = hidden_states.dtype + norm_hidden_states = self.norm(hidden_states).to(dtype=dtype) + norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype) + + emb = self.linear(temb) + ( + shift_msa, + c_shift_msa, + scale_msa, + c_scale_msa, + gate_msa, + c_gate_msa, + shift_mlp, + c_shift_mlp, + scale_mlp, + c_scale_mlp, + gate_mlp, + c_gate_mlp, + ) = emb.chunk(12, dim=1) + + hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1) + + return ( + hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) + + +class GlmImageLayerKVCache: + """KV cache for GlmImage model.""" + + def __init__(self): + self.k_cache = None + self.v_cache = None + self.mode: Optional[str] = None # "write", "read", "skip" + + def store(self, k: torch.Tensor, v: torch.Tensor): + if self.k_cache is None: + self.k_cache = k + self.v_cache = v + else: + self.k_cache = torch.cat([self.k_cache, k], dim=2) + self.v_cache = torch.cat([self.v_cache, v], dim=2) + + def get(self): + return self.k_cache, self.v_cache + + def clear(self): + self.k_cache = None + self.v_cache = None + self.mode = None + + +class GlmImageKVCache: + """Container for all layers' KV caches.""" + + def __init__(self, num_layers: int): + self.num_layers = num_layers + self.caches = [GlmImageLayerKVCache() for _ in range(num_layers)] + + def __getitem__(self, layer_idx: int) -> GlmImageLayerKVCache: + return self.caches[layer_idx] + + def set_mode(self, mode: Optional[str]): + if mode is not None and mode not in ["write", "read", "skip"]: + raise ValueError(f"Invalid mode: {mode}, must be one of 'write', 'read', 'skip'") + for cache in self.caches: + cache.mode = mode + + def clear(self): + for cache in self.caches: + cache.clear() + + +class GlmImageAttnProcessor: + """ + Processor for implementing scaled dot-product attention for the GlmImage model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + + The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size, + text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("GlmImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + kv_cache: Optional[GlmImageLayerKVCache] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + dtype = encoder_hidden_states.dtype + + batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape + batch_size, image_seq_length, embed_dim = hidden_states.shape + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # 1. QKV projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + # 2. QK normalization + if attn.norm_q is not None: + query = attn.norm_q(query).to(dtype=dtype) + if attn.norm_k is not None: + key = attn.norm_k(key).to(dtype=dtype) + + # 3. Rotational positional embeddings applied to latent stream + if image_rotary_emb is not None: + from ..embeddings import apply_rotary_emb + + query[:, text_seq_length:, :, :] = apply_rotary_emb( + query[:, text_seq_length:, :, :], image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2 + ) + key[:, text_seq_length:, :, :] = apply_rotary_emb( + key[:, text_seq_length:, :, :], image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2 + ) + + if kv_cache is not None: + if kv_cache.mode == "write": + kv_cache.store(key, value) + elif kv_cache.mode == "read": + k_cache, v_cache = kv_cache.get() + key = torch.cat([k_cache, key], dim=1) if k_cache is not None else key + value = torch.cat([v_cache, value], dim=1) if v_cache is not None else value + elif kv_cache.mode == "skip": + pass + + # 4. Attention + if attention_mask is not None: + text_attn_mask = attention_mask + assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)" + text_attn_mask = text_attn_mask.float().to(query.device) + mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device) + mix_attn_mask[:, :text_seq_length] = text_attn_mask + mix_attn_mask = mix_attn_mask.unsqueeze(2) + attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2) + attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # 5. Output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + + +@maybe_allow_in_graph +class GlmImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int = 2560, + num_attention_heads: int = 64, + attention_head_dim: int = 40, + time_embed_dim: int = 512, + ) -> None: + super().__init__() + + # 1. Attention + self.norm1 = GlmImageAdaLayerNormZero(time_embed_dim, dim) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + out_dim=dim, + bias=True, + qk_norm="layer_norm", + elementwise_affine=False, + eps=1e-5, + processor=GlmImageAttnProcessor(), + ) + + # 2. Feedforward + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]] + ] = None, + attention_mask: Optional[Dict[str, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + kv_cache: Optional[GlmImageLayerKVCache] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Timestep conditioning + ( + norm_hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + norm_encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = self.norm1(hidden_states, encoder_hidden_states, temb) + + # 2. Attention + attention_kwargs = attention_kwargs or {} + + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + kv_cache=kv_cache, + **attention_kwargs, + ) + hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1) + + # 3. Feedforward + norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * ( + 1 + c_scale_mlp.unsqueeze(1) + ) + c_shift_mlp.unsqueeze(1) + + ff_output = self.ff(norm_hidden_states) + ff_output_context = self.ff(norm_encoder_hidden_states) + hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1) + + return hidden_states, encoder_hidden_states + + +class GlmImageRotaryPosEmbed(nn.Module): + def __init__(self, dim: int, patch_size: int, theta: float = 10000.0) -> None: + super().__init__() + + self.dim = dim + self.patch_size = patch_size + self.theta = theta + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, num_channels, height, width = hidden_states.shape + height, width = height // self.patch_size, width // self.patch_size + + dim_h, dim_w = self.dim // 2, self.dim // 2 + h_inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h) + ) + w_inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w) + ) + h_seq = torch.arange(height) + w_seq = torch.arange(width) + freqs_h = torch.outer(h_seq, h_inv_freq) + freqs_w = torch.outer(w_seq, w_inv_freq) + + # Create position matrices for height and width + # [height, 1, dim//4] and [1, width, dim//4] + freqs_h = freqs_h.unsqueeze(1) + freqs_w = freqs_w.unsqueeze(0) + # Broadcast freqs_h and freqs_w to [height, width, dim//4] + freqs_h = freqs_h.expand(height, width, -1) + freqs_w = freqs_w.expand(height, width, -1) + + # Concatenate along last dimension to get [height, width, dim//2] + freqs = torch.cat([freqs_h, freqs_w], dim=-1) + freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim] + freqs = freqs.reshape(height * width, -1) + return (freqs.cos(), freqs.sin()) + + +class GlmImageAdaLayerNormContinuous(nn.Module): + """ + GlmImage-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the + Linear on conditioning embedding. + """ + + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + elementwise_affine: bool = True, + eps: float = 1e-5, + bias: bool = True, + norm_type: str = "layer_norm", + ): + super().__init__() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + # *** NO SiLU here *** + emb = self.linear(conditioning_embedding.to(x.dtype)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): + r""" + Args: + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + num_layers (`int`, defaults to `30`): + The number of layers of Transformer blocks to use. + attention_head_dim (`int`, defaults to `40`): + The number of channels in each head. + num_attention_heads (`int`, defaults to `64`): + The number of heads to use for multi-head attention. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_embed_dim (`int`, defaults to `1472`): + Input dimension of text embeddings from the text encoder. + time_embed_dim (`int`, defaults to `512`): + Output dimension of timestep embeddings. + condition_dim (`int`, defaults to `256`): + The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size, + crop_coords). + pos_embed_max_size (`int`, defaults to `128`): + The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added + to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128 + means that the maximum supported height and width for image generation is `128 * vae_scale_factor * + patch_size => 128 * 8 * 2 => 2048`. + sample_size (`int`, defaults to `128`): + The base resolution of input latents. If height/width is not provided during generation, this value is used + to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024` + """ + + _supports_gradient_checkpointing = True + _no_split_modules = [ + "GlmImageTransformerBlock", + "GlmImageImageProjector", + "GlmImageImageProjector", + ] + _skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"] + _skip_keys = ["kv_caches"] + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + out_channels: int = 16, + num_layers: int = 30, + attention_head_dim: int = 40, + num_attention_heads: int = 64, + text_embed_dim: int = 1472, + time_embed_dim: int = 512, + condition_dim: int = 256, + prior_vq_quantizer_codebook_size: int = 16384, + ): + super().__init__() + + # GlmImage uses 2 additional SDXL-like conditions - target_size, crop_coords + # Each of these are sincos embeddings of shape 2 * condition_dim + pooled_projection_dim = 2 * 2 * condition_dim + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels + + # 1. RoPE + self.rope = GlmImageRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0) + + # 2. Patch & Text-timestep embedding + self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size) + self.glyph_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu") + self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim) + self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu") + + self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings( + embedding_dim=time_embed_dim, + condition_dim=condition_dim, + pooled_projection_dim=pooled_projection_dim, + timesteps_dim=time_embed_dim, + ) + + # 3. Transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + GlmImageTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim) + for _ in range(num_layers) + ] + ) + + # 4. Output projection + self.norm_out = GlmImageAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + prior_token_id: torch.Tensor, + prior_token_drop: torch.Tensor, + timestep: torch.LongTensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + attention_mask: Optional[torch.Tensor] = None, + kv_caches: Optional[GlmImageKVCache] = None, + image_rotary_emb: Optional[ + Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]] + ] = None, + ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: + batch_size, num_channels, height, width = hidden_states.shape + + # 1. RoPE + if image_rotary_emb is None: + image_rotary_emb = self.rope(hidden_states) + + # 2. Patch & Timestep embeddings + p = self.config.patch_size + post_patch_height = height // p + post_patch_width = width // p + + hidden_states = self.image_projector(hidden_states) + encoder_hidden_states = self.glyph_projector(encoder_hidden_states) + prior_embedding = self.prior_token_embedding(prior_token_id) + prior_embedding[prior_token_drop] *= 0.0 + prior_hidden_states = self.prior_projector(prior_embedding) + + hidden_states = hidden_states + prior_hidden_states + + temb = self.time_condition_embed(timestep, target_size, crop_coords, hidden_states.dtype) + + # 3. Transformer blocks + for idx, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + attention_mask, + attention_kwargs, + kv_caches[idx] if kv_caches is not None else None, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + attention_mask, + attention_kwargs, + kv_cache=kv_caches[idx] if kv_caches is not None else None, + ) + + # 4. Output norm & projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p) + + # Rearrange tensor from (B, H_p, W_p, C, p, p) to (B, C, H_p * p, W_p * p) + output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index c14910250b54..6f583385de7a 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -52,6 +52,7 @@ FluxKontextPipeline, FluxPipeline, ) +from .glm_image import GlmImagePipeline from .hunyuandit import HunyuanDiTPipeline from .kandinsky import ( KandinskyCombinedPipeline, @@ -167,6 +168,7 @@ ("chroma", ChromaPipeline), ("cogview3", CogView3PlusPipeline), ("cogview4", CogView4Pipeline), + ("glm_image", GlmImagePipeline), ("cogview4-control", CogView4ControlPipeline), ("qwenimage", QwenImagePipeline), ("qwenimage-controlnet", QwenImageControlNetPipeline), diff --git a/src/diffusers/pipelines/glm_image/__init__.py b/src/diffusers/pipelines/glm_image/__init__.py new file mode 100644 index 000000000000..9df31b0b1734 --- /dev/null +++ b/src/diffusers/pipelines/glm_image/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["GlmImagePipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_glm_image"] = ["GlmImagePipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_glm_image import GlmImagePipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py new file mode 100644 index 000000000000..1c706208d3de --- /dev/null +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -0,0 +1,804 @@ +# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import re +from math import sqrt +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from transformers import ByT5Tokenizer, GlmImageForConditionalGeneration, GlmImageProcessor, T5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, GlmImageTransformer2DModel +from ...models.transformers.transformer_glm_image import GlmImageKVCache +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import GlmImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import GlmImagePipeline + + >>> pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + >>> image.save("output.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + base_shift: float = 0.25, + max_shift: float = 0.75, +) -> float: + m = (image_seq_len / base_seq_len) ** 0.5 + mu = m * max_shift + base_shift + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + """ + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + + if timesteps is not None and sigmas is not None: + if not accepts_timesteps and not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep or sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is not None and sigmas is None: + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is None and sigmas is not None: + if not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class GlmImagePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using GLM-Image. + + This pipeline integrates both the AR (autoregressive) model for token generation and the DiT (diffusion + transformer) model for image decoding. + + Args: + tokenizer (`PreTrainedTokenizer`): + Tokenizer for the text encoder. + processor (`AutoProcessor`): + Processor for the AR model to handle chat templates and tokenization. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder for glyph embeddings. + vision_language_encoder ([`GlmImageForConditionalGeneration`]): + The AR model that generates image tokens from text prompts. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + transformer ([`GlmImageTransformer2DModel`]): + A text conditioned transformer to denoise the encoded image latents (DiT). + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "vision_language_encoder->text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + tokenizer: ByT5Tokenizer, + processor: GlmImageProcessor, + text_encoder: T5EncoderModel, + vision_language_encoder: GlmImageForConditionalGeneration, + vae: AutoencoderKL, + transformer: GlmImageTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + processor=processor, + text_encoder=text_encoder, + vision_language_encoder=vision_language_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") + and self.transformer is not None + and hasattr(self.transformer.config, "sample_size") + else 128 + ) + + @staticmethod + def _build_image_grid_thw( + token_h: int, + token_w: int, + prev_token_h: int, + prev_token_w: int, + existing_grid: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + if existing_grid is None or existing_grid.numel() == 0: + return torch.tensor( + [ + [1, token_h, token_w], + [1, prev_token_h, prev_token_w], + ], + device=device, + ) + else: + return torch.cat([existing_grid.to(device), torch.tensor([[1, token_h, token_w]], device=device)], dim=0) + + @staticmethod + def _calculate_ar_generation_params( + token_h: int, token_w: int, prev_token_h: int, prev_token_w: int, is_text_to_image: bool + ) -> Tuple[int, int]: + large_image_tokens = token_h * token_w + small_image_tokens = prev_token_h * prev_token_w + + if is_text_to_image: + max_new_tokens = small_image_tokens + large_image_tokens + 1 + large_image_start_offset = small_image_tokens + else: + max_new_tokens = large_image_tokens + 1 + large_image_start_offset = 0 + + return max_new_tokens, large_image_start_offset + + @staticmethod + def _extract_large_image_tokens( + outputs: torch.Tensor, input_length: int, large_image_start_offset: int, large_image_tokens: int + ) -> torch.Tensor: + generated_tokens = outputs[0][input_length:] + large_image_start = large_image_start_offset + large_image_end = large_image_start + large_image_tokens + return generated_tokens[large_image_start:large_image_end] + + @staticmethod + def _upsample_token_ids(token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor: + token_ids = token_ids.view(1, 1, token_h, token_w) + token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to( + dtype=torch.long + ) + token_ids = token_ids.view(1, -1) + return token_ids + + @staticmethod + def _build_prompt_with_shape( + prompt: str, + height: int, + width: int, + is_text_to_image: bool, + factor: int = 32, + ) -> Tuple[str, int, int, int, int]: + token_h = height // factor + token_w = width // factor + ratio = token_h / token_w + prev_token_h = int(sqrt(ratio) * (factor // 2)) + prev_token_w = int(sqrt(1 / ratio) * (factor // 2)) + + if is_text_to_image: + expanded_prompt = f"{prompt}{token_h} {token_w}{prev_token_h} {prev_token_w}" + else: + expanded_prompt = f"{prompt}{token_h} {token_w}" + + return expanded_prompt, token_h, token_w, prev_token_h, prev_token_w + + def generate_prior_tokens( + self, + prompt: str, + height: int, + width: int, + image: Optional[List[PIL.Image.Image]] = None, + factor: int = 32, + ) -> Tuple[torch.Tensor, int, int]: + device = self.vision_language_encoder.device + height = (height // factor) * factor + width = (width // factor) * factor + is_text_to_image = image is None or len(image) == 0 + expanded_prompt, token_h, token_w, prev_h, prev_w = self._build_prompt_with_shape( + prompt, height, width, is_text_to_image + ) + content = [] + if image is not None: + for img in image: + content.append({"type": "image", "image": img}) + content.append({"type": "text", "text": expanded_prompt}) + messages = [{"role": "user", "content": content}] + inputs = self.processor.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to(device) + existing_grid = inputs.get("image_grid_thw") + inputs["image_grid_thw"] = self._build_image_grid_thw( + token_h, + token_w, + prev_h, + prev_w, + existing_grid=existing_grid if not is_text_to_image else None, + device=device, + ) + + max_new_tokens, large_image_offset = self._calculate_ar_generation_params( + token_h, token_w, prev_h, prev_w, is_text_to_image + ) + large_image_tokens = token_h * token_w + input_length = inputs["input_ids"].shape[-1] + + prior_token_image_ids = None + if image is not None: + prior_token_image_embed = self.vision_language_encoder.get_image_features( + inputs["pixel_values"], existing_grid + ) + prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0) + prior_token_image_ids = self.vision_language_encoder.get_image_tokens( + prior_token_image_embed, existing_grid + ) + outputs = self.vision_language_encoder.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=True, + ) + + prior_token_ids_d32 = self._extract_large_image_tokens( + outputs, input_length, large_image_offset, large_image_tokens + ) + prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w) + + return prior_token_ids, prior_token_image_ids + + def get_glyph_texts(self, prompt): + prompt = prompt[0] if isinstance(prompt, list) else prompt + ocr_texts = ( + re.findall(r"'([^']*)'", prompt) + + re.findall(r"“([^“”]*)”", prompt) + + re.findall(r'"([^"]*)"', prompt) + + re.findall(r"「([^「」]*)」", prompt) + ) + return ocr_texts + + def _get_glyph_embeds( + self, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 2048, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + glyph_texts = self.get_glyph_texts(prompt) + input_ids = self.tokenizer( + glyph_texts if len(glyph_texts) > 0 else [""], + max_length=max_sequence_length, + truncation=True, + ).input_ids + input_ids = [ + [self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids + ] + max_length = max(len(input_ids_) for input_ids_ in input_ids) + attention_mask = torch.tensor( + [[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids], device=device + ) + input_ids = torch.tensor( + [input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) for input_ids_ in input_ids], + device=device, + ) + outputs = self.text_encoder(input_ids, attention_mask=attention_mask) + glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0) + + return glyph_embeds.to(device=device, dtype=dtype) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 2048, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + max_sequence_length (`int`, defaults to `2048`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_glyph_embeds(prompt, max_sequence_length, device, dtype) + + seq_len = prompt_embeds.size(1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_embeds = None + if do_classifier_free_guidance: + negative_prompt = "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype) + + seq_len = negative_prompt_embeds.size(1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * self.transformer.config.patch_size) != 0 + or width is not None + and width % (self.transformer.config.patch_size) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Optional[ + Union[ + torch.Tensor, PIL.Image.Image, np.ndarray, List[torch.Tensor], List[PIL.Image.Image], List[np.ndarray] + ] + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 1.5, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 2048, + ) -> Union[GlmImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. Must contain shape info in the format 'H + W' where H and W are token dimensions (d32). Example: "A beautiful sunset36 24" + generates a 1152x768 image. + image: Optional condition images for image-to-image generation. + height (`int`, *optional*): + The height in pixels. If not provided, derived from prompt shape info. + width (`int`, *optional*): + The width in pixels. If not provided, derived from prompt shape info. + num_inference_steps (`int`, *optional*, defaults to `50`): + The number of denoising steps for DiT. + guidance_scale (`float`, *optional*, defaults to `1.5`): + Guidance scale for classifier-free guidance. + num_images_per_prompt (`int`, *optional*, defaults to `1`): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + Random generator for reproducibility. + output_type (`str`, *optional*, defaults to `"pil"`): + Output format: "pil", "np", or "latent". + + Examples: + + Returns: + [`GlmImagePipelineOutput`] or `tuple`: Generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + callback_on_step_end_tensor_inputs, + prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if batch_size != 1: + raise ValueError(f"batch_size must be 1 due to AR model limitations, got {batch_size}") + + device = self._execution_device + + prior_token_id, prior_token_image_ids = self.generate_prior_tokens( + prompt=prompt[0] if isinstance(prompt, list) else prompt, + image=image, + height=height, + width=width, + ) + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + dtype=self.dtype, + ) + + # 4. process images + if image is not None: + preprocessed_condition_images = [] + for img in image: + image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2] + multiple_of = self.vae_scale_factor * self.transformer.config.patch_size + image_height = (image_height // multiple_of) * multiple_of + image_width = (image_width // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width) + preprocessed_condition_images.append(img) + height = height or image_height + width = width or image_width + image = preprocessed_condition_images + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 5. Prepare latents and (optional) image kv cache + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_channels_latents=latent_channels, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers) + + if image is not None: + kv_caches.set_mode("write") + latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.latent_channels, 1, 1) + latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.latent_channels, 1, 1) + + latents_mean = latents_mean.to(device=device, dtype=prompt_embeds.dtype) + latents_std = latents_std.to(device=device, dtype=prompt_embeds.dtype) + + for condition_image, condition_image_prior_token_id in zip(image, prior_token_image_ids): + condition_image = condition_image.to(device=device, dtype=prompt_embeds.dtype) + condition_latent = retrieve_latents( + self.vae.encode(condition_image), generator=generator, sample_mode="argmax" + ) + condition_latent = (condition_latent - latents_mean) / latents_std + + # Do not remove. + # It would be use to run the reference image through a + # forward pass at timestep 0 and keep the KV cache. + _ = self.transformer( + hidden_states=condition_latent, + encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...], + prior_token_id=condition_image_prior_token_id, + prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool), + timestep=torch.zeros((1,), device=device), + target_size=torch.tensor([condition_image.shape[-2:]], device=device), + crop_coords=torch.zeros((1, 2), device=device), + attention_kwargs=attention_kwargs, + kv_caches=kv_caches, + ) + + # 6. Prepare additional timestep conditions + target_size = (height, width) + target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device) + crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device) + + target_size = target_size.repeat(batch_size * num_images_per_prompt, 1) + crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1) + + # Prepare timesteps + image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // ( + self.transformer.config.patch_size**2 + ) + timesteps = ( + np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps + 1)[:-1] + if timesteps is None + else np.array(timesteps) + ) + timesteps = timesteps.astype(np.int64).astype(np.float32) + sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("base_shift", 0.25), + self.scheduler.config.get("max_shift", 0.75), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu + ) + self._num_timesteps = len(timesteps) + + # 7. Denoising loop + transformer_dtype = self.transformer.dtype + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + prior_token_drop_cond = torch.full_like(prior_token_id, False, dtype=torch.bool) + prior_token_drop_uncond = torch.full_like(prior_token_id, True, dtype=torch.bool) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + + timestep = t.expand(latents.shape[0]) - 1 + + if image is not None: + kv_caches.set_mode("read") + + noise_pred_cond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + prior_token_id=prior_token_id, + prior_token_drop=prior_token_drop_cond, + timestep=timestep, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + kv_caches=kv_caches, + )[0].float() + + # perform guidance + if self.do_classifier_free_guidance: + if image is not None: + kv_caches.set_mode("skip") + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + prior_token_id=prior_token_id, + prior_token_drop=prior_token_drop_uncond, + timestep=timestep, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + kv_caches=kv_caches, + )[0].float() + + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + kv_caches.clear() + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.vae.config.latent_channels, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = latents * latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return GlmImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/glm_image/pipeline_output.py b/src/diffusers/pipelines/glm_image/pipeline_output.py new file mode 100644 index 000000000000..aec5a5454ea8 --- /dev/null +++ b/src/diffusers/pipelines/glm_image/pipeline_output.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class GlmImagePipelineOutput(BaseOutput): + """ + Output class for CogView3 pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index bb94c94da360..35feca0e346d 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -997,6 +997,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class GlmImageTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HunyuanDiT2DControlNetModel(metaclass=DummyObject): _backends = ["torch"]