Skip to content

Commit 7dbd273

Browse files
committed
feat: some architecture flexibilities added to uvit
1 parent 6b9b4a4 commit 7dbd273

7 files changed

Lines changed: 914 additions & 13 deletions

File tree

.github/workflows/python-publish.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
name: Upload Python Package
1010

1111
on:
12+
push:
13+
branches: [ "main" ]
1214
release:
1315
types: [published]
1416

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ datacache
1818
gcsfuse.yml
1919
*.csv
2020
*.tsv
21-
*.parquet
21+
*.parquet
22+
*.arrow

datasets/dataset preparations.ipynb

Lines changed: 212 additions & 2 deletions
Large diffs are not rendered by default.

datasets/datasets/laion2B-en-aesthetic-4.2_37M/dataset_info.json

Lines changed: 602 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
{
2+
"_data_files": [
3+
{
4+
"filename": "data-00000-of-00017.arrow"
5+
},
6+
{
7+
"filename": "data-00001-of-00017.arrow"
8+
},
9+
{
10+
"filename": "data-00002-of-00017.arrow"
11+
},
12+
{
13+
"filename": "data-00003-of-00017.arrow"
14+
},
15+
{
16+
"filename": "data-00004-of-00017.arrow"
17+
},
18+
{
19+
"filename": "data-00005-of-00017.arrow"
20+
},
21+
{
22+
"filename": "data-00006-of-00017.arrow"
23+
},
24+
{
25+
"filename": "data-00007-of-00017.arrow"
26+
},
27+
{
28+
"filename": "data-00008-of-00017.arrow"
29+
},
30+
{
31+
"filename": "data-00009-of-00017.arrow"
32+
},
33+
{
34+
"filename": "data-00010-of-00017.arrow"
35+
},
36+
{
37+
"filename": "data-00011-of-00017.arrow"
38+
},
39+
{
40+
"filename": "data-00012-of-00017.arrow"
41+
},
42+
{
43+
"filename": "data-00013-of-00017.arrow"
44+
},
45+
{
46+
"filename": "data-00014-of-00017.arrow"
47+
},
48+
{
49+
"filename": "data-00015-of-00017.arrow"
50+
},
51+
{
52+
"filename": "data-00016-of-00017.arrow"
53+
}
54+
],
55+
"_fingerprint": "9e2180a190e4d3ae",
56+
"_format_columns": null,
57+
"_format_kwargs": {},
58+
"_format_type": null,
59+
"_output_all_columns": false,
60+
"_split": "train"
61+
}

flaxdiff/models/simple_vit.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Callable, Any, Optional, Tuple
77
from .simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
88
from .attention import TransformerBlock
9-
from flaxdiff.models.simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
9+
from flaxdiff.models.simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init, ResidualBlock
1010
import einops
1111
from flax.typing import Dtype, PrecisionLike
1212
from functools import partial
@@ -68,6 +68,7 @@ class UViT(nn.Module):
6868
dtype: Optional[Dtype] = None
6969
precision: PrecisionLike = None
7070
kernel_init: Callable = partial(kernel_init, 1.0)
71+
add_residualblock_output: bool = False
7172

7273
def setup(self):
7374
if self.norm_groups > 0:
@@ -80,6 +81,8 @@ def __call__(self, x, temb, textcontext=None):
8081
# Time embedding
8182
temb = FourierEmbedding(features=self.emb_features)(temb)
8283
temb = TimeProjection(features=self.emb_features)(temb)
84+
85+
original_img = x
8386

8487
# Patch embedding
8588
x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
@@ -141,14 +144,36 @@ def __call__(self, x, temb, textcontext=None):
141144
x = x[:, 1 + num_text_tokens:, :]
142145
x = unpatchify(x, channels=self.output_channels)
143146
# print(f'Shape of x after final dense layer: {x.shape}')
144-
x = nn.Conv(
145-
features=self.output_channels,
146-
kernel_size=(3, 3),
147+
148+
if self.add_residualblock_output:
149+
# Concatenate the original image
150+
x = jnp.concatenate([original_img, x], axis=-1)
151+
152+
x = ResidualBlock(
153+
"conv",
154+
name="final_residual",
155+
features=64,
156+
kernel_init=self.kernel_init(1.0),
157+
kernel_size=(3,3),
158+
strides=(1, 1),
159+
activation=self.activation,
160+
norm_groups=self.norm_groups,
161+
dtype=self.dtype,
162+
precision=self.precision,
163+
named_norms=False
164+
)(x, temb)
165+
166+
x = self.norm()(x)
167+
x = self.activation(x)
168+
169+
x = ConvLayer(
170+
"conv",
171+
features=self.output_channels,
172+
kernel_size=(3, 3),
147173
strides=(1, 1),
148-
padding='SAME',
149-
dtype=self.dtype,
150-
precision=self.precision,
151-
kernel_init=kernel_init(0.0),
174+
# activation=jax.nn.mish
175+
kernel_init=self.kernel_init(0.0),
176+
dtype=self.dtype,
177+
precision=self.precision
152178
)(x)
153-
154179
return x

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
setup(
1212
name='flaxdiff',
1313
packages=find_packages(),
14-
version='0.1.27',
14+
version='0.1.28',
1515
description='A versatile and easy to understand Diffusion library',
1616
long_description=open('README.md').read(),
1717
long_description_content_type='text/markdown',

0 commit comments

Comments
 (0)