@@ -60,20 +60,20 @@ def forward(self, x):
6060 v = self .norm_v (v )
6161
6262 q = rearrange (q , 'b (h k) hh ww -> b h k (hh ww)' , h = h )
63- k = rearrange (k , 'b (k u) hh ww -> b k u (hh ww)' , u = u )
64- v = rearrange (v , 'b (v u) hh ww -> b v u (hh ww)' , u = u )
63+ k = rearrange (k , 'b (k u) hh ww -> b u k (hh ww)' , u = u )
64+ v = rearrange (v , 'b (v u) hh ww -> b u v (hh ww)' , u = u )
6565
6666 k = k .softmax (dim = - 1 )
6767
68- λc = einsum ('b k u m, b v u m -> b k v' , k , v )
68+ λc = einsum ('b u k m, b u v m -> b k v' , k , v )
6969 Yc = einsum ('b h k n, b k v -> b n h v' , q , λc )
7070
7171 if self .local_contexts :
72- v = rearrange (v , 'b v u (hh ww) -> b u v hh ww' , hh = hh , ww = ww )
72+ v = rearrange (v , 'b u v (hh ww) -> b u v hh ww' , hh = hh , ww = ww )
7373 λp = F .conv3d (v , self .R , padding = (0 , self .padding , self .padding ))
7474 Yp = einsum ('b h k n, b k v n -> b n h v' , q , λp .flatten (3 ))
7575 else :
76- λp = einsum ('n m k u, b v u m -> b n k v' , self .pos_emb , v )
76+ λp = einsum ('n m k u, b u v m -> b n k v' , self .pos_emb , v )
7777 Yp = einsum ('b h k n, b n k v -> b n h v' , q , λp )
7878
7979 Y = Yc + Yp
0 commit comments