diff --git a/tools/demo.py b/tools/demo.py index bf171fe..e437c1a 100644 --- a/tools/demo.py +++ b/tools/demo.py @@ -10,7 +10,7 @@ import numpy as np import cv2 -import lib.transform_cv2 as T +import lib.data.transform_cv2 as T from lib.models import model_factory from configs import set_cfg_from_file @@ -34,7 +34,7 @@ palette = np.random.randint(0, 256, (256, 3), dtype=np.uint8) # define model -net = model_factory[cfg.model_type](cfg.n_cats, aux_mode='pred') +net = model_factory[cfg.model_type](cfg.n_cats, aux_mode='eval') net.load_state_dict(torch.load(args.weight_path, map_location='cpu'), strict=False) net.eval() net.cuda() @@ -53,8 +53,9 @@ # inference im = F.interpolate(im, size=new_size, align_corners=False, mode='bilinear') -out = net(im) +out = net(im)[0] out = F.interpolate(out, size=org_size, align_corners=False, mode='bilinear') +out = out.argmax(dim=1) # visualize out = out.squeeze().detach().cpu().numpy() diff --git a/tools/demo_video.py b/tools/demo_video.py index 59cc12d..d3df05f 100644 --- a/tools/demo_video.py +++ b/tools/demo_video.py @@ -5,13 +5,13 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.multiprocessing import Process, Queue +import torch.multiprocessing as mp import time from PIL import Image import numpy as np import cv2 -import lib.transform_cv2 as T +import lib.data.transform_cv2 as T from lib.models import model_factory from configs import set_cfg_from_file @@ -40,7 +40,7 @@ def get_model(): # fetch frames -def get_func(inpth, in_q): +def get_func(inpth, in_q, done): cap = cv2.VideoCapture(args.input) width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # type is float height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # type is float @@ -59,7 +59,8 @@ def get_func(inpth, in_q): in_q.put(frame) in_q.put('quit') - while not in_q.empty(): continue + done.wait() + cap.release() time.sleep(1) print('input queue done') @@ -105,14 +106,15 @@ def infer_batch(frames): if __name__ == '__main__': - torch.multiprocessing.set_start_method('spawn') + mp.set_start_method('spawn') - in_q = Queue(1024) - out_q = Queue(1024) + in_q = mp.Queue(1024) + out_q = mp.Queue(1024) + done = mp.Event() - in_worker = Process(target=get_func, - args=(args.input, in_q)) - out_worker = Process(target=save_func, + in_worker = mp.Process(target=get_func, + args=(args.input, in_q, done)) + out_worker = mp.Process(target=save_func, args=(args.input, args.output, out_q)) in_worker.start() @@ -133,6 +135,7 @@ def infer_batch(frames): infer_batch(frames) out_q.put('quit') + done.set() out_worker.join() in_worker.join()