@@ -264,7 +264,10 @@ def window_partition(x: ms.Tensor, window_size: int) -> Tuple[ms.Tensor, Tuple[i
264
264
pad_h = (window_size - H % window_size ) % window_size
265
265
pad_w = (window_size - W % window_size ) % window_size
266
266
if pad_h > 0 or pad_w > 0 :
267
- x = ops .pad (x , (0 , 0 , 0 , pad_w , 0 , pad_h ))
267
+ # replace ops.pad with ops.concat for better performance
268
+ pad_mat1 = ops .zeros ((B , H , pad_w , C ), x .dtype )
269
+ pad_mat2 = ops .zeros ((B , pad_h , W + pad_w , C ), x .dtype )
270
+ x = ops .concat ([ops .concat ([x , pad_mat1 ], axis = 2 ), pad_mat2 ], axis = 1 )
268
271
Hp , Wp = H + pad_h , W + pad_w
269
272
270
273
x = x .view (B , Hp // window_size , window_size , Wp // window_size , window_size , C )
@@ -401,7 +404,8 @@ def __init__(
401
404
)
402
405
403
406
def construct (self , x : ms .Tensor ) -> ms .Tensor :
404
- x = ops .pad (x , (self .padding [0 ], self .padding [0 ], self .padding [1 ], self .padding [1 ])) # to align with torch
407
+ if sum (self .padding ) > 0 :
408
+ x = ops .pad (x , (self .padding [0 ], self .padding [0 ], self .padding [1 ], self .padding [1 ])) # to align with torch
405
409
x = self .proj (x )
406
410
# B C H W -> B H W C
407
411
x = x .permute (0 , 2 , 3 , 1 )
0 commit comments