Skip to content

Commit f02de13

Browse files
authored
Add TCFG node (#8730)
1 parent c46268b commit f02de13

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

comfy_extras/nodes_tcfg.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)
2+
3+
import torch
4+
5+
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
6+
7+
8+
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
9+
"""Drop tangential components from uncond score to align with cond score."""
10+
# (B, 1, ...)
11+
batch_num = cond_score.shape[0]
12+
cond_score_flat = cond_score.reshape(batch_num, 1, -1).float()
13+
uncond_score_flat = uncond_score.reshape(batch_num, 1, -1).float()
14+
15+
# Score matrix A (B, 2, ...)
16+
score_matrix = torch.cat((uncond_score_flat, cond_score_flat), dim=1)
17+
try:
18+
_, _, Vh = torch.linalg.svd(score_matrix, full_matrices=False)
19+
except RuntimeError:
20+
# Fallback to CPU
21+
_, _, Vh = torch.linalg.svd(score_matrix.cpu(), full_matrices=False)
22+
23+
# Drop the tangential components
24+
v1 = Vh[:, 0:1, :].to(uncond_score_flat.device) # (B, 1, ...)
25+
uncond_score_td = (uncond_score_flat @ v1.transpose(-2, -1)) * v1
26+
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)
27+
28+
29+
class TCFG(ComfyNodeABC):
30+
@classmethod
31+
def INPUT_TYPES(cls) -> InputTypeDict:
32+
return {
33+
"required": {
34+
"model": (IO.MODEL, {}),
35+
}
36+
}
37+
38+
RETURN_TYPES = (IO.MODEL,)
39+
RETURN_NAMES = ("patched_model",)
40+
FUNCTION = "patch"
41+
42+
CATEGORY = "advanced/guidance"
43+
DESCRIPTION = "TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality."
44+
45+
def patch(self, model):
46+
m = model.clone()
47+
48+
def tangential_damping_cfg(args):
49+
# Assume [cond, uncond, ...]
50+
x = args["input"]
51+
conds_out = args["conds_out"]
52+
if len(conds_out) <= 1 or None in args["conds"][:2]:
53+
# Skip when either cond or uncond is None
54+
return conds_out
55+
cond_pred = conds_out[0]
56+
uncond_pred = conds_out[1]
57+
uncond_td = score_tangential_damping(x - cond_pred, x - uncond_pred)
58+
uncond_pred_td = x - uncond_td
59+
return [cond_pred, uncond_pred_td] + conds_out[2:]
60+
61+
m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
62+
return (m,)
63+
64+
65+
NODE_CLASS_MAPPINGS = {
66+
"TCFG": TCFG,
67+
}
68+
69+
NODE_DISPLAY_NAME_MAPPINGS = {
70+
"TCFG": "Tangential Damping CFG",
71+
}

nodes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2283,6 +2283,7 @@ def init_builtin_extra_nodes():
22832283
"nodes_string.py",
22842284
"nodes_camera_trajectory.py",
22852285
"nodes_edit_model.py",
2286+
"nodes_tcfg.py"
22862287
]
22872288

22882289
import_failed = []

0 commit comments

Comments
 (0)