-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathargs.py
96 lines (85 loc) · 5.16 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import json
import os
from consts import *
def parse_opt():
machine_specific_config = json.load(open(os.path.dirname(os.path.realpath(__file__)) + '/machine_specific_config.json', 'r'))
parser = argparse.ArgumentParser()
parser.add_argument("--exp_name", default="no_data_filter", help="save to project/name")
parser.add_argument("--with_arm", type=bool, default=False, help="whether osim model has arm DoFs")
parser.add_argument("--with_kinematics_vel", type=bool, default=True, help="whether to include 1st derivative of kinematics")
parser.add_argument("--log_with_wandb", type=bool, default=machine_specific_config['log_with_wandb'], help="log with wandb")
parser.add_argument("--epochs", type=int, default=7680)
parser.add_argument("--target_sampling_rate", type=int, default=100)
parser.add_argument("--window_len", type=int, default=150)
parser.add_argument("--guide_x_start_the_beginning_step", type=int, default=-10) # negative value means no guidance
parser.add_argument("--project", default="runs/train", help="project/name")
parser.add_argument(
"--processed_data_dir",
type=str,
default="dataset_backups/",
help="Dataset backup path",
)
parser.add_argument("--feature_type", type=str, default="jukebox")
parser.add_argument(
"--wandb_pj_name", type=str, default="MotionModel", help="project name"
)
parser.add_argument("--batch_size", type=int, default=machine_specific_config['batch_size'], help="batch size")
parser.add_argument("--batch_size_inference", type=int, default=128, help="batch size during inference")
parser.add_argument("--pseudo_dataset_len", type=int, default=machine_specific_config['pseudo_dataset_len'], help="pseudo dataset length")
parser.add_argument(
"--force_reload", action="store_true", help="force reloads the datasets"
)
parser.add_argument(
"--no_cache", action="store_true", help="don't reuse / cache loaded dataset"
)
parser.add_argument(
"--save_interval",
type=int,
default=50,
help='Log model after every "save_period" epoch',
)
parser.add_argument("--ema_interval", type=int, default=1, help="ema every x steps")
parser.add_argument(
"--checkpoint", type=str, default="", help="trained checkpoint path (optional)"
)
parser.add_argument(
"--checkpoint_bl", type=str, default="", help="trained checkpoint path (optional)"
)
opt = parser.parse_args()
opt.data_path_parent = machine_specific_config['b3d_path']
opt.use_server = machine_specific_config['use_server']
set_with_arm_opt(opt, opt.with_arm)
return opt
def set_with_arm_opt(opt, with_arm):
if with_arm:
opt.with_arm = True
opt.osim_dof_columns = copy.deepcopy(OSIM_DOF_ALL + KINETICS_ALL)
opt.joints_3d = JOINTS_3D_ALL
data_path = opt.data_path_parent + '/b3d_with_arm/'
opt.data_path_osim_model = opt.data_path_parent + 'osim_model/unscaled_generic_with_arm.osim'
opt.model_states_column_names = copy.deepcopy(MODEL_STATES_COLUMN_NAMES_WITH_ARM)
else:
opt.with_arm = False
opt.osim_dof_columns = copy.deepcopy(OSIM_DOF_ALL[:23] + KINETICS_ALL)
opt.joints_3d = {key_: value_ for key_, value_ in JOINTS_3D_ALL.items() if key_ in ['pelvis', 'hip_r', 'hip_l', 'lumbar']}
data_path = opt.data_path_parent + '/b3d_no_arm/'
opt.data_path_osim_model = opt.data_path_parent + 'osim_model/unscaled_generic_no_arm.osim'
opt.model_states_column_names = copy.deepcopy(MODEL_STATES_COLUMN_NAMES_NO_ARM)
for joint_name, joints_with_3_dof in opt.joints_3d.items():
opt.model_states_column_names = opt.model_states_column_names + [
joint_name + '_' + axis + '_angular_vel' for axis in ['x', 'y', 'z']]
if opt.with_kinematics_vel:
opt.model_states_column_names = opt.model_states_column_names + [
f'{col}_vel' for i_col, col in enumerate(opt.model_states_column_names)
if not sum([term in col for term in ['force', 'pelvis_', '_vel', '_0', '_1', '_2', '_3', '_4', '_5']])]
opt.data_path_train = data_path + 'train_cleaned/'
opt.data_path_test = data_path + 'test_cleaned/'
opt.knee_diffusion_col_loc = [i_col for i_col, col in enumerate(opt.model_states_column_names) if 'knee' in col]
opt.ankle_diffusion_col_loc = [i_col for i_col, col in enumerate(opt.model_states_column_names) if 'ankle' in col]
opt.hip_diffusion_col_loc = [i_col for i_col, col in enumerate(opt.model_states_column_names) if 'hip' in col]
opt.kinematic_diffusion_col_loc = [i_col for i_col, col in enumerate(opt.model_states_column_names) if 'force' not in col]
opt.kinetic_diffusion_col_loc = [i_col for i_col, col in enumerate(opt.model_states_column_names) if i_col not in opt.kinematic_diffusion_col_loc]
opt.grf_osim_col_loc = [i_col for i_col, col in enumerate(opt.osim_dof_columns) if 'force' in col and '_cop_' not in col]
opt.cop_osim_col_loc = [i_col for i_col, col in enumerate(opt.osim_dof_columns) if '_cop_' in col]
opt.kinematic_osim_col_loc = [i_col for i_col, col in enumerate(opt.osim_dof_columns) if 'force' not in col]