Skip to content

Commit

Permalink
added reset_after_rollout ability in adapt.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sliu2019 committed Jan 17, 2021
1 parent b5fc7cf commit 5fded47
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 51 deletions.
50 changes: 29 additions & 21 deletions adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
'decoder.output_projection.weight', 'decoder.output_projection.bias', ]


def adaptable_prediction(data_loader, model, train_params, device, adaptor, adapt_step=1):
def adaptable_prediction(data_loader, model, train_params, device, adaptor, adapt_step=1, reset_after_rollout=True):
'''adaptation hyper param'''
adapt_params = adapt_hyper_parameters(adaptor=adaptor, adapt_step=adapt_step, log_dir=train_params['log_dir'])
adapt_params._save_parameters()
Expand All @@ -45,14 +45,15 @@ def adaptable_prediction(data_loader, model, train_params, device, adaptor, adap
if train_params['encoder'] == 'rnn':
adapt_layers = rnn_layer_name[8:]
else:
adapt_layers = fc_layer_name[8:]
adapt_layers = fc_layer_name[8:] # TODO
print('adapt_weights:')
print(adapt_layers)
for name, p in model.named_parameters():
if name in adapt_layers:
adapt_weights.append(p)
print(name, p.size())

# IPython.embed()
optim_param = adapt_params.adapt_param()
if adaptor == 'mekf' or adaptor=='mekf_ma':
optimizer = MEKF_MA(adapt_weights, dim_out=adapt_step * train_params['coordinate_dim'],
Expand Down Expand Up @@ -81,13 +82,13 @@ def adaptable_prediction(data_loader, model, train_params, device, adaptor, adap
pred_result = online_adaptation(data_loader, model, optimizer, train_params, device,
adapt_step=adapt_step,
use_multi_epoch=st_param['use_multi_epoch'],
multiepoch_thresh=st_param['multiepoch_thresh'])
multiepoch_thresh=st_param['multiepoch_thresh'], reset_after_rollout=reset_after_rollout)


return pred_result


def test(params, adaptor='none', adapt_step=1):
def test(params, adaptor='none', adapt_step=1, reset_after_rollout=True):
train_params = params.train_param()
train_params['data_mean'] = torch.tensor(train_params['data_stats']['speed_mean'], dtype=torch.float).unsqueeze(
0).to(device)
Expand All @@ -105,18 +106,11 @@ def test(params, adaptor='none', adapt_step=1):
with torch.no_grad():
pred_result = get_predictions(data_loader, model, device)
else:
pred_result = adaptable_prediction(data_loader, model, train_params, device, adaptor, adapt_step)

traj_hist, traj_preds, traj_labels, intent_preds, intent_labels, pred_start_pos = pred_result

pred_result = adaptable_prediction(data_loader, model, train_params, device, adaptor, adapt_step, reset_after_rollout=reset_after_rollout)

# IPython.embed()
# TODO: what happened to the multiple rollouts in the test set? only just 1
# Note: traj_preds is 1 rollout's worth, w/ shape (len_rollout, ydim, 2)
# true_mse = torch.nn.MSELoss()(traj_preds * data_stats["data_std"] + data_stats["data_mean"],
# traj_labels * data_stats["data_std"] + data_stats["data_mean"])
# true_mse = true_mse.cpu().detach().numpy()

traj_hist, traj_preds, traj_labels, intent_preds, intent_labels, pred_start_pos = pred_result
traj_preds = get_position(traj_preds, pred_start_pos, data_stats) # NOTE: converted these to position first!
traj_labels = get_position(traj_labels, pred_start_pos, data_stats) # NOTE!!
intent_preds_prob = intent_preds.detach().clone()
Expand All @@ -131,27 +125,41 @@ def test(params, adaptor='none', adapt_step=1):

out_str = 'Evaluation Result: \n'

# out_str += "trajectory_mse: %.5f, \n" % (true_mse)

num, time_step = result['traj_labels'].shape[:2]
mse = np.power(result['traj_labels'] - result['traj_preds'], 2).sum() / (num * time_step)
out_str += "trajectory_mse: %.4f, \n" % (mse)
# TODO: calling this trajectory loss instead
# out_str += "trajectory_loss: %.4f, \n" % (mse)

# IPython.embed()
windows_per_rollout = 400 - (train_params["output_time_step"] + train_params["input_time_step"]) + 1
if reset_after_rollout:
# IPython.embed()
mse_list = []
for i in range(6): # TODO: set to 10
mse = np.power(result['traj_labels'][i*windows_per_rollout: (i+1)*windows_per_rollout] - result['traj_preds'][i*windows_per_rollout: (i+1)*windows_per_rollout], 2).sum() / (windows_per_rollout * time_step)
mse_list.append(mse)

result["mse_list"] = mse_list
result["mse_mean"] = np.mean(mse_list)
result["mse_std"] = np.std(mse_list)
print("******************************************************")
print("Per rollout stats")
print(mse_list)
print(result["mse_mean"])
print(result["mse_std"])
print("******************************************************")


acc = (result['intent_labels'] == result['intent_preds']).sum() / len(result['intent_labels'])
out_str += "action_acc: %.4f, \n" % (acc)

print(out_str)
# TODO: modified save path to be more specific
save_path = train_params['log_dir'] + adaptor + str(adapt_step) + '_pred.pkl'
joblib.dump(result, save_path)
print('save result to', save_path)
return result


def main(dataset='vehicle_ngsim', model_type='rnn', adaptor='mekf',adapt_step=1, epoch=1):
def main(dataset='vehicle_ngsim', model_type='rnn', adaptor='mekf',adapt_step=1, epoch=1, reset_after_rollout=True):
save_dir = 'output/' + dataset + '/' + model_type + '/'
# TODO: default, load model_1 (product of first epoch), but should instead specify best epoch
# model_path = save_dir + 'model_1.pkl'
Expand All @@ -160,11 +168,11 @@ def main(dataset='vehicle_ngsim', model_type='rnn', adaptor='mekf',adapt_step=1,
params._load_parameters(save_dir + 'log/')
params.params_dict['train_param']['init_model'] = model_path
params.print_params()
test(params, adaptor=adaptor, adapt_step=adapt_step)
test(params, adaptor=adaptor, adapt_step=adapt_step, reset_after_rollout=reset_after_rollout)


if __name__ == '__main__':
# main(adapt_step=50, model_type="fc", epoch=20)
# main(adapt_step=5)
main(adapt_step=50, model_type="fc", epoch=18)
main(adapt_step=50, model_type="fc", epoch=18, reset_after_rollout=True)

11 changes: 2 additions & 9 deletions dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,20 +109,13 @@ def __init__(self, params, mode='train',data_stats={}):
self.mode = mode
print(mode,'data preprocessing')
cache_dir = params['log_dir']+mode+'.cache'
# if os.path.exists(cache_dir):
if False:
if os.path.exists(cache_dir):
print('loading data from cache',cache_dir)
self.data = joblib.load(cache_dir)

# print("Just loaded data from saved cache")
# IPython.embed()
else:
raw_data = joblib.load(params['data_path'])[mode]
self.data = data_time_split(raw_data,params) # This just does windowing

# print("Just loaded data to create anew")
# IPython.embed()

if mode=='train':
data_stats['traj_mean'] = np.mean(self.data['x_traj'],axis=(0,1))
data_stats['traj_std'] = np.std(self.data['x_traj'], axis=(0, 1))
Expand Down Expand Up @@ -158,7 +151,7 @@ def __init__(self, params, mode='train',data_stats={}):
print(mode + '_data size:', len(self.data['x_encoder']))
print('each category counts:')
print(Counter(self.data['y_intent']))
print("In dataset.py")
# print("In dataset.py")
# IPython.embed()

def __getitem__(self, index):
Expand Down
68 changes: 47 additions & 21 deletions utils/adapt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@
import numpy as np
import torch
from .pred_utils import get_prediction_on_batch
import IPython
import IPython, copy

# TODO: this selects how much data to use
# data_size=100
# test set size
# 16*128 = 2048
# data_size=800
# test set size: 1986
data_size=3000
# data_size = 400
def batch2iter_data(dataloader, device='cpu',data_size=data_size):
traj_hist, traj_labels, intent_labels, start_decodes, pred_start_pos, x_mask = None, None, None, None, None, None
traj_hist, traj_labels, intent_labels, start_decodes, pred_start_pos, x_mask, rollout_start_inds = None, None, None, None, None, None, [0]

for i, data in enumerate(dataloader, 0):
x, y_traj, y_intent, start_decode, start_pos, mask = data
Expand All @@ -30,10 +29,11 @@ def batch2iter_data(dataloader, device='cpu',data_size=data_size):
pred_start_pos = torch.cat([pred_start_pos, start_pos], dim=0)
x_mask = torch.cat([x_mask, mask], dim=0)

# rollout_start_inds.append(traj_hist.shape[0])
if data_size>0 and traj_hist.size(0)>data_size:
break

print(traj_hist.shape)
# print(traj_hist.shape)
traj_hist = traj_hist.float().to(device)
traj_labels = traj_labels.float().to(device)
intent_labels = intent_labels.float().to(device)
Expand All @@ -42,48 +42,72 @@ def batch2iter_data(dataloader, device='cpu',data_size=data_size):
x_mask = x_mask.byte().to(device)
data = [traj_hist, traj_labels, intent_labels, start_decodes, pred_start_pos, x_mask]

print("at the end of batch2iterdata")
# print("at the end of batch2iterdata")
# IPython.embed()
return data


def online_adaptation(dataloader, model, optimizer, params, device,
adapt_step=1, use_multi_epoch=False,multiepoch_thresh=(0, 0)):
adapt_step=1, use_multi_epoch=False,multiepoch_thresh=(0, 0), reset_after_rollout=True):
optim_name = optimizer.__class__.__name__
if optim_name == 'Lookahead':
optim_name = optim_name + '_' + optimizer.optimizer.__class__.__name__
print('optimizer:', optim_name)
print('adapt_step:', adapt_step, ', use_multi_epoch:', use_multi_epoch,', multiepoch_thresh:', multiepoch_thresh),
t1 = time()

data = batch2iter_data(dataloader, device)
traj_hist, traj_labels, intent_labels, _, pred_start_pos, _ = data
traj_hist, traj_labels, intent_labels, start_decodes, pred_start_pos, x_mask = data
batches = []
for ii in range(len(pred_start_pos)):
temp_batch=[]
for item in data:
temp_batch.append(item[[ii]])
batches.append(temp_batch)

traj_preds = torch.zeros_like(traj_labels)
intent_preds = torch.zeros(size=(len(intent_labels),params['class_num']),dtype=torch.float,device=intent_labels.device)
# IPython.embed()
windows_per_rollout = 400 - (params["output_time_step"] + params["input_time_step"]) + 1
if reset_after_rollout:
for i in range(6): # TODO: 10
rollout_batch = batches[i*windows_per_rollout: (i+1)*windows_per_rollout]
rollout_traj_preds, rollout_intent_preds = online_adaptation_single_rollout(model, rollout_batch, optimizer, optim_name, adapt_step, multiepoch_thresh, device)

if i == 0:
traj_preds = rollout_traj_preds
intent_preds = rollout_intent_preds
else:
traj_preds = torch.cat((traj_preds, rollout_traj_preds), axis=0)
intent_preds = torch.cat((intent_preds, rollout_intent_preds), axis=0)
else:
traj_preds, intent_preds = online_adaptation_single_rollout(model, batches, optimizer,
optim_name, adapt_step,
multiepoch_thresh, device)

return traj_hist, traj_preds, traj_labels, intent_preds, intent_labels, pred_start_pos


def online_adaptation_single_rollout(model, batches, optimizer, optim_name, adapt_step, multiepoch_thresh, device):
"""
Returns a 3D Tensor
"""
traj_preds = []
intent_preds = []

temp_pred_list = []
temp_label_list = []
temp_data_list = []
cnt = [0, 0, 0]
cost_list = []
post_cost_list=[]

cost_diff_list = []
print("In online_adaptation, ln 69")
# IPython.embed()
for t in range(len(pred_start_pos)):

t1 = time()

for t in range(len(batches)):
batch_data = batches[t]
_, pred_traj, y_traj, pred_intent, _, _ = get_prediction_on_batch(batch_data, model, device)
# IPython.embed()
traj_preds[t] = pred_traj[0].detach()
intent_preds[t] = pred_intent[0].detach()

traj_preds.append(pred_traj[0].detach()[None])
intent_preds.append(pred_intent[0].detach()[None])

temp_pred_list += [pred_traj]
temp_label_list += [y_traj]
Expand Down Expand Up @@ -155,8 +179,10 @@ def lbfgs_closure():
print('finished pred {}, time:{}, partial cost before adapt:{}, partial cost after adapt:{}'.format(t, time() - t1, full_loss,post_loss))
t1 = time()

# IPython.embed()
print("avg cost improvement (should be +): %f +/- %f" % (np.mean(cost_diff_list), np.std(cost_diff_list)))
print('avg_cost:', np.mean(cost_list))
print('number of update epoch', cnt)
return traj_hist, traj_preds,traj_labels, intent_preds, intent_labels, pred_start_pos

traj_preds = torch.cat(traj_preds, axis=0)
intent_preds = torch.cat(intent_preds, axis=0)
return traj_preds, intent_preds

0 comments on commit 5fded47

Please sign in to comment.