Skip to content

Commit 9187a09

Browse files
Change cosmos and hydit models to use the native RMSNorm. (#7934)
1 parent 3041e5c commit 9187a09

File tree

3 files changed

+8
-11
lines changed

3 files changed

+8
-11
lines changed

comfy/ldm/cosmos/blocks.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from einops.layers.torch import Rearrange
2424
from torch import nn
2525

26-
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
2726
from comfy.ldm.modules.attention import optimized_attention
2827

2928

@@ -37,11 +36,11 @@ def apply_rotary_pos_emb(
3736
return t_out
3837

3938

40-
def get_normalization(name: str, channels: int, weight_args={}):
39+
def get_normalization(name: str, channels: int, weight_args={}, operations=None):
4140
if name == "I":
4241
return nn.Identity()
4342
elif name == "R":
44-
return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
43+
return operations.RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
4544
else:
4645
raise ValueError(f"Normalization {name} not found")
4746

@@ -120,15 +119,15 @@ def __init__(
120119

121120
self.to_q = nn.Sequential(
122121
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
123-
get_normalization(qkv_norm[0], norm_dim),
122+
get_normalization(qkv_norm[0], norm_dim, weight_args=weight_args, operations=operations),
124123
)
125124
self.to_k = nn.Sequential(
126125
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
127-
get_normalization(qkv_norm[1], norm_dim),
126+
get_normalization(qkv_norm[1], norm_dim, weight_args=weight_args, operations=operations),
128127
)
129128
self.to_v = nn.Sequential(
130129
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
131-
get_normalization(qkv_norm[2], norm_dim),
130+
get_normalization(qkv_norm[2], norm_dim, weight_args=weight_args, operations=operations),
132131
)
133132

134133
self.to_out = nn.Sequential(

comfy/ldm/cosmos/model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
from enum import Enum
2828
import logging
2929

30-
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
31-
3230
from .blocks import (
3331
FinalLayer,
3432
GeneralDITTransformerBlock,
@@ -195,7 +193,7 @@ def __init__(
195193

196194
if self.affline_emb_norm:
197195
logging.debug("Building affine embedding normalization layer")
198-
self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6)
196+
self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
199197
else:
200198
self.affline_norm = nn.Identity()
201199

comfy/ldm/hydit/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.nn as nn
44

55
import comfy.ops
6-
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
6+
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed
77
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
88
from torch.utils import checkpoint
99

@@ -51,7 +51,7 @@ def __init__(self,
5151
if norm_type == "layer":
5252
norm_layer = operations.LayerNorm
5353
elif norm_type == "rms":
54-
norm_layer = RMSNorm
54+
norm_layer = operations.RMSNorm
5555
else:
5656
raise ValueError(f"Unknown norm_type: {norm_type}")
5757

0 commit comments

Comments
 (0)