Skip to content

Commit d5abbf3

Browse files
committed
add tf/keras version to pypi and update readme
1 parent 865b7ab commit d5abbf3

File tree

3 files changed

+36
-30
lines changed

3 files changed

+36
-30
lines changed

README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,22 @@ For fun, you can also import this as follows
5858
from lambda_networks import λLayer
5959
```
6060

61+
## Tensorflow / Keras version
62+
63+
<a href="https://github.com/shinel94">Shinel94</a> has added a Keras implementation! It won't be officially supported in this repository, so either copy / paste the code under `./lambda_networks/tfkeras.py` or make sure to install `tensorflow` and `keras` before running the following.
64+
65+
```python
66+
from lambda_networks.tfkeras import LambdaLayer
67+
68+
layer = LambdaLayer(
69+
dim_out = 32,
70+
r = 23,
71+
dim_k = 16,
72+
heads = 4,
73+
dim_u = 4
74+
)
75+
```
76+
6177
## Citations
6278

6379
```bibtex

lambda_networks/tfkeras_lambda_networks.py renamed to lambda_networks/tfkeras.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from einops.layers.keras import Rearrange
2-
from keras.layers import Conv2D, BatchNormalization, Conv3D, ZeroPadding3D, Softmax, Lambda, Add
2+
from keras.layers import Conv2D, BatchNormalization, Conv3D, ZeroPadding3D, Softmax, Lambda, Add, Layer
3+
from keras import initializers
34
from tensorflow import einsum
45

56
# helpers functions
@@ -16,15 +17,16 @@ def default(val, d):
1617

1718
class LambdaLayer(Layer):
1819
def __init__(
19-
self,
20-
*,
21-
dim_k,
22-
n=None,
23-
r=None,
24-
heads=4,
25-
dim_out=None,
26-
dim_u=1):
20+
self,
21+
*,
22+
dim_k,
23+
n = None,
24+
r = None,
25+
heads = 4,
26+
dim_out = None,
27+
dim_u = 1):
2728
super(LambdaLayer, self).__init__()
29+
2830
self.out_dim = dim_out
2931
self.u = dim_u # intra-depth dimension
3032
self.heads = heads
@@ -33,9 +35,13 @@ def __init__(
3335
self.dim_v = dim_out // heads
3436
self.dim_k = dim_k
3537
self.heads = heads
36-
self.dim_u = dim_u
37-
self.r = r
38-
self.n = n
38+
39+
self.to_q = Conv2D(self.dim_k * heads, 1, use_bias=False)
40+
self.to_k = Conv2D(self.dim_k * dim_u, 1, use_bias=False)
41+
self.to_v = Conv2D(self.dim_v * dim_u, 1, use_bias=False)
42+
43+
self.norm_q = BatchNormalization()
44+
self.norm_v = BatchNormalization()
3945

4046
self.local_contexts = exists(r)
4147
if exists(r):
@@ -49,27 +55,11 @@ def __init__(
4955
initializer=initializers.random_normal,
5056
trainable=True)
5157

52-
self.to_q = Conv2D(self.dim_k * self.heads, 1, bias=False)
53-
self.to_k = Conv2D(self.dim_k * self.dim_u, 1, bias=False)
54-
self.to_v = Conv2D(self.dim_v * self.dim_u, 1, bias=False)
55-
self.norm_q = BatchNormalization()
56-
self.norm_v = BatchNormalization()
57-
self.local_contexts = exists(self.r)
58-
if exists(self.r):
59-
assert (self.r % 2) == 1, 'Receptive kernel size should be odd'
60-
self.pos_padding = ZeroPadding3D(padding=(0, self.r // 2, self.r // 2))
61-
self.pos_conv = Conv3D(self.dim_k, (1, self.r, self.r), padding='valid')
62-
else:
63-
assert exists(self.n), 'You must specify the total sequence length (h x w)'
64-
self.pos_emb = self.add_weight(name='pos_emb',
65-
shape=(self.n, self.n, self.dim_k, self.dim_u),
66-
initializer=initializers.random_normal,
67-
trainable=True)
68-
6958
def call(self, inputs, **kwargs):
7059
b, c, hh, ww = inputs.get_shape().as_list()
7160
u, h = self.u, self.heads
7261
x = inputs
62+
7363
q = self.to_q(x)
7464
k = self.to_k(x)
7565
v = self.to_v(x)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'lambda-networks',
55
packages = find_packages(),
6-
version = '0.2.2',
6+
version = '0.3.0',
77
license='MIT',
88
description = 'Lambda Networks - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)