Skip to content

Commit f71feac

Browse files
committed
Merge branch 'master' into attention-select
2 parents 66c4eb0 + 18de0b2 commit f71feac

File tree

75 files changed

+7117
-1294
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+7117
-1294
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
name: Execution Tests
2+
3+
on:
4+
push:
5+
branches: [ main, master ]
6+
pull_request:
7+
branches: [ main, master ]
8+
9+
jobs:
10+
test:
11+
strategy:
12+
matrix:
13+
os: [ubuntu-latest, windows-latest, macos-latest]
14+
runs-on: ${{ matrix.os }}
15+
continue-on-error: true
16+
steps:
17+
- uses: actions/checkout@v4
18+
- name: Set up Python
19+
uses: actions/setup-python@v4
20+
with:
21+
python-version: '3.12'
22+
- name: Install requirements
23+
run: |
24+
python -m pip install --upgrade pip
25+
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
26+
pip install -r requirements.txt
27+
pip install -r tests-unit/requirements.txt
28+
- name: Run Execution Tests
29+
run: |
30+
python -m pytest tests/execution -v --skip-timing-checks

comfy/cli_args.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,9 @@ class PerformanceFeature(enum.Enum):
143143
Fp16Accumulation = "fp16_accumulation"
144144
Fp8MatrixMultiplication = "fp8_matrix_mult"
145145
CublasOps = "cublas_ops"
146+
AutoTune = "autotune"
146147

147-
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
148+
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
148149

149150
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
150151
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")

comfy/clip_model.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,25 @@ def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate
6161
def forward(self, x, mask=None, intermediate_output=None):
6262
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
6363

64+
all_intermediate = None
6465
if intermediate_output is not None:
65-
if intermediate_output < 0:
66+
if intermediate_output == "all":
67+
all_intermediate = []
68+
intermediate_output = None
69+
elif intermediate_output < 0:
6670
intermediate_output = len(self.layers) + intermediate_output
6771

6872
intermediate = None
6973
for i, l in enumerate(self.layers):
7074
x = l(x, mask, optimized_attention)
7175
if i == intermediate_output:
7276
intermediate = x.clone()
77+
if all_intermediate is not None:
78+
all_intermediate.append(x.unsqueeze(1).clone())
79+
80+
if all_intermediate is not None:
81+
intermediate = torch.cat(all_intermediate, dim=1)
82+
7383
return x, intermediate
7484

7585
class CLIPEmbeddings(torch.nn.Module):

comfy/clip_vision.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,13 @@ def __init__(self, json_config):
5050
self.image_size = config.get("image_size", 224)
5151
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
5252
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
53-
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
53+
model_type = config.get("model_type", "clip_vision_model")
54+
model_class = IMAGE_ENCODERS.get(model_type)
55+
if model_type == "siglip_vision_model":
56+
self.return_all_hidden_states = True
57+
else:
58+
self.return_all_hidden_states = False
59+
5460
self.load_device = comfy.model_management.text_encoder_device()
5561
offload_device = comfy.model_management.text_encoder_offload_device()
5662
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
@@ -68,12 +74,18 @@ def get_sd(self):
6874
def encode_image(self, image, crop=True):
6975
comfy.model_management.load_model_gpu(self.patcher)
7076
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
71-
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
77+
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
7278

7379
outputs = Output()
7480
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
7581
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
76-
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
82+
if self.return_all_hidden_states:
83+
all_hs = out[1].to(comfy.model_management.intermediate_device())
84+
outputs["penultimate_hidden_states"] = all_hs[:, -2]
85+
outputs["all_hidden_states"] = all_hs
86+
else:
87+
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
88+
7789
outputs["mm_projected"] = out[3]
7890
return outputs
7991

@@ -124,8 +136,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
124136
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
125137
else:
126138
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
127-
elif "embeddings.patch_embeddings.projection.weight" in sd:
139+
140+
# Dinov2
141+
elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
128142
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
143+
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
144+
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
129145
else:
130146
return None
131147

comfy/controlnet.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,10 @@ def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
253253
to_concat = []
254254
for c in self.extra_concat_orig:
255255
c = c.to(self.cond_hint.device)
256-
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
256+
c = comfy.utils.common_upscale(c, self.cond_hint.shape[-1], self.cond_hint.shape[-2], self.upscale_algorithm, "center")
257+
if c.ndim < self.cond_hint.ndim:
258+
c = c.unsqueeze(2)
259+
c = comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[2], dim=2)
257260
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
258261
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
259262

