2323from einops .layers .torch import Rearrange
2424from torch import nn
2525
26- from comfy .ldm .modules .diffusionmodules .mmdit import RMSNorm
2726from comfy .ldm .modules .attention import optimized_attention
2827
2928
@@ -37,11 +36,11 @@ def apply_rotary_pos_emb(
3736 return t_out
3837
3938
40- def get_normalization (name : str , channels : int , weight_args = {}):
39+ def get_normalization (name : str , channels : int , weight_args = {}, operations = None ):
4140 if name == "I" :
4241 return nn .Identity ()
4342 elif name == "R" :
44- return RMSNorm (channels , elementwise_affine = True , eps = 1e-6 , ** weight_args )
43+ return operations . RMSNorm (channels , elementwise_affine = True , eps = 1e-6 , ** weight_args )
4544 else :
4645 raise ValueError (f"Normalization { name } not found" )
4746
@@ -120,15 +119,15 @@ def __init__(
120119
121120 self .to_q = nn .Sequential (
122121 operations .Linear (query_dim , inner_dim , bias = qkv_bias , ** weight_args ),
123- get_normalization (qkv_norm [0 ], norm_dim ),
122+ get_normalization (qkv_norm [0 ], norm_dim , weight_args = weight_args , operations = operations ),
124123 )
125124 self .to_k = nn .Sequential (
126125 operations .Linear (context_dim , inner_dim , bias = qkv_bias , ** weight_args ),
127- get_normalization (qkv_norm [1 ], norm_dim ),
126+ get_normalization (qkv_norm [1 ], norm_dim , weight_args = weight_args , operations = operations ),
128127 )
129128 self .to_v = nn .Sequential (
130129 operations .Linear (context_dim , inner_dim , bias = qkv_bias , ** weight_args ),
131- get_normalization (qkv_norm [2 ], norm_dim ),
130+ get_normalization (qkv_norm [2 ], norm_dim , weight_args = weight_args , operations = operations ),
132131 )
133132
134133 self .to_out = nn .Sequential (
0 commit comments