|
12 | 12 | T2IFinalLayer, |
13 | 13 | SizeEmbedder, |
14 | 14 | ) |
15 | | -from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp |
16 | | -from .pixart import PixArt, get_2d_sincos_pos_embed_torch |
| 15 | +from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch |
17 | 16 |
|
18 | 17 |
|
| 18 | +def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32): |
| 19 | + grid_h, grid_w = torch.meshgrid( |
| 20 | + torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation, |
| 21 | + torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation, |
| 22 | + indexing='ij' |
| 23 | + ) |
| 24 | + emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype) |
| 25 | + emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype) |
| 26 | + emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D) |
| 27 | + return emb |
| 28 | + |
19 | 29 | class PixArtMSBlock(nn.Module): |
20 | 30 | """ |
21 | 31 | A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning. |
@@ -53,7 +63,7 @@ def forward(self, x, y, t, mask=None, HW=None, **kwargs): |
53 | 63 |
|
54 | 64 |
|
55 | 65 | ### Core PixArt Model ### |
56 | | -class PixArtMS(PixArt): |
| 66 | +class PixArtMS(nn.Module): |
57 | 67 | """ |
58 | 68 | Diffusion model with a Transformer backbone. |
59 | 69 | """ |
|
0 commit comments