@@ -585,11 +588,18 @@ def load_controlnet_flux_instantx(sd, model_options={}):
585588

586589
def load_controlnet_qwen_instantx(sd, model_options={}):
587590
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
588-
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
591+
control_latent_channels = sd.get("controlnet_x_embedder.weight").shape[1]
592+
593+
extra_condition_channels = 0
594+
concat_mask = False
595+
if control_latent_channels == 68: #inpaint controlnet
596+
extra_condition_channels = control_latent_channels - 64
597+
concat_mask = True
598+
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(extra_condition_channels=extra_condition_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
589599
control_model = controlnet_load_state_dict(control_model, sd)
590600
latent_format = comfy.latent_formats.Wan21()
591601
extra_conds = []
592-
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
602+
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
593603
return control
594604

595605
def convert_mistoline(sd):

comfy/image_encoders/dino2.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@ def __init__(self, dim, dtype, device, operations):
3131
def forward(self, x):
3232
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
3333

34+
class Dinov2MLP(torch.nn.Module):
35+
def __init__(self, hidden_size: int, dtype, device, operations):
36+
super().__init__()
37+
38+
mlp_ratio = 4
39+
hidden_features = int(hidden_size * mlp_ratio)
40+
self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype)
41+
self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype)
42+
43+
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
44+
hidden_state = self.fc1(hidden_state)
45+
hidden_state = torch.nn.functional.gelu(hidden_state)
46+
hidden_state = self.fc2(hidden_state)
47+
return hidden_state
3448

3549
class SwiGLUFFN(torch.nn.Module):
3650
def __init__(self, dim, dtype, device, operations):
@@ -50,12 +64,15 @@ def forward(self, x):
5064

5165

5266
class Dino2Block(torch.nn.Module):
53-
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
67+
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
5468
super().__init__()
5569
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
5670
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
5771
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
58-
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
72+
if use_swiglu_ffn:
73+
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
74+
else:
75+
self.mlp = Dinov2MLP(dim, dtype, device, operations)
5976
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
6077
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
6178

@@ -66,9 +83,10 @@ def forward(self, x, optimized_attention):
6683

6784

6885
class Dino2Encoder(torch.nn.Module):
69-
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
86+
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
7087
super().__init__()
71-
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
88+
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
89+
for _ in range(num_layers)])
7290

7391
def forward(self, x, intermediate_output=None):
7492
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
@@ -78,8 +96,8 @@ def forward(self, x, intermediate_output=None):
7896
intermediate_output = len(self.layer) + intermediate_output
7997

8098
intermediate = None
81-
for i, l in enumerate(self.layer):
82-
x = l(x, optimized_attention)
99+
for i, layer in enumerate(self.layer):
100+
x = layer(x, optimized_attention)
83101
if i == intermediate_output:
84102
intermediate = x.clone()
85103
return x, intermediate
@@ -128,9 +146,10 @@ def __init__(self, config_dict, dtype, device, operations):
128146
dim = config_dict["hidden_size"]
129147
heads = config_dict["num_attention_heads"]
130148
layer_norm_eps = config_dict["layer_norm_eps"]
149+
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
131150

132151
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
133-
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
152+
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
134153
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
135154

136155
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"hidden_size": 1024,
3+
"use_mask_token": true,
4+
"patch_size": 14,
5+
"image_size": 518,
6+
"num_channels": 3,
7+
"num_attention_heads": 16,
8+
"initializer_range": 0.02,
9+
"attention_probs_dropout_prob": 0.0,
10+
"hidden_dropout_prob": 0.0,
11+
"hidden_act": "gelu",
12+
"mlp_ratio": 4,
13+
"model_type": "dinov2",
14+
"num_hidden_layers": 24,
15+
"layer_norm_eps": 1e-6,
16+
"qkv_bias": true,
17+
"use_swiglu_ffn": false,
18+
"layerscale_value": 1.0,
19+
"drop_path_rate": 0.0,
20+
"image_mean": [0.485, 0.456, 0.406],
21+
"image_std": [0.229, 0.224, 0.225]
22+
}

0 commit comments

Comments
 (0)