diff --git a/main_test_swinir.py b/main_test_swinir.py index f06b5f39d..21eb9b050 100644 --- a/main_test_swinir.py +++ b/main_test_swinir.py @@ -67,8 +67,8 @@ def main(): with torch.no_grad(): # pad input image to be a multiple of window_size _, _, h_old, w_old = img_lq.size() - h_pad = (h_old // window_size + 1) * window_size - h_old - w_pad = (w_old // window_size + 1) * window_size - w_old + h_pad = (window_size - h_old % window_size) % window_size + w_pad = (window_size - w_old % window_size) % window_size img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :] img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad] output = test(img_lq, model, args, window_size) diff --git a/predict.py b/predict.py index c0f6b715d..384d8b623 100644 --- a/predict.py +++ b/predict.py @@ -129,8 +129,8 @@ def predict(self, image, task_type='Real-World Image Super-Resolution', jpeg=40, with torch.no_grad(): # pad input image to be a multiple of window_size _, _, h_old, w_old = img_lq.size() - h_pad = (h_old // window_size + 1) * window_size - h_old - w_pad = (w_old // window_size + 1) * window_size - w_old + h_pad = (window_size - h_old % window_size) % window_size + w_pad = (window_size - w_old % window_size) % window_size img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :] img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad] output = model(img_lq)