Skip to content

Commit 9d92545

Browse files
committed
cleanup
1 parent 50fe8cd commit 9d92545

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

lambda_networks/lambda_networks.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,20 @@ def forward(self, x):
6060
v = self.norm_v(v)
6161

6262
q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h)
63-
k = rearrange(k, 'b (k u) hh ww -> b k u (hh ww)', u = u)
64-
v = rearrange(v, 'b (v u) hh ww -> b v u (hh ww)', u = u)
63+
k = rearrange(k, 'b (k u) hh ww -> b u k (hh ww)', u = u)
64+
v = rearrange(v, 'b (v u) hh ww -> b u v (hh ww)', u = u)
6565

6666
k = k.softmax(dim=-1)
6767

68-
λc = einsum('b k u m, b v u m -> b k v', k, v)
68+
λc = einsum('b u k m, b u v m -> b k v', k, v)
6969
Yc = einsum('b h k n, b k v -> b n h v', q, λc)
7070

7171
if self.local_contexts:
72-
v = rearrange(v, 'b v u (hh ww) -> b u v hh ww', hh = hh, ww = ww)
72+
v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww)
7373
λp = F.conv3d(v, self.R, padding = (0, self.padding, self.padding))
7474
Yp = einsum('b h k n, b k v n -> b n h v', q, λp.flatten(3))
7575
else:
76-
λp = einsum('n m k u, b v u m -> b n k v', self.pos_emb, v)
76+
λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
7777
Yp = einsum('b h k n, b n k v -> b n h v', q, λp)
7878

7979
Y = Yc + Yp

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'lambda-networks',
55
packages = find_packages(),
6-
version = '0.1.0',
6+
version = '0.1.1',
77
license='MIT',
88
description = 'Lambda Networks - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)