diff --git a/main.py b/main.py index f2ee4729..4fb91f34 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,6 @@ import tensorflow as tf config = tf.ConfigProto() -config.gpu_options.allow_growth = False +config.gpu_options.allow_growth = True session = tf.Session(config = config) import os, logging, shutil, datetime @@ -430,6 +430,31 @@ def transfer(self, x_encode, c_encode, c_decode): self.cn: c_encode, self.c_generator: c_decode}) + def prepare_tranfer(self, src_img_path, src_joints_path, tar_img_path, tar_joints_path): + from batches import load_img, make_joint_img, preprocess, normalize + joint_order = ['cnose', 'cneck', 'rshoulder', 'relbow', 'rwrist', + 'lshoulder', 'lelbow', 'lwrist', 'rhip', + 'rknee', 'rankle', 'lhip', 'lknee', 'lankle', + 'reye', 'leye', 'rear', 'lear'] + imgs = [load_img(src_img_path, target_size = self.img_shape), load_img(tar_img_path, target_size = self.img_shape)] + imgs = np.stack(imgs) + imgs = preprocess(imgs) + + h,w = self.img_shape[:2] + wh = np.array([w,h]) + joints_coordinates = [np.load(src_joints_path)*wh, np.load(tar_joints_path)*wh] + joints = [] + for joints_c in joints_coordinates: + joints.append(make_joint_img(self.img_shape, joint_order, joints_c)) + joints = np.stack(joints) + joints = preprocess(joints) + + nimgs, njoints = normalize(imgs, joints_coordinates, joints, joint_order, 2) + x_encode = np.stack([nimgs[0]]) + c_encode = np.stack([njoints[0]]) + c_decode = np.stack([joints[1]]) + return x_encode, c_encode, c_decode + if __name__ == "__main__": default_log_dir = os.path.join(os.getcwd(), "log") @@ -441,6 +466,10 @@ def transfer(self, x_encode, c_encode, c_decode): parser.add_argument("--log_dir", default = default_log_dir, help = "path to log into") parser.add_argument("--checkpoint", help = "path to checkpoint to restore") parser.add_argument("--retrain", dest = "retrain", action = "store_true", help = "reset global_step to zero") + parser.add_argument("--src_img", help = "path to src_img") + parser.add_argument("--tar_img", help = "path to tar_img") + parser.add_argument("--src_jo", help = "path to src_jo") + parser.add_argument("--tar_jo", help = "path to tar_jo") parser.set_defaults(retrain = False) opt = parser.parse_args() @@ -474,5 +503,16 @@ def transfer(self, x_encode, c_encode, c_decode): if opt.retrain: model.reset_global_step() model.fit(batches, valid_batches) + elif opt.mode == "transfer": + if not opt.checkpoint: + raise Exception("transfer requires --checkpoint") + config['batch_size'] = 1 + config['box_factor'] = 2 + model = Model(config, out_dir, logger) + model.restore_graph(opt.checkpoint) + x_encode, c_encode, c_decode = model.prepare_tranfer(opt.src_img, opt.src_jo, opt.tar_img, opt.tar_jo) + x_gen = model.transfer(x_encode, c_encode, c_decode) + plot_batch(x_gen, os.path.join(out_dir, "testing.png")) + else: raise NotImplemented() diff --git a/sample/img_a.jpg b/sample/img_a.jpg new file mode 100755 index 00000000..be44378d Binary files /dev/null and b/sample/img_a.jpg differ diff --git a/sample/img_b.jpg b/sample/img_b.jpg new file mode 100755 index 00000000..0a8abe4f Binary files /dev/null and b/sample/img_b.jpg differ diff --git a/sample/jo_a.npy b/sample/jo_a.npy new file mode 100644 index 00000000..2ede6961 Binary files /dev/null and b/sample/jo_a.npy differ diff --git a/sample/jo_b.npy b/sample/jo_b.npy new file mode 100644 index 00000000..cd4e668c Binary files /dev/null and b/sample/jo_b.npy differ diff --git a/sample/testing.png b/sample/testing.png new file mode 100644 index 00000000..105e3d44 Binary files /dev/null and b/sample/testing.png differ