1- from einops .layers .keras import Rearrange
2- from keras .layers import Conv2D , BatchNormalization , Conv3D , ZeroPadding3D , Softmax , Lambda , Add , Layer
3- from keras import initializers
1+ from einops .layers .tensorflow import Rearrange
2+ from tensorflow . keras .layers import Conv2D , BatchNormalization , Conv3D , ZeroPadding3D , Softmax , Lambda , Add , Layer
3+ from tensorflow . keras import initializers
44from tensorflow import einsum
55
66# helpers functions
@@ -12,7 +12,6 @@ def exists(val):
1212def default (val , d ):
1313 return val if exists (val ) else d
1414
15-
1615# lambda layer
1716
1817class LambdaLayer (Layer ):
@@ -46,8 +45,7 @@ def __init__(
4645 self .local_contexts = exists (r )
4746 if exists (r ):
4847 assert (r % 2 ) == 1 , 'Receptive kernel size should be odd'
49- self .pos_padding = ZeroPadding3D (padding = (0 , r // 2 , r // 2 ))
50- self .pos_conv = Conv3D (dim_k , (1 , r , r ), padding = 'valid' )
48+ self .pos_conv = Conv3D (dim_k , (1 , r , r ), padding = 'same' )
5149 else :
5250 assert exists (n ), 'You must specify the total sequence length (h x w)'
5351 self .pos_emb = self .add_weight (name = 'pos_emb' ,
@@ -56,7 +54,7 @@ def __init__(
5654 trainable = True )
5755
5856 def call (self , inputs , ** kwargs ):
59- b , c , hh , ww = inputs .get_shape ().as_list ()
57+ b , hh , ww , c = inputs .get_shape ().as_list ()
6058 u , h = self .u , self .heads
6159 x = inputs
6260
@@ -67,33 +65,32 @@ def call(self, inputs, **kwargs):
6765 q = self .norm_q (q )
6866 v = self .norm_v (v )
6967
70- q = Rearrange ('b (h k) hh ww -> b h k (hh ww)' , h = h )(q )
71- k = Rearrange ('b (u k) hh ww -> b u k (hh ww)' , u = u )(k )
72- v = Rearrange ('b (u v) hh ww -> b u v (hh ww)' , u = u )(v )
68+ q = Rearrange ('b hh ww (h k) -> b h k (hh ww)' , h = h )(q )
69+ k = Rearrange ('b hh ww (u k) -> b u k (hh ww)' , u = u )(k )
70+ v = Rearrange ('b hh ww (u v) -> b u v (hh ww)' , u = u )(v )
7371
7472 k = Softmax ()(k )
7573
7674 Lc = Lambda (lambda x : einsum ('b u k m, b u v m -> b k v' , x [0 ], x [1 ]))([k , v ])
7775 Yc = Lambda (lambda x : einsum ('b h k n, b k v -> b n h v' , x [0 ], x [1 ]))([q , Lc ])
7876
7977 if self .local_contexts :
80- v = Rearrange ('b u v (hh ww) -> b u v hh ww' , hh = hh , ww = ww )(v )
81- Lp = self .pos_padding (v )
82- Lp = self .pos_conv (Lp )
83- Lp = Rearrange ('b c k h w -> b c k (h w)' )(Lp )
84- Yp = Lambda (lambda x : einsum ('b h k n, b k v n -> b n h v' , x [0 ], x [1 ]))([q , Lp ])
78+ v = Rearrange ('b u v (hh ww) -> b v hh ww u' , hh = hh , ww = ww )(v )
79+ Lp = self .pos_conv (v )
80+ Lp = Rearrange ('b v h w k -> b v k (h w)' )(Lp )
81+ Yp = Lambda (lambda x : einsum ('b h k n, b v k n -> b n h v' , x [0 ], x [1 ]))([q , Lp ])
8582 else :
8683 Lp = Lambda (lambda x : einsum ('n m k u, b u v m -> b n k v' , x [0 ], x [1 ]))([self .pos_emb , v ])
8784 Yp = Lambda (lambda x : einsum ('b h k n, b n k v -> b n h v' , x [0 ], x [1 ]))([q , Lp ])
8885
8986 Y = Add ()([Yc , Yp ])
90- out = Rearrange ('b (hh ww) h v -> b (h v) hh ww ' , hh = hh , ww = ww )(Y )
87+ out = Rearrange ('b (hh ww) h v -> b hh ww (h v)' , hh = hh , ww = ww )(Y )
9188 return out
9289
9390 def compute_output_shape (self , input_shape ):
94- return (input_shape [0 ], self . out_dim , input_shape [2 ], input_shape [3 ] )
91+ return (input_shape [0 ], input_shape [1 ], input_shape [2 ], self . out_dim )
9592
9693 def get_config (self ):
97- config = {'output_dim' : (self .input_shape [0 ], self .out_dim , self .input_shape [2 ], self .input_shape [ 3 ] )}
94+ config = {'output_dim' : (self .input_shape [0 ], self .input_shape [ 1 ] , self .input_shape [2 ], self .out_dim )}
9895 base_config = super (LambdaLayer , self ).get_config ()
9996 return dict (list (base_config .items ()) + list (config .items ()))
0 commit comments