-
Notifications
You must be signed in to change notification settings - Fork 125
/
Copy pathinfer_hf.py
70 lines (54 loc) · 1.87 KB
/
infer_hf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import cv2
import torch
import argparse
from tqdm import tqdm
from huggingface_hub import PyTorchModelHubMixin
from ddcolor_model import DDColor
from infer import ImageColorizationPipeline
class DDColorHF(DDColor, PyTorchModelHubMixin):
def __init__(self, config):
super().__init__(**config)
class ImageColorizationPipelineHF(ImageColorizationPipeline):
def __init__(self, model, input_size):
self.input_size = input_size
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.model = model.to(self.device)
self.model.eval()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="ddcolor_modelscope")
parser.add_argument(
"--input",
type=str,
default="figure/",
help="input test image folder or video path",
)
parser.add_argument(
"--output", type=str, default="results", help="output folder or video path"
)
parser.add_argument(
"--input_size", type=int, default=512, help="input size for model"
)
args = parser.parse_args()
if not os.path.exists(args.model_name):
model_name = f"piddnad/{args.model_name}"
else:
model_name = args.model_name
ddcolor_model = DDColorHF.from_pretrained(model_name)
print(f"Output path: {args.output}")
os.makedirs(args.output, exist_ok=True)
img_list = os.listdir(args.input)
assert len(img_list) > 0
colorizer = ImageColorizationPipelineHF(
model=ddcolor_model, input_size=args.input_size
)
for name in tqdm(img_list):
img = cv2.imread(os.path.join(args.input, name))
image_out = colorizer.process(img)
cv2.imwrite(os.path.join(args.output, name), image_out)
if __name__ == "__main__":
main()