-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_maml.py
223 lines (192 loc) · 8.42 KB
/
run_maml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
from eval_utils import *
from maml import (MAML,
MAML_ID, REPTILE_ID, REPTILIAN_MAML_ID,
ID_TO_NAME, NAME_TO_ID)
from multitask_env import MultiTaskEnv
from grasp_env import GraspEnv
from reach_task import ReachTargetCustom
from progress_callback import ProgressCallback
from utils import parse_arguments
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo.policies import MlpPolicy
from stable_baselines3 import PPO, HER
from rlbench.action_modes import ArmActionMode
import torch
import numpy as np
import time
import datetime
import copy
import sys
import ipdb
import wandb
import random
import ray
ray.init()
"""Commands:
Train:
python run_maml.py --train
Eval:
python run_maml.py --eval --model_path=models/reptile_randomized_targets/320_iters.zip
"""
class CustomPolicy(MlpPolicy):
def __init__(self, *args, **kwargs):
super(CustomPolicy, self).__init__(*args, **kwargs,
net_arch=[64, 64, dict(pi=[64, 64], vf=[64, 64])])
# self.optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
# self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min')
def _get_torch_save_params(self):
state_dicts = ["policy", "policy.optimizer", "policy.lr_scheduler"]
return state_dicts, []
if __name__ == "__main__":
# sys.stdout = open("outputs.txt", "w")
# seed = 12345
# seed = 320
# seed = 420 # lel
# Args
args = parse_arguments()
try:
random.seed(args.seed) # python random seed
torch.manual_seed(args.seed) # pytorch random seed
np.random.seed(args.seed) # numpy random seed
torch.backends.cudnn.deterministic = True
except:
print("NEED TO SPECIFY RANDOM SEED with --seed")
render = args.render
is_train = args.train
model_path = args.model_path
num_episodes = args.num_episodes
lr = args.lr
# lr_scheduler = None
timestamp = int(time.time())
print(args)
# MAML Iterations:
# for iter = 1:num_iters
# for task in task_batch (size=task_batch_size)
# During PPO's Adaption to a specific task:
# for t = 1:total_timesteps
# for e = 1:num_episodes:
# Gen + add new episode of max episode_length
# for epoch = 1:n_epochs
# for batch (size=batch_size) in batches
# Calc loss over collected episodes, step gradients
# Post-update collection of new data and gradients:
# for e = 1:num_episodes:
# Gen + add new episode of max episode_length
# Gradients += this task's PPO Gradients
# Step with summed gradients
# PPO Adaptation Parameters
episode_length = 200 # horizon H
num_episodes = 5 # "K" in K-shot learning
n_steps = num_episodes * episode_length
n_epochs = 2
batch_size = 64
num_iters = 300
algo_name = args.algo_name.upper()
algo_type = NAME_TO_ID[algo_name]
if (algo_type == MAML_ID):
num_iters = num_iters * batch_size * num_episodes * n_epochs
episode_length = 200
num_episodes = 1
batch_size = None
n_epochs = 1
total_timesteps = 1 * n_steps # number of "epochs"
action_size = 3 # only control EE position
manual_terminate = True
penalize_illegal = True
# Logistical parameters
verbose = 1
save_targets = True # save the train targets (loaded or generated)
save_freq = 1 # save model weights every save_freq iteration
# MAML parameters
# MAML_ID, REPTILE_ID, REPTILIAN_MAML_ID
num_tasks = 10
task_batch_size = 8 # Reptile uses 1 during training automatically
act_mode = ArmActionMode.DELTA_EE_POSE_PLAN_WORLD_FRAME
alpha = 1e-3
beta = 1e-3
vf_coef = 0.5
ent_coef = 0.01
base_init_kwargs = {'policy': CustomPolicy, 'n_steps': n_steps, 'n_epochs': n_epochs, 'learning_rate': alpha,
'batch_size': batch_size, 'verbose': verbose, 'vf_coef': vf_coef, 'ent_coef': ent_coef}
base_adapt_kwargs = {'total_timesteps': total_timesteps, "n_steps": n_steps}
render_mode = "human" if render else None
env_kwargs = {'task_class': ReachTargetCustom, 'act_mode': act_mode, "render_mode": render_mode,
'epsiode_length': episode_length, 'action_size': action_size,
'manual_terminate': manual_terminate, 'penalize_illegal': penalize_illegal}
# log results
config = {
"num_tasks": num_tasks,
"task_batch_size": task_batch_size,
"alpha": alpha,
"beta": beta,
"algo_name": algo_name,
"seed": args.seed,
}
run_title = "IDL - Train" if is_train else "IDL - Eval"
run = wandb.init(project=run_title, entity="idl-project", config=config)
wandb.save("maml.py")
wandb.save("run_maml.py")
wandb.save("multitask_env.py")
wandb.save("grasp_env.py")
print("Run Name:", run.name)
save_path = "models/" + str(run.name)
save_kwargs = {'save_freq': save_freq,
'save_path': save_path, 'tensorboard_log': save_path, 'save_targets': save_targets}
# load in targets
train_targets = MultiTaskEnv.targets
task_batch_size = min(task_batch_size, len(train_targets))
test_targets = MultiTaskEnv.test_targets
if is_train:
# create maml class that spawns multiple agents and sim environments
model = MAML(BaseAlgo=PPO, EnvClass=GraspEnv, algo_type=algo_type,
num_tasks=num_tasks, task_batch_size=task_batch_size, targets=train_targets,
alpha=alpha, beta=beta, model_path=model_path,
env_kwargs=env_kwargs, base_init_kwargs=base_init_kwargs, base_adapt_kwargs=base_adapt_kwargs)
model.learn(num_iters=num_iters, save_kwargs=save_kwargs)
else:
# create maml class that spawns multiple agents and sim environments
model = MAML(BaseAlgo=PPO, EnvClass=GraspEnv, algo_type=algo_type,
num_tasks=num_tasks, task_batch_size=task_batch_size, targets=train_targets,
alpha=alpha, beta=beta, model_path='', # <--- Empty path for random weights
env_kwargs=env_kwargs, base_init_kwargs=base_init_kwargs, base_adapt_kwargs=base_adapt_kwargs)
# see performance on train tasks
assert(model.model_path == '') # Testing randomly-initialized weights
rand_init_metrics = model.eval_performance(
model_type="random",
save_kwargs=save_kwargs,
num_iters=num_iters,
targets=test_targets)
rand_init_rewards = [v.reward for v in rand_init_metrics]
rand_init_success = [v.success_rate for v in rand_init_metrics]
rand_init_e_loss = [v.entropy_loss for v in rand_init_metrics]
rand_init_pg_loss = [v.pg_loss for v in rand_init_metrics]
rand_init_v_loss = [v.value_loss for v in rand_init_metrics]
rand_init_loss = [v.loss for v in rand_init_metrics]
# see performance on test tasks
assert model_path != "" # maml or reptile
pretrained_metrics = model.eval_performance(
model_type=algo_name, # "MAML", "RL^2"
save_kwargs=save_kwargs,
num_iters=num_iters,
targets=test_targets,
model_path=model_path)
pretrained_rewards = [v.reward for v in pretrained_metrics]
pretrained_success = [v.success_rate for v in pretrained_metrics]
pretrained_e_loss = [v.entropy_loss for v in pretrained_metrics]
pretrained_pg_loss = [v.pg_loss for v in pretrained_metrics]
pretrained_v_loss = [v.value_loss for v in pretrained_metrics]
pretrained_loss = [v.loss for v in pretrained_metrics]
np.savez(f"final_results_{algo_name}",
rand_init_rewards=rand_init_rewards,
rand_init_success=rand_init_success,
rand_init_e_loss=rand_init_e_loss,
rand_init_pg_loss=rand_init_pg_loss,
rand_init_v_loss=rand_init_v_loss,
rand_init_loss=rand_init_loss,
pretrained_rewards=pretrained_rewards,
pretrained_success=pretrained_success,
pretrained_e_loss=pretrained_e_loss,
pretrained_v_loss=pretrained_v_loss,
pretrained_loss=pretrained_loss,
)
model.close()