Skip to content

Commit bb44c2e

Browse files
committed
Merge branch 'master' into worksplit-multigpu
2 parents efcd828 + dd611a7 commit bb44c2e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+3120
-1048
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
6666
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
6767
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
6868
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
69+
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
6970
- Image Editing Models
7071
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
7172
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)

comfy/audio_encoders/audio_encoders.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .wav2vec2 import Wav2Vec2Model
2+
from .whisper import WhisperLargeV3
23
import comfy.model_management
34
import comfy.ops
45
import comfy.utils
@@ -11,7 +12,18 @@ def __init__(self, config):
1112
self.load_device = comfy.model_management.text_encoder_device()
1213
offload_device = comfy.model_management.text_encoder_offload_device()
1314
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
14-
self.model = Wav2Vec2Model(dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast)
15+
model_type = config.pop("model_type")
16+
model_config = dict(config)
17+
model_config.update({
18+
"dtype": self.dtype,
19+
"device": offload_device,
20+
"operations": comfy.ops.manual_cast
21+
})
22+
23+
if model_type == "wav2vec2":
24+
self.model = Wav2Vec2Model(**model_config)
25+
elif model_type == "whisper3":
26+
self.model = WhisperLargeV3(**model_config)
1527
self.model.eval()
1628
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
1729
self.model_sample_rate = 16000
@@ -29,14 +41,51 @@ def encode_audio(self, audio, sample_rate):
2941
outputs = {}
3042
outputs["encoded_audio"] = out
3143
outputs["encoded_audio_all_layers"] = all_layers
44+
outputs["audio_samples"] = audio.shape[2]
3245
return outputs
3346

3447

3548
def load_audio_encoder_from_sd(sd, prefix=""):
36-
audio_encoder = AudioEncoderModel(None)
3749
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
50+
if "encoder.layer_norm.bias" in sd: #wav2vec2
51+
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
52+
if embed_dim == 1024:# large
53+
config = {
54+
"model_type": "wav2vec2",
55+
"embed_dim": 1024,
56+
"num_heads": 16,
57+
"num_layers": 24,
58+
"conv_norm": True,
59+
"conv_bias": True,
60+
"do_normalize": True,
61+
"do_stable_layer_norm": True
62+
}
63+
elif embed_dim == 768: # base
64+
config = {
65+
"model_type": "wav2vec2",
66+
"embed_dim": 768,
67+
"num_heads": 12,
68+
"num_layers": 12,
69+
"conv_norm": False,
70+
"conv_bias": False,
71+
"do_normalize": False, # chinese-wav2vec2-base has this False
72+
"do_stable_layer_norm": False
73+
}
74+
else:
75+
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
76+
elif "model.encoder.embed_positions.weight" in sd:
77+
sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
78+
config = {
79+
"model_type": "whisper3",
80+
}
81+
else:
82+
raise RuntimeError("ERROR: audio encoder not supported.")
83+
84+
audio_encoder = AudioEncoderModel(config)
3885
m, u = audio_encoder.load_sd(sd)
3986
if len(m) > 0:
4087
logging.warning("missing audio encoder: {}".format(m))
88+
if len(u) > 0:
89+
logging.warning("unexpected audio encoder: {}".format(u))
4190

4291
return audio_encoder

comfy/audio_encoders/wav2vec2.py

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,49 @@ def forward(self, x):
1313
x = self.conv(x)
1414
return torch.nn.functional.gelu(self.layer_norm(x.transpose(-2, -1)).transpose(-2, -1))
1515

16+
class LayerGroupNormConv(nn.Module):
17+
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
18+
super().__init__()
19+
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
20+
self.layer_norm = operations.GroupNorm(num_groups=out_channels, num_channels=out_channels, affine=True, device=device, dtype=dtype)
21+
22+
def forward(self, x):
23+
x = self.conv(x)
24+
return torch.nn.functional.gelu(self.layer_norm(x))
25+
26+
class ConvNoNorm(nn.Module):
27+
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=False, dtype=None, device=None, operations=None):
28+
super().__init__()
29+
self.conv = operations.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=bias, device=device, dtype=dtype)
30+
31+
def forward(self, x):
32+
x = self.conv(x)
33+
return torch.nn.functional.gelu(x)
34+
1635

