Skip to content

Commit 3d581cf

Browse files
committed
fix a bunch of bugs with tf/keras version
1 parent d5abbf3 commit 3d581cf

File tree

3 files changed

+21
-20
lines changed

3 files changed

+21
-20
lines changed

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,19 @@ from lambda_networks import λLayer
6363
<a href="https://github.com/shinel94">Shinel94</a> has added a Keras implementation! It won't be officially supported in this repository, so either copy / paste the code under `./lambda_networks/tfkeras.py` or make sure to install `tensorflow` and `keras` before running the following.
6464

6565
```python
66+
import tensorflow as tf
6667
from lambda_networks.tfkeras import LambdaLayer
6768

6869
layer = LambdaLayer(
6970
dim_out = 32,
7071
r = 23,
7172
dim_k = 16,
7273
heads = 4,
73-
dim_u = 4
74+
dim_u = 1
7475
)
76+
77+
x = tf.random.normal((1, 64, 64, 16))
78+
layer(x)
7579
```
7680

7781
## Citations

lambda_networks/tfkeras.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from einops.layers.keras import Rearrange
2-
from keras.layers import Conv2D, BatchNormalization, Conv3D, ZeroPadding3D, Softmax, Lambda, Add, Layer
3-
from keras import initializers
1+
from einops.layers.tensorflow import Rearrange
2+
from tensorflow.keras.layers import Conv2D, BatchNormalization, Conv3D, ZeroPadding3D, Softmax, Lambda, Add, Layer
3+
from tensorflow.keras import initializers
44
from tensorflow import einsum
55

66
# helpers functions
@@ -12,7 +12,6 @@ def exists(val):
1212
def default(val, d):
1313
return val if exists(val) else d
1414

15-
1615
# lambda layer
1716

1817
class LambdaLayer(Layer):
@@ -46,8 +45,7 @@ def __init__(
4645
self.local_contexts = exists(r)
4746
if exists(r):
4847
assert (r % 2) == 1, 'Receptive kernel size should be odd'
49-
self.pos_padding = ZeroPadding3D(padding=(0, r//2, r//2))
50-
self.pos_conv = Conv3D(dim_k, (1, r, r), padding='valid')
48+
self.pos_conv = Conv3D(dim_k, (1, r, r), padding='same')
5149
else:
5250
assert exists(n), 'You must specify the total sequence length (h x w)'
5351
self.pos_emb = self.add_weight(name='pos_emb',
@@ -56,7 +54,7 @@ def __init__(
5654
trainable=True)
5755

5856
def call(self, inputs, **kwargs):
59-
b, c, hh, ww = inputs.get_shape().as_list()
57+
b, hh, ww, c = inputs.get_shape().as_list()
6058
u, h = self.u, self.heads
6159
x = inputs
6260

@@ -67,33 +65,32 @@ def call(self, inputs, **kwargs):
6765
q = self.norm_q(q)
6866
v = self.norm_v(v)
6967

70-
q = Rearrange('b (h k) hh ww -> b h k (hh ww)', h=h)(q)
71-
k = Rearrange('b (u k) hh ww -> b u k (hh ww)', u=u)(k)
72-
v = Rearrange('b (u v) hh ww -> b u v (hh ww)', u=u)(v)
68+
q = Rearrange('b hh ww (h k) -> b h k (hh ww)', h=h)(q)
69+
k = Rearrange('b hh ww (u k) -> b u k (hh ww)', u=u)(k)
70+
v = Rearrange('b hh ww (u v) -> b u v (hh ww)', u=u)(v)
7371

7472
k = Softmax()(k)
7573

7674
Lc = Lambda(lambda x: einsum('b u k m, b u v m -> b k v', x[0], x[1]))([k, v])
7775
Yc = Lambda(lambda x: einsum('b h k n, b k v -> b n h v', x[0], x[1]))([q, Lc])
7876

7977
if self.local_contexts:
80-
v = Rearrange('b u v (hh ww) -> b u v hh ww', hh=hh, ww=ww)(v)
81-
Lp = self.pos_padding(v)
82-
Lp = self.pos_conv(Lp)
83-
Lp = Rearrange('b c k h w -> b c k (h w)')(Lp)
84-
Yp = Lambda(lambda x: einsum('b h k n, b k v n -> b n h v', x[0], x[1]))([q, Lp])
78+
v = Rearrange('b u v (hh ww) -> b v hh ww u', hh=hh, ww=ww)(v)
79+
Lp = self.pos_conv(v)
80+
Lp = Rearrange('b v h w k -> b v k (h w)')(Lp)
81+
Yp = Lambda(lambda x: einsum('b h k n, b v k n -> b n h v', x[0], x[1]))([q, Lp])
8582
else:
8683
Lp = Lambda(lambda x: einsum('n m k u, b u v m -> b n k v', x[0], x[1]))([self.pos_emb, v])
8784
Yp = Lambda(lambda x: einsum('b h k n, b n k v -> b n h v', x[0], x[1]))([q, Lp])
8885

8986
Y = Add()([Yc, Yp])
90-
out = Rearrange('b (hh ww) h v -> b (h v) hh ww', hh = hh, ww = ww)(Y)
87+
out = Rearrange('b (hh ww) h v -> b hh ww (h v)', hh = hh, ww = ww)(Y)
9188
return out
9289

9390
def compute_output_shape(self, input_shape):
94-
return (input_shape[0], self.out_dim, input_shape[2], input_shape[3])
91+
return (input_shape[0], input_shape[1], input_shape[2], self.out_dim)
9592

9693
def get_config(self):
97-
config = {'output_dim': (self.input_shape[0], self.out_dim, self.input_shape[2], self.input_shape[3])}
94+
config = {'output_dim': (self.input_shape[0], self.input_shape[1], self.input_shape[2], self.out_dim)}
9895
base_config = super(LambdaLayer, self).get_config()
9996
return dict(list(base_config.items()) + list(config.items()))

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'lambda-networks',
55
packages = find_packages(),
6-
version = '0.3.0',
6+
version = '0.3.1',
77
license='MIT',
88
description = 'Lambda Networks - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)