forked from EricGuo5513/text-to-motion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_decomp_v3.py
100 lines (81 loc) · 3.81 KB
/
train_decomp_v3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
from os.path import join as pjoin
import utils.paramUtil as paramUtil
from options.train_options import TrainDecompOptions
from utils.plot_script import *
from networks.modules import *
from networks.trainers import DecompTrainerV3
from data.dataset import MotionDatasetV2
from scripts.motion_process import *
from torch.utils.data import DataLoader
from utils.word_vectorizer import WordVectorizer, POS_enumerator
def plot_t2m(data, save_dir):
data = train_dataset.inv_transform(data)
for i in range(len(data)):
joint_data = data[i]
joint = recover_from_ric(torch.from_numpy(joint_data).float(), opt.joints_num).numpy()
save_path = pjoin(save_dir, '%02d.mp4'%(i))
plot_3d_motion(save_path, kinematic_chain, joint, title="None", fps=fps, radius=radius)
if __name__ == '__main__':
parser = TrainDecompOptions()
opt = parser.parse()
opt.device = torch.device("cpu" if opt.gpu_id==-1 else "cuda:" + str(opt.gpu_id))
torch.autograd.set_detect_anomaly(True)
if opt.gpu_id != -1:
# self.opt.gpu_id = int(self.opt.gpu_id)
torch.cuda.set_device(opt.gpu_id)
opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
opt.model_dir = pjoin(opt.save_root, 'model')
opt.meta_dir = pjoin(opt.save_root, 'meta')
opt.eval_dir = pjoin(opt.save_root, 'animation')
opt.log_dir = pjoin('./log', opt.dataset_name, opt.name)
os.makedirs(opt.model_dir, exist_ok=True)
os.makedirs(opt.meta_dir, exist_ok=True)
os.makedirs(opt.eval_dir, exist_ok=True)
os.makedirs(opt.log_dir, exist_ok=True)
if opt.dataset_name == 't2m':
opt.data_root = './dataset/HumanML3D'
opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
opt.text_dir = pjoin(opt.data_root, 'texts')
opt.joints_num = 22
opt.max_motion_length = 196
dim_pose = 263
radius = 4
fps = 20
kinematic_chain = paramUtil.t2m_kinematic_chain
elif opt.dataset_name == 'kit':
opt.data_root = './dataset/KIT-ML'
opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
opt.text_dir = pjoin(opt.data_root, 'texts')
opt.joints_num = 21
radius = 240 * 8
fps = 12.5
dim_pose = 251
opt.max_motion_length = 196
kinematic_chain = paramUtil.kit_kinematic_chain
else:
raise KeyError('Dataset Does Not Exist')
mean = np.load(pjoin(opt.data_root, 'Mean.npy'))
std = np.load(pjoin(opt.data_root, 'Std.npy'))
w_vectorizer = WordVectorizer('./glove', 'our_vab')
train_split_file = pjoin(opt.data_root, 'train.txt')
val_split_file = pjoin(opt.data_root, 'val.txt')
movement_enc = MovementConvEncoder(dim_pose - 4, opt.dim_movement_enc_hidden, opt.dim_movement_latent)
movement_dec = MovementConvDecoder(opt.dim_movement_latent, opt.dim_movement_dec_hidden, dim_pose)
all_params = 0
pc_mov_enc = sum(param.numel() for param in movement_enc.parameters())
print(movement_enc)
print("Total parameters of prior net: {}".format(pc_mov_enc))
all_params += pc_mov_enc
pc_mov_dec = sum(param.numel() for param in movement_dec.parameters())
print(movement_dec)
print("Total parameters of posterior net: {}".format(pc_mov_dec))
all_params += pc_mov_dec
trainer = DecompTrainerV3(opt, movement_enc, movement_dec)
train_dataset = MotionDatasetV2(opt, mean, std, train_split_file)
val_dataset = MotionDatasetV2(opt, mean, std, val_split_file)
train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, drop_last=True, num_workers=4,
shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=opt.batch_size, drop_last=True, num_workers=4,
shuffle=True, pin_memory=True)
trainer.train(train_loader, val_loader, plot_t2m)