1736
class ConvFeatureEncoder(nn.Module):
18-
def __init__(self, conv_dim, dtype=None, device=None, operations=None):
37+
def __init__(self, conv_dim, conv_bias=False, conv_norm=True, dtype=None, device=None, operations=None):
1938
super().__init__()
20-
self.conv_layers = nn.ModuleList([
21-
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
22-
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
23-
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
24-
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
25-
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
26-
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
27-
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=True, device=device, dtype=dtype, operations=operations),
28-
])
39+
if conv_norm:
40+
self.conv_layers = nn.ModuleList([
41+
LayerNormConv(1, conv_dim, kernel_size=10, stride=5, bias=True, device=device, dtype=dtype, operations=operations),
42+
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
43+
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
44+
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
45+
LayerNormConv(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
46+
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
47+
LayerNormConv(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
48+
])
49+
else:
50+
self.conv_layers = nn.ModuleList([
51+
LayerGroupNormConv(1, conv_dim, kernel_size=10, stride=5, bias=conv_bias, device=device, dtype=dtype, operations=operations),
52+
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
53+
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
54+
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
55+
ConvNoNorm(conv_dim, conv_dim, kernel_size=3, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
56+
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
57+
ConvNoNorm(conv_dim, conv_dim, kernel_size=2, stride=2, bias=conv_bias, device=device, dtype=dtype, operations=operations),
58+
])
2959

3060
def forward(self, x):
3161
x = x.unsqueeze(1)
@@ -76,6 +106,7 @@ def __init__(
76106
num_heads=12,
77107
num_layers=12,
78108
mlp_ratio=4.0,
109+
do_stable_layer_norm=True,
79110
dtype=None, device=None, operations=None
80111
):
81112
super().__init__()
@@ -86,20 +117,25 @@ def __init__(
86117
embed_dim=embed_dim,
87118
num_heads=num_heads,
88119
mlp_ratio=mlp_ratio,
120+
do_stable_layer_norm=do_stable_layer_norm,
89121
device=device, dtype=dtype, operations=operations
90122
)
91123
for _ in range(num_layers)
92124
])
93125

94126
self.layer_norm = operations.LayerNorm(embed_dim, eps=1e-05, device=device, dtype=dtype)
127+
self.do_stable_layer_norm = do_stable_layer_norm
95128

96129
def forward(self, x, mask=None):
97130
x = x + self.pos_conv_embed(x)
98131
all_x = ()
132+
if not self.do_stable_layer_norm:
133+
x = self.layer_norm(x)
99134
for layer in self.layers:
100135
all_x += (x,)
101136
x = layer(x, mask)
102-
x = self.layer_norm(x)
137+
if self.do_stable_layer_norm:
138+
x = self.layer_norm(x)
103139
all_x += (x,)
104140
return x, all_x
105141

@@ -145,6 +181,7 @@ def __init__(
145181
embed_dim=768,
146182
num_heads=12,
147183
mlp_ratio=4.0,
184+
do_stable_layer_norm=True,
148185
dtype=None, device=None, operations=None
149186
):
150187
super().__init__()
@@ -154,15 +191,19 @@ def __init__(
154191
self.layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
155192
self.feed_forward = FeedForward(embed_dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
156193
self.final_layer_norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
194+
self.do_stable_layer_norm = do_stable_layer_norm
157195

158196
def forward(self, x, mask=None):
159197
residual = x
160-
x = self.layer_norm(x)
198+
if self.do_stable_layer_norm:
199+
x = self.layer_norm(x)
161200
x = self.attention(x, mask=mask)
162201
x = residual + x
163-
164-
x = x + self.feed_forward(self.final_layer_norm(x))
165-
return x
202+
if not self.do_stable_layer_norm:
203+
x = self.layer_norm(x)
204+
return self.final_layer_norm(x + self.feed_forward(x))
205+
else:
206+
return x + self.feed_forward(self.final_layer_norm(x))
166207

167208

168209
class Wav2Vec2Model(nn.Module):
@@ -174,34 +215,38 @@ def __init__(
174215
final_dim=256,
175216
num_heads=16,
176217
num_layers=24,
218+
conv_norm=True,
219+
conv_bias=True,
220+
do_normalize=True,
221+
do_stable_layer_norm=True,
177222
dtype=None, device=None, operations=None
178223
):
179224
super().__init__()
180225

181226
conv_dim = 512
182-
self.feature_extractor = ConvFeatureEncoder(conv_dim, device=device, dtype=dtype, operations=operations)
227+
self.feature_extractor = ConvFeatureEncoder(conv_dim, conv_norm=conv_norm, conv_bias=conv_bias, device=device, dtype=dtype, operations=operations)
183228
self.feature_projection = FeatureProjection(conv_dim, embed_dim, device=device, dtype=dtype, operations=operations)
184229

185230
self.masked_spec_embed = nn.Parameter(torch.empty(embed_dim, device=device, dtype=dtype))
231+
self.do_normalize = do_normalize
186232

187233
self.encoder = TransformerEncoder(
188234
embed_dim=embed_dim,
189235
num_heads=num_heads,
190236
num_layers=num_layers,
237+
do_stable_layer_norm=do_stable_layer_norm,
191238
device=device, dtype=dtype, operations=operations
192239
)
193240

194241
def forward(self, x, mask_time_indices=None, return_dict=False):
195-
196242
x = torch.mean(x, dim=1)
197243

198-
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
244+
if self.do_normalize:
245+
x = (x - x.mean()) / torch.sqrt(x.var() + 1e-7)
199246

200247
features = self.feature_extractor(x)
201248
features = self.feature_projection(features)
202-
203249
batch_size, seq_len, _ = features.shape
204250

205251
x, all_x = self.encoder(features)
206-
207252
return x, all_x

0 commit comments

Comments
 (0)