-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathinference_image.py
81 lines (59 loc) · 2.47 KB
/
inference_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
""" ************************************************
* fileName: inference_image.py
* desc: inference a
* author: mingdeng_cao
* date: 2022/03/08 19:31
* last revised: None
************************************************ """
import os
import time
import torch
import argparse
import cv2
import numpy as np
from easydict import EasyDict as edict
from torchvision.utils import save_image
from simdeblur.config.build import build_config
from simdeblur.model.build import build_backbone, build_meta_arch
def parse_arguments():
parser = argparse.ArgumentParser(description="Parameters during inference of SimDeblur")
parser.add_argument("config_file", default="", help="the path of config file")
parser.add_argument("ckpt_file", default="", help="the path of checkpoint file")
parser.add_argument("--img", help="the path of input blurry image")
parser.add_argument("--save_path", default=None, help="the dir to save inference resutls")
args = parser.parse_args()
return args
def inference():
# read arguments
args = parse_arguments()
cfg = build_config(args.config_file)
cfg.args = edict(vars(args))
# construct model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
arch = build_meta_arch(cfg)
# load the trained checkpoint
try:
kwargs = {'map_location': lambda storage,
loc: storage.cuda(0)}
ckpt = torch.load(os.path.abspath(cfg.args.ckpt_file), **kwargs)
arch.load_ckpt(ckpt, strict=True)
print(f"Using checkpoint loaded from {cfg.args.ckpt_file} for testing.")
except Exception as e:
print(e)
print(f"Checkpoint loaded failed, cannot find ckpt file from {cfg.args.ckpt_file}.")
arch.model.eval()
# read input image at RGB format, shape(1, 3, H, W)
input_image = {"input_frames": torch.Tensor(np.ascontiguousarray(cv2.imread(args.img)[..., ::-1]/255.)).permute(2, 0, 1).unsqueeze(0).unsqueeze(0).to(device)}
with torch.no_grad():
if hasattr(arch, "inference"):
outputs = arch.postprocess(arch.inference(arch.preprocess(input_image)))
else:
outputs = arch.postprocess(arch.model(arch.preprocess(input_image)))
if args.save_path is None:
save_path = "./inference_resutls"
os.makedirs(save_path, exist_ok=True)
else:
os.makedirs(args.save_path)
save_image(outputs.clamp(0, 1), os.path.join(save_path, "infer_output.png"))
if __name__ == "__main__":
inference()