Skip to content

Commit eac4988

Browse files
authored
Merge branch 'main' into transformers-v5-pr
2 parents 4455f14 + ed77a24 commit eac4988

32 files changed

+3458
-229
lines changed

scripts/convert_cosmos_to_diffusers.py

Lines changed: 290 additions & 15 deletions
Large diffs are not rendered by default.

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@
221221
"ControlNetModel",
222222
"ControlNetUnionModel",
223223
"ControlNetXSAdapter",
224+
"CosmosControlNetModel",
224225
"CosmosTransformer3DModel",
225226
"DiTTransformer2DModel",
226227
"EasyAnimateTransformer3DModel",
@@ -485,6 +486,7 @@
485486
"CogView4Pipeline",
486487
"ConsisIDPipeline",
487488
"Cosmos2_5_PredictBasePipeline",
489+
"Cosmos2_5_TransferPipeline",
488490
"Cosmos2TextToImagePipeline",
489491
"Cosmos2VideoToWorldPipeline",
490492
"CosmosTextToWorldPipeline",
@@ -992,6 +994,7 @@
992994
ControlNetModel,
993995
ControlNetUnionModel,
994996
ControlNetXSAdapter,
997+
CosmosControlNetModel,
995998
CosmosTransformer3DModel,
996999
DiTTransformer2DModel,
9971000
EasyAnimateTransformer3DModel,
@@ -1226,6 +1229,7 @@
12261229
CogView4Pipeline,
12271230
ConsisIDPipeline,
12281231
Cosmos2_5_PredictBasePipeline,
1232+
Cosmos2_5_TransferPipeline,
12291233
Cosmos2TextToImagePipeline,
12301234
Cosmos2VideoToWorldPipeline,
12311235
CosmosTextToWorldPipeline,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
_import_structure["autoencoders.vq_model"] = ["VQModel"]
5555
_import_structure["cache_utils"] = ["CacheMixin"]
5656
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
57+
_import_structure["controlnets.controlnet_cosmos"] = ["CosmosControlNetModel"]
5758
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
5859
_import_structure["controlnets.controlnet_hunyuan"] = [
5960
"HunyuanDiT2DControlNetModel",
@@ -175,6 +176,7 @@
175176
ControlNetModel,
176177
ControlNetUnionModel,
177178
ControlNetXSAdapter,
179+
CosmosControlNetModel,
178180
FluxControlNetModel,
179181
FluxMultiControlNetModel,
180182
HunyuanDiT2DControlNetModel,

src/diffusers/models/controlnets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
if is_torch_available():
55
from .controlnet import ControlNetModel, ControlNetOutput
6+
from .controlnet_cosmos import CosmosControlNetModel
67
from .controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
78
from .controlnet_hunyuan import (
89
HunyuanControlNetOutput,
Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
from dataclasses import dataclass
2+
from typing import List, Optional, Tuple, Union
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
from ...configuration_utils import ConfigMixin, register_to_config
8+
from ...loaders import FromOriginalModelMixin
9+
from ...utils import BaseOutput, is_torchvision_available, logging
10+
from ..modeling_utils import ModelMixin
11+
from ..transformers.transformer_cosmos import (
12+
CosmosEmbedding,
13+
CosmosLearnablePositionalEmbed,
14+
CosmosPatchEmbed,
15+
CosmosRotaryPosEmbed,
16+
CosmosTransformerBlock,
17+
)
18+
19+
20+
if is_torchvision_available():
21+
from torchvision import transforms
22+
23+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24+
25+
26+
@dataclass
27+
class CosmosControlNetOutput(BaseOutput):
28+
"""
29+
Output of [`CosmosControlNetModel`].
30+
31+
Args:
32+
control_block_samples (`list[torch.Tensor]`):
33+
List of control block activations to be injected into transformer blocks.
34+
"""
35+
36+
control_block_samples: List[torch.Tensor]
37+
38+
39+
class CosmosControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
40+
r"""
41+
ControlNet for Cosmos Transfer2.5.
42+
43+
This model duplicates the shared embedding modules from the transformer (patch_embed, time_embed,
44+
learnable_pos_embed, img_context_proj) to enable proper CPU offloading. The forward() method computes everything
45+
internally from raw inputs.
46+
"""
47+
48+
_supports_gradient_checkpointing = True
49+
_skip_layerwise_casting_patterns = ["patch_embed", "patch_embed_base", "time_embed"]
50+
_no_split_modules = ["CosmosTransformerBlock"]
51+
_keep_in_fp32_modules = ["learnable_pos_embed"]
52+
53+
@register_to_config
54+
def __init__(
55+
self,
56+
n_controlnet_blocks: int = 4,
57+
in_channels: int = 130,
58+
latent_channels: int = 18, # base latent channels (latents + condition_mask) + padding_mask
59+
model_channels: int = 2048,
60+
num_attention_heads: int = 32,
61+
attention_head_dim: int = 128,
62+
mlp_ratio: float = 4.0,
63+
text_embed_dim: int = 1024,
64+
adaln_lora_dim: int = 256,
65+
patch_size: Tuple[int, int, int] = (1, 2, 2),
66+
max_size: Tuple[int, int, int] = (128, 240, 240),
67+
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
68+
extra_pos_embed_type: Optional[str] = None,
69+
img_context_dim_in: Optional[int] = None,
70+
img_context_dim_out: int = 2048,
71+
use_crossattn_projection: bool = False,
72+
crossattn_proj_in_channels: int = 1024,
73+
encoder_hidden_states_channels: int = 1024,
74+
):
75+
super().__init__()
76+
77+
self.patch_embed = CosmosPatchEmbed(in_channels, model_channels, patch_size, bias=False)
78+
79+
self.patch_embed_base = CosmosPatchEmbed(latent_channels, model_channels, patch_size, bias=False)
80+
self.time_embed = CosmosEmbedding(model_channels, model_channels)
81+
82+
self.learnable_pos_embed = None
83+
if extra_pos_embed_type == "learnable":
84+
self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
85+
hidden_size=model_channels,
86+
max_size=max_size,
87+
patch_size=patch_size,
88+
)
89+
90+
self.img_context_proj = None
91+
if img_context_dim_in is not None and img_context_dim_in > 0:
92+
self.img_context_proj = nn.Sequential(
93+
nn.Linear(img_context_dim_in, img_context_dim_out, bias=True),
94+
nn.GELU(),
95+
)
96+
97+
# Cross-attention projection for text embeddings (same as transformer)
98+
self.crossattn_proj = None
99+
if use_crossattn_projection:
100+
self.crossattn_proj = nn.Sequential(
101+
nn.Linear(crossattn_proj_in_channels, encoder_hidden_states_channels, bias=True),
102+
nn.GELU(),
103+
)
104+
105+
# RoPE for both control and base latents
106+
self.rope = CosmosRotaryPosEmbed(
107+
hidden_size=attention_head_dim, max_size=max_size, patch_size=patch_size, rope_scale=rope_scale
108+
)
109+
110+
self.control_blocks = nn.ModuleList(
111+
[
112+
CosmosTransformerBlock(
113+
num_attention_heads=num_attention_heads,
114+
attention_head_dim=attention_head_dim,
115+
cross_attention_dim=text_embed_dim,
116+
mlp_ratio=mlp_ratio,
117+
adaln_lora_dim=adaln_lora_dim,
118+
qk_norm="rms_norm",
119+
out_bias=False,
120+
img_context=img_context_dim_in is not None and img_context_dim_in > 0,
121+
before_proj=(block_idx == 0),
122+
after_proj=True,
123+
)
124+
for block_idx in range(n_controlnet_blocks)
125+
]
126+
)
127+
128+
self.gradient_checkpointing = False
129+
130+
def _expand_conditioning_scale(self, conditioning_scale: Union[float, List[float]]) -> List[float]:
131+
if isinstance(conditioning_scale, list):
132+
scales = conditioning_scale
133+
else:
134+
scales = [conditioning_scale] * len(self.control_blocks)
135+
136+
if len(scales) < len(self.control_blocks):
137+
logger.warning(
138+
"Received %d control scales, but control network defines %d blocks. "
139+
"Scales will be trimmed or repeated to match.",
140+
len(scales),
141+
len(self.control_blocks),
142+
)
143+
scales = (scales * len(self.control_blocks))[: len(self.control_blocks)]
144+
return scales
145+
146+
def forward(
147+
self,
148+
controls_latents: torch.Tensor,
149+
latents: torch.Tensor,
150+
timestep: torch.Tensor,
151+
encoder_hidden_states: Union[Optional[torch.Tensor], Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
152+
condition_mask: torch.Tensor,
153+
conditioning_scale: Union[float, List[float]] = 1.0,
154+
padding_mask: Optional[torch.Tensor] = None,
155+
attention_mask: Optional[torch.Tensor] = None,
156+
fps: Optional[int] = None,
157+
return_dict: bool = True,
158+
) -> Union[CosmosControlNetOutput, Tuple[List[torch.Tensor]]]:
159+
"""
160+
Forward pass for the ControlNet.
161+
162+
Args:
163+
controls_latents: Control signal latents [B, C, T, H, W]
164+
latents: Base latents from the noising process [B, C, T, H, W]
165+
timestep: Diffusion timestep tensor
166+
encoder_hidden_states: Tuple of (text_context, img_context) or text_context
167+
condition_mask: Conditioning mask [B, 1, T, H, W]
168+
conditioning_scale: Scale factor(s) for control outputs
169+
padding_mask: Padding mask [B, 1, H, W] or None
170+
attention_mask: Optional attention mask or None
171+
fps: Frames per second for RoPE or None
172+
return_dict: Whether to return a CosmosControlNetOutput or a tuple
173+
174+
Returns:
175+
CosmosControlNetOutput or tuple of control tensors
176+
"""
177+
B, C, T, H, W = controls_latents.shape
178+
179+
# 1. Prepare control latents
180+
control_hidden_states = controls_latents
181+
vace_in_channels = self.config.in_channels - 1
182+
if control_hidden_states.shape[1] < vace_in_channels - 1:
183+
pad_C = vace_in_channels - 1 - control_hidden_states.shape[1]
184+
control_hidden_states = torch.cat(
185+
[
186+
control_hidden_states,
187+
torch.zeros(
188+
(B, pad_C, T, H, W), dtype=control_hidden_states.dtype, device=control_hidden_states.device
189+
),
190+
],
191+
dim=1,
192+
)
193+
194+
control_hidden_states = torch.cat([control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1)
195+
196+
padding_mask_resized = transforms.functional.resize(
197+
padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
198+
)
199+
control_hidden_states = torch.cat(
200+
[control_hidden_states, padding_mask_resized.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1
201+
)
202+
203+
# 2. Prepare base latents (same processing as transformer.forward)
204+
base_hidden_states = latents
205+
if condition_mask is not None:
206+
base_hidden_states = torch.cat([base_hidden_states, condition_mask], dim=1)
207+
208+
base_padding_mask = transforms.functional.resize(
209+
padding_mask, list(base_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
210+
)
211+
base_hidden_states = torch.cat(
212+
[base_hidden_states, base_padding_mask.unsqueeze(2).repeat(B, 1, T, 1, 1)], dim=1
213+
)
214+
215+
# 3. Generate positional embeddings (shared for both)
216+
image_rotary_emb = self.rope(control_hidden_states, fps=fps)
217+
extra_pos_emb = self.learnable_pos_embed(control_hidden_states) if self.learnable_pos_embed else None
218+
219+
# 4. Patchify control latents
220+
control_hidden_states = self.patch_embed(control_hidden_states)
221+
control_hidden_states = control_hidden_states.flatten(1, 3)
222+
223+
# 5. Patchify base latents
224+
p_t, p_h, p_w = self.config.patch_size
225+
post_patch_num_frames = T // p_t
226+
post_patch_height = H // p_h
227+
post_patch_width = W // p_w
228+
229+
base_hidden_states = self.patch_embed_base(base_hidden_states)
230+
base_hidden_states = base_hidden_states.flatten(1, 3)
231+
232+
# 6. Time embeddings
233+
if timestep.ndim == 1:
234+
temb, embedded_timestep = self.time_embed(base_hidden_states, timestep)
235+
elif timestep.ndim == 5:
236+
batch_size, _, num_frames, _, _ = latents.shape
237+
assert timestep.shape == (batch_size, 1, num_frames, 1, 1), (
238+
f"Expected timestep to have shape [B, 1, T, 1, 1], but got {timestep.shape}"
239+
)
240+
timestep_flat = timestep.flatten()
241+
temb, embedded_timestep = self.time_embed(base_hidden_states, timestep_flat)
242+
temb, embedded_timestep = (
243+
x.view(batch_size, post_patch_num_frames, 1, 1, -1)
244+
.expand(-1, -1, post_patch_height, post_patch_width, -1)
245+
.flatten(1, 3)
246+
for x in (temb, embedded_timestep)
247+
)
248+
else:
249+
raise ValueError(f"Expected timestep to have shape [B, 1, T, 1, 1] or [T], but got {timestep.shape}")
250+
251+
# 7. Process encoder hidden states
252+
if isinstance(encoder_hidden_states, tuple):
253+
text_context, img_context = encoder_hidden_states
254+
else:
255+
text_context = encoder_hidden_states
256+
img_context = None
257+
258+
# Apply cross-attention projection to text context
259+
if self.crossattn_proj is not None:
260+
text_context = self.crossattn_proj(text_context)
261+
262+
# Apply cross-attention projection to image context (if provided)
263+
if img_context is not None and self.img_context_proj is not None:
264+
img_context = self.img_context_proj(img_context)
265+
266+
# Combine text and image context into a single tuple
267+
if self.config.img_context_dim_in is not None and self.config.img_context_dim_in > 0:
268+
processed_encoder_hidden_states = (text_context, img_context)
269+
else:
270+
processed_encoder_hidden_states = text_context
271+
272+
# 8. Prepare attention mask
273+
if attention_mask is not None:
274+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]
275+
276+
# 9. Run control blocks
277+
scales = self._expand_conditioning_scale(conditioning_scale)
278+
result = []
279+
for block_idx, (block, scale) in enumerate(zip(self.control_blocks, scales)):
280+
if torch.is_grad_enabled() and self.gradient_checkpointing:
281+
control_hidden_states, control_proj = self._gradient_checkpointing_func(
282+
block,
283+
control_hidden_states,
284+
processed_encoder_hidden_states,
285+
embedded_timestep,
286+
temb,
287+
image_rotary_emb,
288+
extra_pos_emb,
289+
attention_mask,
290+
None, # controlnet_residual
291+
base_hidden_states,
292+
block_idx,
293+
)
294+
else:
295+
control_hidden_states, control_proj = block(
296+
hidden_states=control_hidden_states,
297+
encoder_hidden_states=processed_encoder_hidden_states,
298+
embedded_timestep=embedded_timestep,
299+
temb=temb,
300+
image_rotary_emb=image_rotary_emb,
301+
extra_pos_emb=extra_pos_emb,
302+
attention_mask=attention_mask,
303+
controlnet_residual=None,
304+
latents=base_hidden_states,
305+
block_idx=block_idx,
306+
)
307+
result.append(control_proj * scale)
308+
309+
if not return_dict:
310+
return (result,)
311+
312+
return CosmosControlNetOutput(control_block_samples=result)

src/diffusers/models/transformers/transformer_chronoedit.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
4343
encoder_hidden_states = hidden_states
4444

4545
if attn.fused_projections:
46-
if attn.cross_attention_dim_head is None:
46+
if not attn.is_cross_attention:
4747
# In self-attention layers, we can fuse the entire QKV projection into a single linear
4848
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
4949
else:
@@ -219,15 +219,18 @@ def __init__(
219219
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
220220
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
221221

222-
self.is_cross_attention = cross_attention_dim_head is not None
222+
if is_cross_attention is not None:
223+
self.is_cross_attention = is_cross_attention
224+
else:
225+
self.is_cross_attention = cross_attention_dim_head is not None
223226

224227
self.set_processor(processor)
225228

226229
def fuse_projections(self):
227230
if getattr(self, "fused_projections", False):
228231
return
229232

230-
if self.cross_attention_dim_head is None:
233+
if not self.is_cross_attention:
231234
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
232235
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
233236
out_features, in_features = concatenated_weights.shape

0 commit comments

Comments
 (0)