Skip to content

Commit 80a44b9

Browse files
Change lumina to native RMSNorm. (#7935)
1 parent 9187a09 commit 80a44b9

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

comfy/ldm/lumina/model.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch.nn.functional as F
99
import comfy.ldm.common_dit
1010

11-
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
11+
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
1212
from comfy.ldm.modules.attention import optimized_attention_masked
1313
from comfy.ldm.flux.layers import EmbedND
1414

@@ -64,8 +64,8 @@ def __init__(
6464
)
6565

6666
if qk_norm:
67-
self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
68-
self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings)
67+
self.q_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
68+
self.k_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
6969
else:
7070
self.q_norm = self.k_norm = nn.Identity()
7171

@@ -242,11 +242,11 @@ def __init__(
242242
operation_settings=operation_settings,
243243
)
244244
self.layer_id = layer_id
245-
self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
246-
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
245+
self.attention_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
246+
self.ffn_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
247247

248-
self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
249-
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
248+
self.attention_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
249+
self.ffn_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
250250

251251
self.modulation = modulation
252252
if modulation:
@@ -431,7 +431,7 @@ def __init__(
431431

432432
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
433433
self.cap_embedder = nn.Sequential(
434-
RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings),
434+
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
435435
operation_settings.get("operations").Linear(
436436
cap_feat_dim,
437437
dim,
@@ -457,7 +457,7 @@ def __init__(
457457
for layer_id in range(n_layers)
458458
]
459459
)
460-
self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings)
460+
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
461461
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
462462

463463
assert (dim // n_heads) == sum(axes_dims)

0 commit comments

Comments
 (0)