Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions comfy/cldm/cldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from ..ldm.modules.attention import SpatialTransformer
from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
from ..ldm.util import exists
from .control_types import UNION_CONTROLNET_TYPES
from collections import OrderedDict
import comfy.ops
Expand Down Expand Up @@ -234,12 +233,12 @@ def __init__(
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
if disable_self_attentions is not None:
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False

if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
if num_attention_blocks is None or nr < num_attention_blocks[level]:
layers.append(
SpatialTransformer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
Expand Down Expand Up @@ -434,4 +433,3 @@ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
out_middle.append(self.middle_block_out(h, emb, context))

return {"middle": out_middle, "output": out_output}

16 changes: 2 additions & 14 deletions comfy/gligen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,6 @@
import comfy.ops
ops = comfy.ops.manual_cast

def exists(val):
return val is not None


def uniq(arr):
return{el: True for el in arr}.keys()


def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d


# feedforward
class GEGLU(nn.Module):
Expand All @@ -34,7 +21,8 @@ class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
if not dim_out:
dim_out = dim
project_in = nn.Sequential(
ops.Linear(dim, inner_dim),
nn.GELU()
Expand Down
41 changes: 9 additions & 32 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,30 +28,6 @@ def get_attn_precision(attn_precision):
return FORCE_UPCAST_ATTENTION_DTYPE
return attn_precision

def exists(val):
return val is not None


def uniq(arr):
return{el: True for el in arr}.keys()


def default(val, d):
if exists(val):
return val
return d


def max_neg_value(t):
return -torch.finfo(t.dtype).max


def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor


# feedforward
class GEGLU(nn.Module):
Expand All @@ -68,7 +44,8 @@ class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
if dim_out is None:
dim_out = dim
project_in = nn.Sequential(
operations.Linear(dim, inner_dim, dtype=dtype, device=device),
nn.GELU()
Expand Down Expand Up @@ -121,7 +98,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape

del q, k

if exists(mask):
if mask is not None:
if mask.dtype == torch.bool:
mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
max_neg_value = -torch.finfo(sim.dtype).max
Expand Down Expand Up @@ -449,7 +426,8 @@ class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
if context_dim is None:
context_dim = query_dim
self.attn_precision = attn_precision

self.heads = heads
Expand All @@ -463,7 +441,8 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.

def forward(self, x, context=None, value=None, mask=None):
q = self.to_q(x)
context = default(context, x)
if context is None:
context = x
k = self.to_k(context)
if value is not None:
v = self.to_v(value)
Expand Down Expand Up @@ -649,7 +628,7 @@ def __init__(self, in_channels, n_heads, d_head,
disable_self_attn=False, use_linear=False,
use_checkpoint=True, attn_precision=None, dtype=None, device=None, operations=ops):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
if context_dim is not None and not isinstance(context_dim, list):
context_dim = [context_dim] * depth
self.in_channels = in_channels
inner_dim = n_heads * d_head
Expand Down Expand Up @@ -799,7 +778,7 @@ def forward(
_, _, h, w = x.shape
x_in = x
spatial_context = None
if exists(context):
if context is not None:
spatial_context = context

if self.use_spatial_context:
Expand Down Expand Up @@ -861,5 +840,3 @@ def forward(
x = self.proj_out(x)
out = x + x_in
return out


13 changes: 4 additions & 9 deletions comfy/ldm/modules/diffusionmodules/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
import comfy.ops
import comfy.ldm.common_dit

def default(x, y):
if x is not None:
return x
return y

class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
Expand Down Expand Up @@ -713,14 +709,14 @@ def __init__(
self.learn_sigma = learn_sigma
self.in_channels = in_channels
default_out_channels = in_channels * 2 if learn_sigma else in_channels
self.out_channels = default(out_channels, default_out_channels)
self.out_channels = default_out_channels if out_channels is None else out_channels
self.patch_size = patch_size
self.pos_embed_scaling_factor = pos_embed_scaling_factor
self.pos_embed_offset = pos_embed_offset
self.pos_embed_max_size = pos_embed_max_size

# hidden_size = default(hidden_size, 64 * depth)
# num_heads = default(num_heads, hidden_size // 64)
# hidden_size = 64 * depth if hidden_size is None else hidden_size
# num_heads = hidden_size // 64 if num_heads is None else num_heads

# apply magic --> this defines a head_size of 64
self.hidden_size = 64 * depth
Expand Down Expand Up @@ -862,7 +858,7 @@ def forward_core_with_concat(
context = torch.cat(
(
repeat(self.register, "1 ... -> b ...", b=x.shape[0]),
default(context, torch.Tensor([]).type_as(x)),
torch.Tensor([]).type_as(x) if context is None else context,
),
1,
)
Expand Down Expand Up @@ -932,4 +928,3 @@ def forward(
**kwargs,
) -> torch.Tensor:
return super().forward(x, timesteps, context=context, y=y, control=control)

15 changes: 7 additions & 8 deletions comfy/ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
timestep_embedding,
AlphaBlender,
)
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from comfy.ldm.util import exists
from ..attention import SpatialTransformer, SpatialVideoTransformer
import comfy.ops
ops = comfy.ops.disable_weight_init

Expand Down Expand Up @@ -301,11 +300,11 @@ def __init__(
)

self.time_stack = ResBlock(
default(out_channels, channels),
channels if out_channels is None else out_channels,
emb_channels,
dropout=dropout,
dims=3,
out_channels=default(out_channels, channels),
out_channels=channels if out_channels is None else out_channels,
use_scale_shift_norm=False,
use_conv=False,
up=False,
Expand Down Expand Up @@ -642,12 +641,12 @@ def get_resblock(
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
if disable_self_attentions is not None:
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False

if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
if not num_attention_blocks is None or nr < num_attention_blocks[level]:
layers.append(get_attention_layer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint)
Expand Down Expand Up @@ -768,12 +767,12 @@ def get_resblock(
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
if disable_self_attentions is not None:
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False

if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
if num_attention_blocks is None or i < num_attention_blocks[level]:
layers.append(
get_attention_layer(
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
Expand Down
4 changes: 0 additions & 4 deletions comfy/ldm/modules/diffusionmodules/upscaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from functools import partial

from .util import extract_into_tensor, make_beta_schedule
from comfy.ldm.util import default


class AbstractLowScaleModel(nn.Module):
Expand Down Expand Up @@ -80,6 +79,3 @@ def forward(self, x, noise_level=None, seed=None):
assert isinstance(noise_level, torch.Tensor)
z = self.q_sample(x, noise_level, seed=seed)
return z, noise_level



12 changes: 1 addition & 11 deletions comfy/ldm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,6 @@ def isimage(x):
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)


def exists(x):
return x is not None


def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d


def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Expand Down Expand Up @@ -194,4 +184,4 @@ def step(self, closure=None):
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)

return loss
return loss
Loading