Skip to content

Commit b996c18

Browse files
authored
Merge pull request #3 from LykosAI/add-experiments
Add comfyui_experiments module
2 parents 92bb404 + 3d6b3eb commit b996c18

File tree

10 files changed

+1040
-1
lines changed

10 files changed

+1040
-1
lines changed

src/inference_core_nodes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__all__ = ("__version__", "NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS")
22

3-
__version__ = "0.2.1"
3+
__version__ = "0.3.0"
44

55

66
def _get_node_mappings():

src/inference_core_nodes/comfyui_experiments/LICENSE

Lines changed: 674 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
## ComfyUI Experiments
2+
3+
Based on or modified from: [comfyanonymous/ComfyUI_experiments](https://github.com/comfyanonymous/ComfyUI_experiments) @ 934dba9d206e4738e0dac26a09b51f1dffcb4e44
4+
5+
License: GPL-3.0
6+
7+
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import importlib
2+
import os
3+
4+
node_list = [ #Add list of .py files containing nodes here
5+
"advanced_model_merging",
6+
"reference_only",
7+
"sampler_rescalecfg",
8+
"sampler_tonemap",
9+
"sampler_tonemap_rescalecfg",
10+
"sdxl_model_merging"
11+
]
12+
13+
NODE_CLASS_MAPPINGS = {}
14+
NODE_DISPLAY_NAME_MAPPINGS = {}
15+
16+
for module_name in node_list:
17+
imported_module = importlib.import_module(".{}".format(module_name), __name__)
18+
19+
NODE_CLASS_MAPPINGS = {**NODE_CLASS_MAPPINGS, **imported_module.NODE_CLASS_MAPPINGS}
20+
if hasattr(imported_module, "NODE_DISPLAY_NAME_MAPPINGS"):
21+
NODE_DISPLAY_NAME_MAPPINGS = {**NODE_DISPLAY_NAME_MAPPINGS, **imported_module.NODE_DISPLAY_NAME_MAPPINGS}
22+
23+
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import comfy_extras.nodes_model_merging
2+
3+
class ModelMergeBlockNumber(comfy_extras.nodes_model_merging.ModelMergeBlocks):
4+
@classmethod
5+
def INPUT_TYPES(s):
6+
arg_dict = { "model1": ("MODEL",),
7+
"model2": ("MODEL",)}
8+
9+
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
10+
11+
arg_dict["time_embed."] = argument
12+
arg_dict["label_emb."] = argument
13+
14+
for i in range(12):
15+
arg_dict["input_blocks.{}.".format(i)] = argument
16+
17+
for i in range(3):
18+
arg_dict["middle_block.{}.".format(i)] = argument
19+
20+
for i in range(12):
21+
arg_dict["output_blocks.{}.".format(i)] = argument
22+
23+
arg_dict["out."] = argument
24+
25+
return {"required": arg_dict}
26+
27+
28+
NODE_CLASS_MAPPINGS = {
29+
"ModelMergeBlockNumber": ModelMergeBlockNumber,
30+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
3+
class ReferenceOnlySimple:
4+
@classmethod
5+
def INPUT_TYPES(s):
6+
return {"required": { "model": ("MODEL",),
7+
"reference": ("LATENT",),
8+
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64})
9+
}}
10+
11+
RETURN_TYPES = ("MODEL", "LATENT")
12+
FUNCTION = "reference_only"
13+
14+
CATEGORY = "custom_node_experiments"
15+
16+
def reference_only(self, model, reference, batch_size):
17+
model_reference = model.clone()
18+
size_latent = list(reference["samples"].shape)
19+
size_latent[0] = batch_size
20+
latent = {}
21+
latent["samples"] = torch.zeros(size_latent)
22+
23+
batch = latent["samples"].shape[0] + reference["samples"].shape[0]
24+
def reference_apply(q, k, v, extra_options):
25+
k = k.clone().repeat(1, 2, 1)
26+
offset = 0
27+
if q.shape[0] > batch:
28+
offset = batch
29+
30+
for o in range(0, q.shape[0], batch):
31+
for x in range(1, batch):
32+
k[x + o, q.shape[1]:] = q[o,:]
33+
34+
return q, k, k
35+
36+
model_reference.set_model_attn1_patch(reference_apply)
37+
out_latent = torch.cat((reference["samples"], latent["samples"]))
38+
if "noise_mask" in latent:
39+
mask = latent["noise_mask"]
40+
else:
41+
mask = torch.ones((64,64), dtype=torch.float32, device="cpu")
42+
43+
if len(mask.shape) < 3:
44+
mask = mask.unsqueeze(0)
45+
if mask.shape[0] < latent["samples"].shape[0]:
46+
print(latent["samples"].shape, mask.shape)
47+
mask = mask.repeat(latent["samples"].shape[0], 1, 1)
48+
49+
out_mask = torch.zeros((1,mask.shape[1],mask.shape[2]), dtype=torch.float32, device="cpu")
50+
return (model_reference, {"samples": out_latent, "noise_mask": torch.cat((out_mask, mask))})
51+
52+
NODE_CLASS_MAPPINGS = {
53+
"ReferenceOnlySimple": ReferenceOnlySimple,
54+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
3+
4+
class RescaleClassifierFreeGuidance:
5+
@classmethod
6+
def INPUT_TYPES(s):
7+
return {"required": { "model": ("MODEL",),
8+
"multiplier": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}),
9+
}}
10+
RETURN_TYPES = ("MODEL",)
11+
FUNCTION = "patch"
12+
13+
CATEGORY = "custom_node_experiments"
14+
15+
def patch(self, model, multiplier):
16+
17+
def rescale_cfg(args):
18+
cond = args["cond"]
19+
uncond = args["uncond"]
20+
cond_scale = args["cond_scale"]
21+
22+
x_cfg = uncond + cond_scale * (cond - uncond)
23+
ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True)
24+
ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True)
25+
26+
x_rescaled = x_cfg * (ro_pos / ro_cfg)
27+
x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg
28+
29+
return x_final
30+
31+
m = model.clone()
32+
m.set_model_sampler_cfg_function(rescale_cfg)
33+
return (m, )
34+
35+
36+
NODE_CLASS_MAPPINGS = {
37+
"RescaleClassifierFreeGuidanceTest": RescaleClassifierFreeGuidance,
38+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
3+
4+
class ModelSamplerTonemapNoiseTest:
5+
@classmethod
6+
def INPUT_TYPES(s):
7+
return {"required": { "model": ("MODEL",),
8+
"multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
9+
}}
10+
RETURN_TYPES = ("MODEL",)
11+
FUNCTION = "patch"
12+
13+
CATEGORY = "custom_node_experiments"
14+
15+
def patch(self, model, multiplier):
16+
17+
def sampler_tonemap_reinhard(args):
18+
cond = args["cond"]
19+
uncond = args["uncond"]
20+
cond_scale = args["cond_scale"]
21+
noise_pred = (cond - uncond)
22+
noise_pred_vector_magnitude = (torch.linalg.vector_norm(noise_pred, dim=(1)) + 0.0000000001)[:,None]
23+
noise_pred /= noise_pred_vector_magnitude
24+
25+
mean = torch.mean(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True)
26+
std = torch.std(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True)
27+
28+
top = (std * 3 + mean) * multiplier
29+
30+
#reinhard
31+
noise_pred_vector_magnitude *= (1.0 / top)
32+
new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0)
33+
new_magnitude *= top
34+
35+
return uncond + noise_pred * new_magnitude * cond_scale
36+
37+
m = model.clone()
38+
m.set_model_sampler_cfg_function(sampler_tonemap_reinhard)
39+
return (m, )
40+
41+
42+
NODE_CLASS_MAPPINGS = {
43+
"ModelSamplerTonemapNoiseTest": ModelSamplerTonemapNoiseTest,
44+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
3+
4+
class TonemapNoiseWithRescaleCFG:
5+
@classmethod
6+
def INPUT_TYPES(s):
7+
return {"required": {"model": ("MODEL",),
8+
"tonemap_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
9+
"rescale_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
10+
}}
11+
RETURN_TYPES = ("MODEL",)
12+
FUNCTION = "patch"
13+
14+
CATEGORY = "custom_node_experiments"
15+
16+
def patch(self, model, tonemap_multiplier, rescale_multiplier):
17+
18+
def tonemap_noise_rescale_cfg(args):
19+
cond = args["cond"]
20+
uncond = args["uncond"]
21+
cond_scale = args["cond_scale"]
22+
23+
# Tonemap
24+
noise_pred = (cond - uncond)
25+
noise_pred_vector_magnitude = (torch.linalg.vector_norm(noise_pred, dim=(1)) + 0.0000000001)[:, None]
26+
noise_pred /= noise_pred_vector_magnitude
27+
28+
mean = torch.mean(noise_pred_vector_magnitude, dim=(1, 2, 3), keepdim=True)
29+
std = torch.std(noise_pred_vector_magnitude, dim=(1, 2, 3), keepdim=True)
30+
31+
top = (std * 3 + mean) * tonemap_multiplier
32+
33+
# Reinhard
34+
noise_pred_vector_magnitude *= (1.0 / top)
35+
new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0)
36+
new_magnitude *= top
37+
38+
# Rescale CFG
39+
x_cfg = uncond + (noise_pred * new_magnitude * cond_scale)
40+
ro_pos = torch.std(cond, dim=(1, 2, 3), keepdim=True)
41+
ro_cfg = torch.std(x_cfg, dim=(1, 2, 3), keepdim=True)
42+
43+
x_rescaled = x_cfg * (ro_pos / ro_cfg)
44+
x_final = rescale_multiplier * x_rescaled + (1.0 - rescale_multiplier) * x_cfg
45+
46+
return x_final
47+
48+
m = model.clone()
49+
m.set_model_sampler_cfg_function(tonemap_noise_rescale_cfg)
50+
return (m, )
51+
52+
53+
NODE_CLASS_MAPPINGS = {
54+
"TonemapNoiseWithRescaleCFG": TonemapNoiseWithRescaleCFG,
55+
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import comfy_extras.nodes_model_merging
2+
3+
class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks):
4+
@classmethod
5+
def INPUT_TYPES(s):
6+
arg_dict = { "model1": ("MODEL",),
7+
"model2": ("MODEL",)}
8+
9+
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
10+
11+
arg_dict["time_embed."] = argument
12+
arg_dict["label_emb."] = argument
13+
14+
for i in range(9):
15+
arg_dict["input_blocks.{}".format(i)] = argument
16+
17+
for i in range(3):
18+
arg_dict["middle_block.{}".format(i)] = argument
19+
20+
for i in range(9):
21+
arg_dict["output_blocks.{}".format(i)] = argument
22+
23+
arg_dict["out."] = argument
24+
25+
return {"required": arg_dict}
26+
27+
28+
class ModelMergeSDXLTransformers(comfy_extras.nodes_model_merging.ModelMergeBlocks):
29+
@classmethod
30+
def INPUT_TYPES(s):
31+
arg_dict = { "model1": ("MODEL",),
32+
"model2": ("MODEL",)}
33+
34+
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
35+
36+
arg_dict["time_embed."] = argument
37+
arg_dict["label_emb."] = argument
38+
39+
transformers = {4: 2, 5:2, 7:10, 8:10}
40+
41+
for i in range(9):
42+
arg_dict["input_blocks.{}.0.".format(i)] = argument
43+
if i in transformers:
44+
arg_dict["input_blocks.{}.1.".format(i)] = argument
45+
for j in range(transformers[i]):
46+
arg_dict["input_blocks.{}.1.transformer_blocks.{}.".format(i, j)] = argument
47+
48+
for i in range(3):
49+
arg_dict["middle_block.{}.".format(i)] = argument
50+
if i == 1:
51+
for j in range(10):
52+
arg_dict["middle_block.{}.transformer_blocks.{}.".format(i, j)] = argument
53+
54+
transformers = {3:2, 4: 2, 5:2, 6:10, 7:10, 8:10}
55+
for i in range(9):
56+
arg_dict["output_blocks.{}.0.".format(i)] = argument
57+
t = 8 - i
58+
if t in transformers:
59+
arg_dict["output_blocks.{}.1.".format(i)] = argument
60+
for j in range(transformers[t]):
61+
arg_dict["output_blocks.{}.1.transformer_blocks.{}.".format(i, j)] = argument
62+
63+
arg_dict["out."] = argument
64+
65+
return {"required": arg_dict}
66+
67+
class ModelMergeSDXLDetailedTransformers(comfy_extras.nodes_model_merging.ModelMergeBlocks):
68+
@classmethod
69+
def INPUT_TYPES(s):
70+
arg_dict = { "model1": ("MODEL",),
71+
"model2": ("MODEL",)}
72+
73+
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
74+
75+
arg_dict["time_embed."] = argument
76+
arg_dict["label_emb."] = argument
77+
78+
transformers = {4: 2, 5:2, 7:10, 8:10}
79+
transformers_args = ["norm1", "attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out", "ff.net", "norm2", "attn2.to_q", "attn2.to_k", "attn2.to_v", "attn2.to_out", "norm3"]
80+
81+
for i in range(9):
82+
arg_dict["input_blocks.{}.0.".format(i)] = argument
83+
if i in transformers:
84+
arg_dict["input_blocks.{}.1.".format(i)] = argument
85+
for j in range(transformers[i]):
86+
for x in transformers_args:
87+
arg_dict["input_blocks.{}.1.transformer_blocks.{}.{}".format(i, j, x)] = argument
88+
89+
for i in range(3):
90+
arg_dict["middle_block.{}.".format(i)] = argument
91+
if i == 1:
92+
for j in range(10):
93+
for x in transformers_args:
94+
arg_dict["middle_block.{}.transformer_blocks.{}.{}".format(i, j, x)] = argument
95+
96+
transformers = {3:2, 4: 2, 5:2, 6:10, 7:10, 8:10}
97+
for i in range(9):
98+
arg_dict["output_blocks.{}.0.".format(i)] = argument
99+
t = 8 - i
100+
if t in transformers:
101+
arg_dict["output_blocks.{}.1.".format(i)] = argument
102+
for j in range(transformers[t]):
103+
for x in transformers_args:
104+
arg_dict["output_blocks.{}.1.transformer_blocks.{}.{}".format(i, j, x)] = argument
105+
106+
arg_dict["out."] = argument
107+
108+
return {"required": arg_dict}
109+
110+
NODE_CLASS_MAPPINGS = {
111+
"ModelMergeSDXL": ModelMergeSDXL,
112+
"ModelMergeSDXLTransformers": ModelMergeSDXLTransformers,
113+
"ModelMergeSDXLDetailedTransformers": ModelMergeSDXLDetailedTransformers,
114+
}

0 commit comments

Comments
 (0)