diff --git a/demo.py b/demo.py index c1bb3bf..77c4409 100644 --- a/demo.py +++ b/demo.py @@ -20,6 +20,13 @@ parser.add_argument('--style_seg_path', default=[]) parser.add_argument('--output_image_path', default='./results/example1.png') parser.add_argument('--cuda', type=int, default=1, help='Enable CUDA.') +parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file") +parser.add_argument("--engine", type=str, help="run serialized TRT engine") +parser.add_argument("--onnx", type=str, help="run ONNX model via TRT") +parser.add_argument('--verbose', action='store_true', default = False, help='toggles verbose') +parser.add_argument("-d", "--data_type", default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit") + + args = parser.parse_args() # Load model @@ -32,7 +39,7 @@ if args.cuda: p_wct.cuda(0) - + process_stylization.stylization( p_wct=p_wct, content_image_path=args.content_image_path, @@ -41,4 +48,5 @@ style_seg_path=args.style_seg_path, output_image_path=args.output_image_path, cuda=args.cuda, + args=args ) diff --git a/demo.sh b/demo.sh index 929e0c3..acb700d 100755 --- a/demo.sh +++ b/demo.sh @@ -7,4 +7,4 @@ axel -n 1 https://vignette.wikia.nocookie.net/strangerthings8338/images/e/e0/Wik convert -resize 25% content1.png content1.png; convert -resize 50% style1.png style1.png; cd ..; -python demo.py; +python demo.py $@; diff --git a/photo_wct.py b/photo_wct.py index 2fb6fda..7fffced 100644 --- a/photo_wct.py +++ b/photo_wct.py @@ -23,7 +23,7 @@ def __init__(self): self.d3 = VGGDecoder(3) self.e4 = VGGEncoder(4) self.d4 = VGGDecoder(4) - + def transform(self, cont_img, styl_img, cont_seg, styl_seg): self.__compute_label_info(cont_seg, styl_seg) @@ -53,8 +53,15 @@ def transform(self, cont_img, styl_img, cont_seg, styl_seg): csF1 = self.__feature_wct(cF1, sF1, cont_seg, styl_seg) Im1 = self.d1(csF1) return Im1 + + def forward(self, args): + [cont_img, styl_img, cont_seg, styl_seg] = args + print (cont_img, styl_img, cont_seg, styl_seg) + self.transform(cont_img, styl_img, cont_seg, styl_seg) def __compute_label_info(self, cont_seg, styl_seg): + cont_seg=cont_seg.numpy() + styl_seg=styl_seg.numpy() if cont_seg.size == False or styl_seg.size == False: return max_label = np.max(cont_seg) + 1 @@ -69,6 +76,8 @@ def __compute_label_info(self, cont_seg, styl_seg): self.label_indicator[l] = is_valid(o_cont_mask[0].size, o_styl_mask[0].size) def __feature_wct(self, cont_feat, styl_feat, cont_seg, styl_seg): + cont_seg = cont_seg.numpy() + styl_seg = styl_seg.numpy() cont_c, cont_h, cont_w = cont_feat.size(0), cont_feat.size(1), cont_feat.size(2) styl_c, styl_h, styl_w = styl_feat.size(0), styl_feat.size(1), styl_feat.size(2) cont_feat_view = cont_feat.view(cont_c, -1).clone() diff --git a/process_stylization.py b/process_stylization.py index c4ecd09..b900e4d 100644 --- a/process_stylization.py +++ b/process_stylization.py @@ -8,8 +8,10 @@ import time import numpy as np +import torch from PIL import Image from torch.autograd import Variable +from torch.onnx import export import torchvision.transforms as transforms import torchvision.utils as utils @@ -32,8 +34,7 @@ def __exit__(self, exc_type, exc_value, exc_tb): print(self.msg % (time.time() - self.start_time)) -def stylization(p_wct, content_image_path, style_image_path, content_seg_path, style_seg_path, output_image_path, - cuda): +def stylization(p_wct, content_image_path, style_image_path, content_seg_path, style_seg_path, output_image_path, cuda, args): # Load image cont_img = Image.open(content_image_path).convert('RGB') styl_img = Image.open(style_image_path).convert('RGB') @@ -52,12 +53,16 @@ def stylization(p_wct, content_image_path, style_image_path, content_seg_path, s styl_img = styl_img.cuda(0) p_wct.cuda(0) - cont_img = Variable(cont_img, volatile=True) - styl_img = Variable(styl_img, volatile=True) - - cont_seg = np.asarray(cont_seg) - styl_seg = np.asarray(styl_seg) - + cont_img = Variable(cont_img, requires_grad=False) + styl_img = Variable(styl_img, requires_grad=False) + cont_seg = torch.FloatTensor(np.asarray(cont_seg)) + styl_seg = torch.FloatTensor(np.asarray(styl_seg)) + + if args.export_onnx: + assert args.export_onnx.endswith(".onnx"), "Export model file should end with .onnx" + export(p_wct, [cont_img, styl_img, cont_seg, styl_seg], + f=args.export_onnx, verbose=args.verbose) + with Timer("Elapsed time in stylization: %f"): stylized_img = p_wct.transform(cont_img, styl_img, cont_seg, styl_seg) utils.save_image(stylized_img.data.cpu().float(), output_image_path, nrow=1)