Skip to content

Commit 07a7fa0

Browse files
committed
add lambda convolutions, as described in the paper
1 parent 3ecddf5 commit 07a7fa0

File tree

3 files changed

+43
-9
lines changed

3 files changed

+43
-9
lines changed

README.md

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ $ pip install lambda-networks
1010

1111
## Usage
1212

13+
Global context
14+
1315
```python
1416
import torch
1517
from lambda_networks import LambdaLayer
@@ -18,7 +20,6 @@ layer = LambdaLayer(
1820
dim = 32, # channels going in
1921
dim_out = 32, # channels out
2022
n = 64 * 64, # number of input pixels (64 x 64 image)
21-
m = 64 * 64, # number of context (64 x 64 global)
2223
dim_k = 16, # key dimension
2324
heads = 4, # number of heads, for multi-query
2425
dim_u = 1 # 'intra-depth' dimension
@@ -28,9 +29,28 @@ x = torch.randn(1, 32, 64, 64)
2829
layer(x) # (1, 32, 64, 64)
2930
```
3031

32+
Localized context
33+
34+
```python
35+
import torch
36+
from lambda_networks import LambdaLayer
37+
38+
layer = LambdaLayer(
39+
dim = 32,
40+
dim_out = 32,
41+
r = 23, # the receptive field for relative positional encoding (23 x 23)
42+
dim_k = 16,
43+
heads = 4,
44+
dim_u = 4
45+
)
46+
47+
x = torch.randn(1, 32, 64, 64)
48+
layer(x) # (1, 32, 64, 64)
49+
```
50+
3151
## Todo
3252

33-
- [ ] Lambda layers with structured context
53+
- [x] Lambda layers with structured context
3454
- [ ] Document hyperparameters and put some sensible defaults
3555
- [ ] Test it out
3656

lambda_networks/lambda_networks.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ def __init__(
1818
self,
1919
dim,
2020
*,
21-
n,
22-
m,
2321
dim_k,
22+
n = None,
23+
r = None,
2424
heads = 4,
2525
dim_out = None,
2626
dim_u = 1):
@@ -35,11 +35,20 @@ def __init__(
3535
self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False)
3636
self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False)
3737
self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False)
38-
self.pos_emb = nn.Parameter(torch.randn(n, m, dim_k, dim_u))
3938

4039
self.norm_q = nn.BatchNorm2d(dim_k * heads)
4140
self.norm_v = nn.BatchNorm2d(dim_v * dim_u)
4241

42+
self.local_contexts = exists(r)
43+
if exists(r):
44+
assert (r % 2) == 1, 'Receptive kernel size should be odd'
45+
self.padding = r // 2
46+
self.R = nn.Parameter(torch.randn(dim_k, dim_u, 1, r, r))
47+
else:
48+
assert exists(n), 'You must specify the total sequence length (h x w)'
49+
self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u))
50+
51+
4352
def forward(self, x):
4453
b, c, hh, ww, u, h = *x.shape, self.u, self.heads
4554

@@ -57,10 +66,15 @@ def forward(self, x):
5766
k = k.softmax(dim=-1)
5867

5968
λc = einsum('b k u m, b v u m -> b k v', k, v)
60-
λp = einsum('n m k u, b v u m -> b n k v', self.pos_emb, v)
61-
6269
Yc = einsum('b h k n, b k v -> b n h v', q, λc)
63-
Yp = einsum('b h k n, b n k v -> b n h v', q, λp)
70+
71+
if self.local_contexts:
72+
v = rearrange(v, 'b v u (hh ww) -> b u v hh ww', hh = hh, ww = ww)
73+
λp = F.conv3d(v, self.R, padding = (0, self.padding, self.padding))
74+
Yp = einsum('b h k n, b k v n -> b n h v', q, λp.flatten(3))
75+
else:
76+
λp = einsum('n m k u, b v u m -> b n k v', self.pos_emb, v)
77+
Yp = einsum('b h k n, b n k v -> b n h v', q, λp)
6478

6579
Y = Yc + Yp
6680
out = rearrange(Y, 'b (hh ww) h v -> b (h v) hh ww', hh = hh, ww = ww)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'lambda-networks',
55
packages = find_packages(),
6-
version = '0.0.2',
6+
version = '0.1.0',
77
license='MIT',
88
description = 'Lambda Networks - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)