@@ -46,32 +46,33 @@ def forward(self, x, cond, mask=None):
4646 kv = self .kv_linear (cond ).view (1 , - 1 , 2 , self .num_heads , self .head_dim )
4747 k , v = kv .unbind (2 )
4848
49- # TODO: xformers needs separate mask logic here
50- if model_management .xformers_enabled ():
51- attn_bias = None
52- if mask is not None :
53- attn_bias = block_diagonal_mask_from_seqlens ([N ] * B , mask )
54- x = xformers .ops .memory_efficient_attention (q , k , v , p = 0 , attn_bias = attn_bias )
55- else :
56- q , k , v = map (lambda t : t .transpose (1 , 2 ), (q , k , v ),)
57- attn_mask = None
58- if mask is not None and len (mask ) > 1 :
59- # Create equivalent of xformer diagonal block mask, still only correct for square masks
60- # But depth doesn't matter as tensors can expand in that dimension
61- attn_mask_template = torch .ones (
62- [q .shape [2 ] // B , mask [0 ]],
63- dtype = torch .bool ,
64- device = q .device
65- )
66- attn_mask = torch .block_diag (attn_mask_template )
67-
68- # create a mask on the diagonal for each mask in the batch
69- for _ in range (B - 1 ):
70- attn_mask = torch .block_diag (attn_mask , attn_mask_template )
71-
72- x = optimized_attention (q , k , v , self .num_heads , mask = attn_mask , skip_reshape = True )
73-
74- x = x .view (B , - 1 , C )
49+ assert mask is None # TODO?
50+ # # TODO: xformers needs separate mask logic here
51+ # if model_management.xformers_enabled():
52+ # attn_bias = None
53+ # if mask is not None:
54+ # attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
55+ # x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias)
56+ # else:
57+ # q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
58+ # attn_mask = None
59+ # mask = torch.ones(())
60+ # if mask is not None and len(mask) > 1:
61+ # # Create equivalent of xformer diagonal block mask, still only correct for square masks
62+ # # But depth doesn't matter as tensors can expand in that dimension
63+ # attn_mask_template = torch.ones(
64+ # [q.shape[2] // B, mask[0]],
65+ # dtype=torch.bool,
66+ # device=q.device
67+ # )
68+ # attn_mask = torch.block_diag(attn_mask_template)
69+ #
70+ # # create a mask on the diagonal for each mask in the batch
71+ # for _ in range(B - 1):
72+ # attn_mask = torch.block_diag(attn_mask, attn_mask_template)
73+ # x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True)
74+
75+ x = optimized_attention (q .view (B , - 1 , C ), k .view (B , - 1 , C ), v .view (B , - 1 , C ), self .num_heads , mask = None )
7576 x = self .proj (x )
7677 x = self .proj_drop (x )
7778 return x
@@ -155,9 +156,9 @@ def forward(self, x, mask=None, HW=None, block_id=None):
155156 k , new_N = self .downsample_2d (k , H , W , self .sr_ratio , sampling = self .sampling )
156157 v , new_N = self .downsample_2d (v , H , W , self .sr_ratio , sampling = self .sampling )
157158
158- q = q .reshape (B , N , self .num_heads , C // self .num_heads ). to ( dtype )
159- k = k .reshape (B , new_N , self .num_heads , C // self .num_heads ). to ( dtype )
160- v = v .reshape (B , new_N , self .num_heads , C // self .num_heads ). to ( dtype )
159+ q = q .reshape (B , N , self .num_heads , C // self .num_heads )
160+ k = k .reshape (B , new_N , self .num_heads , C // self .num_heads )
161+ v = v .reshape (B , new_N , self .num_heads , C // self .num_heads )
161162
162163 if mask is not None :
163164 raise NotImplementedError ("Attn mask logic not added for self attention" )
@@ -209,9 +210,9 @@ def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=Non
209210
210211 def forward (self , x , t ):
211212 dtype = x .dtype
212- shift , scale = (self .scale_shift_table [None ] + t [:, None ]).chunk (2 , dim = 1 )
213+ shift , scale = (self .scale_shift_table [None ]. to ( dtype = x . dtype , device = x . device ) + t [:, None ]).chunk (2 , dim = 1 )
213214 x = t2i_modulate (self .norm_final (x ), shift , scale )
214- x = self .linear (x . to ( dtype ) )
215+ x = self .linear (x )
215216 return x
216217
217218
0 commit comments