Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions comfy_extras/nodes_eps.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
from typing_extensions import override

from comfy.k_diffusion.sampling import sigma_to_half_log_snr
from comfy_api.latest import ComfyExtension, io


Expand Down Expand Up @@ -63,12 +65,105 @@ def epsilon_scaling_function(args):
return io.NodeOutput(model_clone)


def compute_tsr_rescaling_factor(
snr: torch.Tensor, tsr_k: float, tsr_variance: float
) -> torch.Tensor:
"""Compute the rescaling score ratio in Temporal Score Rescaling.
See equation (6) in https://arxiv.org/pdf/2510.01184v1.
"""
posinf_mask = torch.isposinf(snr)
rescaling_factor = (snr * tsr_variance + 1) / (snr * tsr_variance / tsr_k + 1)
return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k


class TemporalScoreRescaling(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TemporalScoreRescaling",
display_name="TSR - Temporal Score Rescaling",
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Float.Input(
"tsr_k",
tooltip=(
"Controls the rescaling strength.\n"
"Lower k produces more detailed results; higher k produces smoother results in image generation. Setting k = 1 disables rescaling."
),
default=0.95,
min=0.01,
max=100.0,
step=0.001,
display_mode=io.NumberDisplay.number,
),
io.Float.Input(
"tsr_sigma",
tooltip=(
"Controls how early rescaling takes effect.\n"
"Larger values take effect earlier."
),
default=1.0,
min=0.01,
max=100.0,
step=0.001,
display_mode=io.NumberDisplay.number,
),
],
outputs=[
io.Model.Output(
display_name="patched_model",
),
],
description=(
"[Post-CFG Function]\n"
"TSR - Temporal Score Rescaling (2510.01184)\n\n"
"Rescaling the model's score or noise to steer the sampling diversity.\n"
),
)

@classmethod
def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput:
tsr_variance = tsr_sigma**2

def temporal_score_rescaling(args):
denoised = args["denoised"]
x = args["input"]
sigma = args["sigma"]
curr_model = args["model"]

# No rescaling (r = 1) or no noise
if tsr_k == 1 or sigma == 0:
return denoised

model_sampling = curr_model.current_patcher.get_model_object("model_sampling")
half_log_snr = sigma_to_half_log_snr(sigma, model_sampling)
snr = (2 * half_log_snr).exp()

# No rescaling needed (r = 1)
if snr == 0:
return denoised

rescaling_r = compute_tsr_rescaling_factor(snr, tsr_k, tsr_variance)

# Derived from scaled_denoised = (x - r * sigma * noise) / alpha
alpha = sigma * half_log_snr.exp()
return torch.lerp(x / alpha, denoised, rescaling_r)

m = model.clone()
m.set_model_sampler_post_cfg_function(temporal_score_rescaling)
return io.NodeOutput(m)


class EpsilonScalingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EpsilonScaling,
TemporalScoreRescaling,
]


async def comfy_entrypoint() -> EpsilonScalingExtension:
return EpsilonScalingExtension()