-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGenGAN.py
111 lines (83 loc) · 3.26 KB
/
GenGAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import numpy as np
import cv2
import os
import pickle
import sys
import math
import matplotlib.pyplot as plt
from torchvision.io import read_image
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from VideoSkeleton import VideoSkeleton
from VideoReader import VideoReader
from Skeleton import Skeleton
from GenVanillaNN import *
class Discriminator(nn.Module):
def __init__(self, ngpu=0):
super(Discriminator, self).__init__()
self.ngpu = ngpu
def forward(self, input):
pass
#return self.model(input)
class GenGAN():
""" class that Generate a new image from videoSke from a new skeleton posture
Fonc generator(Skeleton)->Image
"""
def __init__(self, videoSke, loadFromFile=False):
self.netG = GenNNSkeToImage()
self.netD = Discriminator()
self.real_label = 1.
self.fake_label = 0.
self.filename = 'data/Dance/DanceGenGAN.pth'
tgt_transform = transforms.Compose(
[transforms.Resize((64, 64)),
#transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
self.dataset = VideoSkeletonDataset(videoSke, ske_reduced=True, target_transform=tgt_transform)
self.dataloader = torch.utils.data.DataLoader(dataset=self.dataset, batch_size=32, shuffle=True)
if loadFromFile and os.path.isfile(self.filename):
print("GenGAN: Load=", self.filename, " Current Working Directory=", os.getcwd())
self.netG = torch.load(self.filename)
def train(self, n_epochs=20):
pass
def generate(self, ske): # TP-TODO
""" generator of image from skeleton """
pass
# ske_t = torch.from_numpy( ske.__array__(reduced=True).flatten() )
# ske_t = ske_t.to(torch.float32)
# ske_t = ske_t.reshape(1,Skeleton.reduced_dim,1,1) # ske.reshape(1,Skeleton.full_dim,1,1)
# normalized_output = self.netG(ske_t)
# res = self.dataset.tensor2image(normalized_output[0])
# return res
if __name__ == '__main__':
force = False
if len(sys.argv) > 1:
filename = sys.argv[1]
if len(sys.argv) > 2:
force = sys.argv[2].lower() == "true"
else:
filename = "data/taichi1.mp4"
print("GenGAN: Current Working Directory=", os.getcwd())
print("GenGAN: Filename=", filename)
targetVideoSke = VideoSkeleton(filename)
#if False:
if True: # train or load
# Train
gen = GenGAN(targetVideoSke, False)
gen.train(4) #5) #200)
else:
gen = GenGAN(targetVideoSke, loadFromFile=True) # load from file
for i in range(targetVideoSke.skeCount()):
image = gen.generate(targetVideoSke.ske[i])
#image = image*255
nouvelle_taille = (256, 256)
image = cv2.resize(image, nouvelle_taille)
cv2.imshow('Image', image)
key = cv2.waitKey(-1)