Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions comfy_extras/nodes_tomesd.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#Taken from: https://github.com/dbolya/tomesd

import torch
from typing import Tuple, Callable
from typing import Tuple, Callable, Optional
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import math

def do_nothing(x: torch.Tensor, mode:str=None):
Expand Down Expand Up @@ -144,33 +146,45 @@ def get_functions(x, ratio, original_shape):



class TomePatchModel:
class TomePatchModel(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
def define_schema(cls):
return io.Schema(
node_id="TomePatchModel",
category="model_patches/unet",
inputs=[
io.Model.Input("model"),
io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01),
],
outputs=[io.Model.Output()],
)

CATEGORY = "model_patches/unet"

def patch(self, model, ratio):
self.u = None
@classmethod
def execute(cls, model, ratio) -> io.NodeOutput:
u: Optional[Callable] = None
def tomesd_m(q, k, v, extra_options):
nonlocal u
#NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
#however from my basic testing it seems that using q instead gives better results
m, self.u = get_functions(q, ratio, extra_options["original_shape"])
m, u = get_functions(q, ratio, extra_options["original_shape"])
return m(q), k, v
def tomesd_u(n, extra_options):
return self.u(n)
nonlocal u
return u(n)

m = model.clone()
m.set_model_attn1_patch(tomesd_m)
m.set_model_attn1_output_patch(tomesd_u)
return (m, )
return io.NodeOutput(m)


class TomePatchModelExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TomePatchModel,
]


NODE_CLASS_MAPPINGS = {
"TomePatchModel": TomePatchModel,
}
async def comfy_entrypoint() -> TomePatchModelExtension:
return TomePatchModelExtension()