-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvit.py
359 lines (303 loc) · 10.6 KB
/
vit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
from typing import List, Tuple, Optional
import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
from jaxtyping import Key, Array, Float
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
class MultiHeadAttentionLayer(eqx.Module):
embed_dim: int
n_heads: int
head_dim: int
_q: eqx.nn.Linear
_k: eqx.nn.Linear
_v: eqx.nn.Linear # Q, K, V projections
_o: eqx.nn.Linear # Output layer
dropout: eqx.nn.Dropout
def __init__(
self,
embed_dim: int,
n_heads: int,
dropout_rate: float,
*,
key: Key
):
self.embed_dim = embed_dim # Embed dimension
self.n_heads = n_heads
self.head_dim = int(embed_dim / n_heads)
keys = jr.split(key, 4)
self._q = eqx.nn.Linear(embed_dim, embed_dim, key=keys[0]) # Query
self._k = eqx.nn.Linear(embed_dim, embed_dim, key=keys[1]) # Key
self._v = eqx.nn.Linear(embed_dim, embed_dim, key=keys[2]) # Value
self._o = eqx.nn.Linear(embed_dim, embed_dim, key=keys[3]) # Output
self.dropout = eqx.nn.Dropout(dropout_rate)
def __call__(
self,
query: Float[Array, "pc self.embed_dim"],
_key: Float[Array, "pc self.embed_dim"],
value: Float[Array, "pc self.embed_dim"],
*,
key: Optional[Key] = None
) -> Tuple[Array, Array]:
Q = jax.vmap(self._q)(query) # [query_len, embed_dim]
K = jax.vmap(self._k)(_key) # [key_len, embed_dim]
V = jax.vmap(self._v)(value) # [value_len, embed_dim]
# Embed_dim = n_heads * head_dim
Q = rearrange(Q, 'l (h d) -> h l d', h=self.n_heads) # [n_heads, query_len, head_dim]
K = rearrange(K, 'l (h d) -> h l d', h=self.n_heads) # [n_heads, key_len, head_dim]
V = rearrange(V, 'l (h d) -> h l d', h=self.n_heads) # [n_heads, value_len, head_dim]
# Scaled Dot-Product Attention
weight = Q @ rearrange(K, 'h l d -> h d l') / jnp.sqrt(self.head_dim) # [n_heads, query_len, key_len]
attention = jax.nn.softmax(weight, axis=-1) # [n_heads, query_len, key_len]
# Class token (ViT regresses with this into output layer)
c = self.dropout(attention, key=key) @ V # [n_heads, query_len, head_dim]
# Reshape & stack
c = rearrange(c, 'h l d -> l (h d)') # [query_len, embed_dim]
output = jax.vmap(self._o)(c)
return output, attention # [query_len, embed_dim]
class TokenMLP(eqx.Module):
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear
dropout: eqx.nn.Dropout
def __init__(
self,
embed_dim: int,
hidden_dim: int,
dropout_rate: float,
*,
key: Key
):
keys = jr.split(key)
self.linear1 = eqx.nn.Linear(embed_dim, hidden_dim, key=keys[0])
self.linear2 = eqx.nn.Linear(hidden_dim, embed_dim, key=keys[1])
self.dropout = eqx.nn.Dropout(dropout_rate)
def __call__(self, x: Array, key: Optional[Key] = None) -> Array:
# x: [seq_len, embed_dim]
x = jax.nn.gelu(jax.vmap(self.linear1)(x))
x = self.dropout(x, key=key)
x = jax.vmap(self.linear2)(x) # [seq_len, hidden_dim]
return x # [seq_len, embed_dim]
class EncoderLayer(eqx.Module):
embed_dim: int
layernorm1: eqx.nn.LayerNorm
layernorm2: eqx.nn.LayerNorm
multihead_attention_layer: MultiHeadAttentionLayer
token_mlp: TokenMLP
dropout: eqx.nn.Dropout
def __init__(
self,
embed_dim: int,
n_heads: int,
hidden_dim: int,
dropout_rate: float,
*,
key: Key
):
self.embed_dim = embed_dim
keys = jr.split(key)
self.layernorm1 = eqx.nn.LayerNorm(embed_dim)
self.layernorm2 = eqx.nn.LayerNorm(embed_dim)
self.multihead_attention_layer = MultiHeadAttentionLayer(
embed_dim, n_heads, dropout_rate, key=keys[0]
)
self.token_mlp = TokenMLP(
embed_dim, hidden_dim, dropout_rate, key=keys[1]
)
self.dropout = eqx.nn.Dropout(dropout_rate)
def __call__(
self,
x: Float[Array, "s self.embed_dim"],
*,
key: Optional[Key] = None
) -> Tuple[Float[Array, "s self.embed_dim"], Float[Array, "s self.embed_dim"]]:
# x: [x_len, embed_dim]
if key is not None:
keys = jr.split(key, 4)
# Layernorm first
_x = jax.vmap(self.layernorm1)(x)
# Self attention (K=Q=V)
_x, attention = self.multihead_attention_layer(
_x, _x, _x, key=keys[0] if key is not None else None
)
# Residual connections after every attention block
x = x + self.dropout(
_x, key=keys[1] if key is not None else None
)
_x = jax.vmap(self.layernorm2)(x)
_x = self.token_mlp(
_x, key=keys[2] if key is not None else None
)
x = x + self.dropout(
_x, key=keys[3] if key is not None else None
) # [x_len, embed_dim]
return x, attention # x: [x_len, embed_dim]
class Encoder(eqx.Module):
embed_dim: int
layers: list[EncoderLayer]
def __init__(
self,
embed_dim: int,
n_layers: int,
n_heads: int,
hidden_dim: int,
dropout_rate: float,
*,
key: Key
):
self.embed_dim = embed_dim
self.layers = [
EncoderLayer(
embed_dim, n_heads, hidden_dim, dropout_rate, key=_key
)
for _key in jr.split(key, n_layers)
]
def __call__(
self,
x: Float[Array, "s self.embed_dim"],
key: Optional[Key] = None
) -> Tuple[Float[Array, "s self.embed_dim"], List[Float[Array, "s a"]]]:
# x: [x_len]
attentions = []
for i, layer in enumerate(self.layers):
x, attention = layer(
x, key=jr.fold_in(key, i) if key is not None else None
)
attentions.append(attention)
return x, attentions # x: [x_len, embed_dim]
class ImageEmbedding(eqx.Module):
patch_size: int
linear: eqx.nn.Linear
cls_token: Array
def __init__(
self,
channel: int,
patch_size: int,
embed_dim: int,
*,
key: Key
):
keys = jr.split(key)
self.patch_size = patch_size
# [patch, patch_size * patch_size * channel] -> [patch, embed_dim]
self.linear = eqx.nn.Linear(
channel * patch_size * patch_size, embed_dim, key=keys[0]
)
# Class token
self.cls_token = jr.normal(keys[1], (1, embed_dim))
def __call__(self, x: Float[Array, "c h w"]) -> Float[Array, "p self.embed_dim"]:
# x: [channel, width, height]
c, *_ = image.shape
flatten_patches = rearrange(
image,
'c (n_w p1) (n_h p2) -> (n_w n_h) (p1 p2 c) ',
c=c,
p1=self.patch_size,
p2=self.patch_size
) # [patch, patch_size * patch_size * channel]
embedded_patches = jax.vmap(self.linear)(flatten_patches) # [patch, embed_dim]
# Learnable embedding to the sequence of embedded patches for regression
embedded_patches = jnp.concatenate([self.cls_token, embedded_patches])
return embedded_patches # [1 + patch, embed_dim]
class PatchEmbedding(eqx.Module):
linear: eqx.nn.Embedding
patch_size: int
def __init__(
self,
input_channels: int,
output_shape: int,
patch_size: int,
key: Key,
):
self.patch_size = patch_size
self.linear = eqx.nn.Linear(
self.patch_size ** 2 * input_channels,
output_shape,
key=key,
)
def __call__(self, x: Array) -> Array:
x = rearrange(
x,
"c (h ph) (w pw) -> (h w) (c ph pw)",
ph=self.patch_size,
pw=self.patch_size,
)
x = jax.vmap(self.linear)(x)
return x
class TokenPositionalEmbedding(eqx.Module):
embed_dim: int
token_embedding: ImageEmbedding
position_embedding: eqx.nn.Embedding
dropout: eqx.nn.Dropout
def __init__(
self,
c: int,
p: int,
embed_dim: int,
dropout_rate: float,
*,
key: Key
):
self.embed_dim = embed_dim
keys = jr.split(key)
self.token_embedding = ImageEmbedding(c, p, embed_dim, key=keys[0])
# Replace this with sin/cos embedding? Max number of patches = 100 here?
self.position_embedding = eqx.nn.Embedding(100, embed_dim, key=keys[1]) # Replace 100 with image_size ** 2 / patch_size ** 2
self.dropout = eqx.nn.Dropout(dropout_rate)
def __call__(
self,
x: Float[Array, "w h c"], # ?
*,
key: Optional[Key] = None
) -> Float[Array, "s self.embed_dim"]:
# x: [width, height, channel]
x = self.token_embedding(x) # [x_len, embed_dim]
# Positional embedding
pos = jnp.arange(0, x.shape[0]) # [x_len]
x = x * jnp.sqrt(self.embed_dim) + jax.vmap(self.position_embedding)(pos)
x = self.dropout(x, key=key)
return x # [x_len, embed_dim]
class VisionTransformer(eqx.Module):
embedding: TokenPositionalEmbedding
encoder: Encoder
layernorm: eqx.nn.LayerNorm
linear: eqx.nn.Linear
def __init__(
self,
c,
p,
embed_dim,
n_layers,
n_heads,
hidden_dim,
dropout_rate,
output_dim,
*,
key
):
keys = jr.split(key, 3)
self.embedding = TokenPositionalEmbedding(c, p, embed_dim, dropout_rate, key=keys[0])
self.encoder = Encoder(embed_dim, n_layers, n_heads, hidden_dim, dropout_rate, key=keys[1])
self.layernorm = eqx.nn.LayerNorm(embed_dim)
self.linear = eqx.nn.Linear(embed_dim, output_dim, key=keys[2])
def __call__(
self,
x: Float[Array, "c h w"],
*,
key: Optional[Key] = None
) -> Tuple[Float[Array, "o"], List[Float[Array, "s a"]]]:
# x: [x_len]
if key is not None:
keys = jr.split(key)
x = self.embedding(
x, key=keys[0] if key is not None else None
)
# Encoded x: [x_len, embed_dim]
x_embedded, attentions = self.encoder(
x, key=keys[1] if key is not None else None
)
# Classification head
cls_token = x_embedded[0, :] # [embed_dim]
cls_token = self.layernorm(cls_token)
output = self.linear(cls_token) # [output_dim]
return output, attentions