88import torch .nn .functional as F
99import comfy .ldm .common_dit
1010
11- from comfy .ldm .modules .diffusionmodules .mmdit import TimestepEmbedder , RMSNorm
11+ from comfy .ldm .modules .diffusionmodules .mmdit import TimestepEmbedder
1212from comfy .ldm .modules .attention import optimized_attention_masked
1313from comfy .ldm .flux .layers import EmbedND
1414
@@ -64,8 +64,8 @@ def __init__(
6464 )
6565
6666 if qk_norm :
67- self .q_norm = RMSNorm (self .head_dim , elementwise_affine = True , ** operation_settings )
68- self .k_norm = RMSNorm (self .head_dim , elementwise_affine = True , ** operation_settings )
67+ self .q_norm = operation_settings . get ( "operations" ). RMSNorm (self .head_dim , elementwise_affine = True , device = operation_settings . get ( "device" ), dtype = operation_settings . get ( "dtype" ) )
68+ self .k_norm = operation_settings . get ( "operations" ). RMSNorm (self .head_dim , elementwise_affine = True , device = operation_settings . get ( "device" ), dtype = operation_settings . get ( "dtype" ) )
6969 else :
7070 self .q_norm = self .k_norm = nn .Identity ()
7171
@@ -242,11 +242,11 @@ def __init__(
242242 operation_settings = operation_settings ,
243243 )
244244 self .layer_id = layer_id
245- self .attention_norm1 = RMSNorm (dim , eps = norm_eps , elementwise_affine = True , ** operation_settings )
246- self .ffn_norm1 = RMSNorm (dim , eps = norm_eps , elementwise_affine = True , ** operation_settings )
245+ self .attention_norm1 = operation_settings . get ( "operations" ). RMSNorm (dim , eps = norm_eps , elementwise_affine = True , device = operation_settings . get ( "device" ), dtype = operation_settings . get ( "dtype" ) )
246+ self .ffn_norm1 = operation_settings . get ( "operations" ). RMSNorm (dim , eps = norm_eps , elementwise_affine = True , device = operation_settings . get ( "device" ), dtype = operation_settings . get ( "dtype" ) )
247247
248- self .attention_norm2 = RMSNorm (dim , eps = norm_eps , elementwise_affine = True , ** operation_settings )
249- self .ffn_norm2 = RMSNorm (dim , eps = norm_eps , elementwise_affine = True , ** operation_settings )
248+ self .attention_norm2 = operation_settings . get ( "operations" ). RMSNorm (dim , eps = norm_eps , elementwise_affine = True , device = operation_settings . get ( "device" ), dtype = operation_settings . get ( "dtype" ) )
249+ self .ffn_norm2 = operation_settings . get ( "operations" ). RMSNorm (dim , eps = norm_eps , elementwise_affine = True , device = operation_settings . get ( "device" ), dtype = operation_settings . get ( "dtype" ) )
250250
251251 self .modulation = modulation
252252 if modulation :
@@ -431,7 +431,7 @@ def __init__(
431431
432432 self .t_embedder = TimestepEmbedder (min (dim , 1024 ), ** operation_settings )
433433 self .cap_embedder = nn .Sequential (
434- RMSNorm (cap_feat_dim , eps = norm_eps , elementwise_affine = True , ** operation_settings ),
434+ operation_settings . get ( "operations" ). RMSNorm (cap_feat_dim , eps = norm_eps , elementwise_affine = True , device = operation_settings . get ( "device" ), dtype = operation_settings . get ( "dtype" ) ),
435435 operation_settings .get ("operations" ).Linear (
436436 cap_feat_dim ,
437437 dim ,
@@ -457,7 +457,7 @@ def __init__(
457457 for layer_id in range (n_layers )
458458 ]
459459 )
460- self .norm_final = RMSNorm (dim , eps = norm_eps , elementwise_affine = True , ** operation_settings )
460+ self .norm_final = operation_settings . get ( "operations" ). RMSNorm (dim , eps = norm_eps , elementwise_affine = True , device = operation_settings . get ( "device" ), dtype = operation_settings . get ( "dtype" ) )
461461 self .final_layer = FinalLayer (dim , patch_size , self .out_channels , operation_settings = operation_settings )
462462
463463 assert (dim // n_heads ) == sum (axes_dims )
0 commit comments