Skip to content

Commit e946667

Browse files
Some fixes/cleanups to pixart code.
Commented out the masking related code because it is never used in this implementation.
1 parent d7969cb commit e946667

File tree

4 files changed

+43
-43
lines changed

4 files changed

+43
-43
lines changed

comfy/ldm/pixart/blocks.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -46,32 +46,33 @@ def forward(self, x, cond, mask=None):
4646
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
4747
k, v = kv.unbind(2)
4848

49-
# TODO: xformers needs separate mask logic here
50-
if model_management.xformers_enabled():
51-
attn_bias = None
52-
if mask is not None:
53-
attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
54-
x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias)
55-
else:
56-
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
57-
attn_mask = None
58-
if mask is not None and len(mask) > 1:
59-
# Create equivalent of xformer diagonal block mask, still only correct for square masks
60-
# But depth doesn't matter as tensors can expand in that dimension
61-
attn_mask_template = torch.ones(
62-
[q.shape[2] // B, mask[0]],
63-
dtype=torch.bool,
64-
device=q.device
65-
)
66-
attn_mask = torch.block_diag(attn_mask_template)
67-
68-
# create a mask on the diagonal for each mask in the batch
69-
for _ in range(B - 1):
70-
attn_mask = torch.block_diag(attn_mask, attn_mask_template)
71-
72-
x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True)
73-
74-
x = x.view(B, -1, C)
49+
assert mask is None # TODO?
50+
# # TODO: xformers needs separate mask logic here
51+
# if model_management.xformers_enabled():
52+
# attn_bias = None
53+
# if mask is not None:
54+
# attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
55+
# x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias)
56+
# else:
57+
# q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
58+
# attn_mask = None
59+
# mask = torch.ones(())
60+
# if mask is not None and len(mask) > 1:
61+
# # Create equivalent of xformer diagonal block mask, still only correct for square masks
62+
# # But depth doesn't matter as tensors can expand in that dimension
63+
# attn_mask_template = torch.ones(
64+
# [q.shape[2] // B, mask[0]],
65+
# dtype=torch.bool,
66+
# device=q.device
67+
# )
68+
# attn_mask = torch.block_diag(attn_mask_template)
69+
#
70+
# # create a mask on the diagonal for each mask in the batch
71+
# for _ in range(B - 1):
72+
# attn_mask = torch.block_diag(attn_mask, attn_mask_template)
73+
# x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True)
74+
75+
x = optimized_attention(q.view(B, -1, C), k.view(B, -1, C), v.view(B, -1, C), self.num_heads, mask=None)
7576
x = self.proj(x)
7677
x = self.proj_drop(x)
7778
return x
@@ -155,9 +156,9 @@ def forward(self, x, mask=None, HW=None, block_id=None):
155156
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
156157
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
157158

158-
q = q.reshape(B, N, self.num_heads, C // self.num_heads).to(dtype)
159-
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
160-
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype)
159+
q = q.reshape(B, N, self.num_heads, C // self.num_heads)
160+
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads)
161+
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads)
161162

162163
if mask is not None:
163164
raise NotImplementedError("Attn mask logic not added for self attention")
@@ -209,9 +210,9 @@ def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=Non
209210

210211
def forward(self, x, t):
211212
dtype = x.dtype
212-
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
213+
shift, scale = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t[:, None]).chunk(2, dim=1)
213214
x = t2i_modulate(self.norm_final(x), shift, scale)
214-
x = self.linear(x.to(dtype))
215+
x = self.linear(x)
215216
return x
216217

217218

comfy/ldm/pixart/pixart.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,8 @@ def forward_raw(self, x, t, y, mask=None, data_info=None):
127127
t: (N,) tensor of diffusion timesteps
128128
y: (N, 1, 120, C) tensor of class labels
129129
"""
130-
x = x.to(self.dtype)
131-
timestep = t.to(self.dtype)
132-
y = y.to(self.dtype)
133-
pos_embed = self.pos_embed.to(self.dtype)
134130
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
135-
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
131+
t = self.t_embedder(timestep) # (N, D)
136132
t0 = self.t_block(t)
137133
y = self.y_embedder(y, self.training) # (N, 1, L, D)
138134
if mask is not None:
@@ -142,7 +138,7 @@ def forward_raw(self, x, t, y, mask=None, data_info=None):
142138
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
143139
y_lens = mask.sum(dim=1).tolist()
144140
else:
145-
y_lens = [y.shape[2]] * y.shape[0]
141+
y_lens = None
146142
y = y.squeeze(1).view(1, -1, x.shape[-1])
147143
for block in self.blocks:
148144
x = block(x, y, t0, y_lens) # (N, T, D)
@@ -164,13 +160,12 @@ def forward(self, x, timesteps, context, y=None, **kwargs):
164160

165161
## run original forward pass
166162
out = self.forward_raw(
167-
x = x.to(self.dtype),
168-
t = timesteps.to(self.dtype),
169-
y = context.to(self.dtype),
163+
x = x,
164+
t = timesteps,
165+
y = context,
170166
)
171167

172168
## only return EPS
173-
out = out.to(torch.float)
174169
eps, _ = out[:, :self.in_channels], out[:, self.in_channels:]
175170
return eps
176171

comfy/ldm/pixart/pixartms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_si
4444
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
4545
B, N, C = x.shape
4646

47-
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(x.dtype) + t.reshape(B, 6, -1)).chunk(6, dim=1)
47+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
4848
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
4949
x = x + self.cross_attn(x, y, mask)
5050
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
@@ -196,7 +196,7 @@ def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwar
196196
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
197197
y_lens = mask.sum(dim=1).tolist()
198198
else:
199-
y_lens = [y.shape[2]] * y.shape[0]
199+
y_lens = None
200200
y = y.squeeze(1).view(1, -1, x.shape[-1])
201201
for block in self.blocks:
202202
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)

comfy/model_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,10 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None):
726726
def extra_conds(self, **kwargs):
727727
out = super().extra_conds(**kwargs)
728728

729+
cross_attn = kwargs.get("cross_attn", None)
730+
if cross_attn is not None:
731+
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
732+
729733
width = kwargs.get("width", None)
730734
height = kwargs.get("height", None)
731735
if width is not None and height is not None:

0 commit comments

Comments
 (0)