diff --git a/README.md b/README.md index 18eccd7..f834438 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,14 @@ -# MEKF_MAME -**M**odified **E**xtended **K**alman **F**ilter with generalized exponential **M**oving **A**verage and dynamic **M**ulti-**E**poch update strategy (MEKF_MAME) +# MEKFEMA-DME +**M**odified **E**xtended **K**alman **F**ilter with generalized **E**xponential **M**oving **A**verage and **D**ynamic **M**ulti-**E**poch update strategy (MEKFEMA-DME) -Pytorch implementation source coder for paper "Robust Nonlinear Adaptation Algorithms for Multi-TaskPrediction Networks". +Pytorch implementation source coder for paper [Robust Online Model Adaptation by Extended Kalman Filter with Exponential Moving Average and Dynamic Multi-Epoch Strategy](https://arxiv.org/abs/1912.01790). -**We will release the code, once the paper is published.** -In the paper, EKF based adaptation algorithm MEKF_λ was introduced as an effective base algorithm for online adaptation. In order to improve the convergence property of MEKF_λ, generalized exponential moving average filtering was investigated. Then this paper introduced a dynamic multi-epoch update strategy, which can be compatible with any optimizers. By combining all extensions with base MEKF_λ algorithm, robust online adaptation algorithm MEKF_MAME was created. +Inspired by Extended Kalman Filter (EKF), a base adaptation algorithm Modified EKF with forgetting +factor (MEKF_$\lambda$) is introduced first. Using exponential moving average (EMA) methods, this +paper proposes EMA filtering to the base EKFλ in order to increase the convergence rate. followed by exponential moving average filtering techniques. +Then in order to effectively utilize the samples in online +adaptation, this paper proposes a dynamic multi-epoch update strategy to discriminate the “hard” +samples from “easy” samples, and sets different weights for them. With all these extensions, we propose a robust online adaptation algorithm: +MEKF with Exponential Moving Average and Dynamic Multi-Epoch strategy (MEKFEMA-DME). diff --git a/adapt.py b/adapt.py new file mode 100644 index 0000000..9965955 --- /dev/null +++ b/adapt.py @@ -0,0 +1,150 @@ +# coding=utf-8 + +import os +import warnings + +import joblib +import torch +import numpy as np +from dataset.dataset import get_data_loader +from adaptation.lookahead import Lookahead +from adaptation.mekf import MEKF_MA +from parameters import hyper_parameters, adapt_hyper_parameters +from utils.adapt_utils import online_adaptation +from utils.pred_utils import get_predictions,get_position + +os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' +os.environ['CUDA_VISIBLE_DEVICES'] = "0" +warnings.filterwarnings("ignore") + + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +device =torch.device("cpu") +print('testing with device:', device) + +rnn_layer_name = ['encoder.rnn.weight_ih_l0', 'encoder.rnn.bias_ih_l0', + 'encoder.rnn.weight_hh_l0','encoder.rnn.bias_hh_l0', + 'decoder.rnn.weight_ih_l0', 'decoder.rnn.bias_ih_l0', + 'decoder.rnn.weight_hh_l0','decoder.rnn.bias_hh_l0', + 'decoder.output_projection.weight', 'decoder.output_projection.bias'] +fc_layer_name = ['encoder.layers.0.weight', 'encoder.layers.0.bias', + 'encoder.layers.3.weight', 'encoder.layers.3.bias', + 'decoder.layers.0.weight', 'decoder.layers.0.bias', + 'decoder.layers.3.weight', 'decoder.layers.3.bias', + 'decoder.output_projection.weight', 'decoder.output_projection.bias', ] + + +def adaptable_prediction(data_loader, model, train_params, device, adaptor, adapt_step=1): + '''adaptation hyper param''' + adapt_params = adapt_hyper_parameters(adaptor=adaptor, adapt_step=adapt_step, log_dir=train_params['log_dir']) + adapt_params._save_parameters() + adapt_params.print_params() + + adapt_weights = [] + if train_params['encoder'] == 'rnn': + adapt_layers = rnn_layer_name[8:] + else: + adapt_layers = fc_layer_name[8:] + print('adapt_weights:') + print(adapt_layers) + for name, p in model.named_parameters(): + if name in adapt_layers: + adapt_weights.append(p) + print(name, p.size()) + + optim_param = adapt_params.adapt_param() + if adaptor == 'mekf' or adaptor=='mekf_ma': + optimizer = MEKF_MA(adapt_weights, dim_out=adapt_step * train_params['coordinate_dim'], + p0=optim_param['p0'], lbd=optim_param['lbd'], sigma_r=optim_param['sigma_r'], + sigma_q=optim_param['sigma_q'], lr=optim_param['lr'], + miu_v=optim_param['miu_v'], miu_p=optim_param['miu_p'], + k_p=optim_param['k_p']) + elif adaptor == 'sgd': + optimizer = torch.optim.SGD(adapt_weights, lr=optim_param['lr'], momentum=optim_param['momentum'], + nesterov=optim_param['nesterov']) + + elif adaptor == 'adam': + optimizer = torch.optim.Adam(adapt_weights, lr=optim_param['lr'], betas=optim_param['betas'], + amsgrad=optim_param['amsgrad']) + + elif adaptor == 'lbfgs': + optimizer = torch.optim.LBFGS(adapt_weights, lr=optim_param['lr'], max_iter=optim_param['max_iter'], + history_size=optim_param['history_size']) + else: + raise NotImplementedError + print('base optimizer configs:', optimizer.defaults) + if optim_param['use_lookahead']: + optimizer = Lookahead(optimizer, k=optim_param['la_k'], alpha=optim_param['la_alpha']) + + st_param = adapt_params.strategy_param() + pred_result = online_adaptation(data_loader, model, optimizer, train_params, device, + adapt_step=adapt_step, + use_multi_epoch=st_param['use_multi_epoch'], + multiepoch_thresh=st_param['multiepoch_thresh']) + + + return pred_result + + +def test(params, adaptor='none', adapt_step=1): + train_params = params.train_param() + train_params['data_mean'] = torch.tensor(train_params['data_stats']['speed_mean'], dtype=torch.float).unsqueeze( + 0).to(device) + train_params['data_std'] = torch.tensor(train_params['data_stats']['speed_std'], dtype=torch.float).unsqueeze(0).to( + device) + data_stats = {'data_mean': train_params['data_mean'], 'data_std': train_params['data_std']} + + model = torch.load(train_params['init_model']) + model = model.to(device) + print('load model', train_params['init_model']) + + data_loader = get_data_loader(train_params, mode='test') + print('begin to test') + if adaptor == 'none': + with torch.no_grad(): + pred_result = get_predictions(data_loader, model, device) + else: + pred_result = adaptable_prediction(data_loader, model, train_params, device, adaptor, adapt_step) + + traj_hist, traj_preds, traj_labels, intent_preds, intent_labels, pred_start_pos = pred_result + traj_preds = get_position(traj_preds, pred_start_pos, data_stats) + traj_labels = get_position(traj_labels, pred_start_pos, data_stats) + intent_preds_prob = intent_preds.detach().clone() + _, intent_preds = intent_preds.max(1) + + result = {'traj_hist': traj_hist, 'traj_preds': traj_preds, 'traj_labels': traj_labels, + 'intent_preds': intent_preds,'intent_preds_prob':intent_preds_prob, + 'intent_labels': intent_labels, 'pred_start_pos': pred_start_pos} + + for k, v in result.items(): + result[k] = v.cpu().detach().numpy() + + out_str = 'Evaluation Result: \n' + + num, time_step = result['traj_labels'].shape[:2] + mse = np.power(result['traj_labels'] - result['traj_preds'], 2).sum() / (num * time_step) + out_str += "trajectory_mse: %.4f, \n" % (mse) + + acc = (result['intent_labels'] == result['intent_preds']).sum() / len(result['intent_labels']) + out_str += "action_acc: %.4f, \n" % (acc) + + print(out_str) + save_path = train_params['log_dir'] + adaptor + str(adapt_step) + '_pred.pkl' + joblib.dump(result, save_path) + print('save result to', save_path) + return result + + +def main(dataset='vehicle_ngsim', model_type='rnn', adaptor='mekf',adapt_step=1): + save_dir = 'output/' + dataset + '/' + model_type + '/' + model_path = save_dir + 'model_1.pkl' + params = hyper_parameters() + params._load_parameters(save_dir + 'log/') + params.params_dict['train_param']['init_model'] = model_path + params.print_params() + test(params, adaptor=adaptor, adapt_step=adapt_step) + + +if __name__ == '__main__': + main() + diff --git a/adaptation/__init__.py b/adaptation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/adaptation/lookahead.py b/adaptation/lookahead.py new file mode 100644 index 0000000..bb98809 --- /dev/null +++ b/adaptation/lookahead.py @@ -0,0 +1,70 @@ +from collections import defaultdict +from torch.optim import Optimizer +import torch + + +class Lookahead(Optimizer): + def __init__(self, optimizer, k=5, alpha=0.5): + self.optimizer = optimizer + self.k = k + self.alpha = alpha + self.param_groups = self.optimizer.param_groups + self.state = defaultdict(dict) + self.fast_state = self.optimizer.state + for group in self.param_groups: + group["counter"] = 0 + + def update(self, group): + for fast in group["params"]: + param_state = self.state[fast] + if "slow_param" not in param_state: + param_state["slow_param"] = torch.zeros_like(fast.data) + param_state["slow_param"].copy_(fast.data) + slow = param_state["slow_param"] + slow += (fast.data - slow) * self.alpha + fast.data.copy_(slow) + + def update_lookahead(self): + for group in self.param_groups: + self.update(group) + + def step(self, closure=None): + loss = self.optimizer.step(closure) + for group in self.param_groups: + if group["counter"] == 0: + self.update(group) + group["counter"] += 1 + if group["counter"] >= self.k: + group["counter"] = 0 + return loss + + def state_dict(self): + fast_state_dict = self.optimizer.state_dict() + slow_state = { + (id(k) if isinstance(k, torch.Tensor) else k): v + for k, v in self.state.items() + } + fast_state = fast_state_dict["state"] + param_groups = fast_state_dict["param_groups"] + return { + "fast_state": fast_state, + "slow_state": slow_state, + "param_groups": param_groups, + } + + def load_state_dict(self, state_dict): + slow_state_dict = { + "state": state_dict["slow_state"], + "param_groups": state_dict["param_groups"], + } + fast_state_dict = { + "state": state_dict["fast_state"], + "param_groups": state_dict["param_groups"], + } + super(Lookahead, self).load_state_dict(slow_state_dict) + self.optimizer.load_state_dict(fast_state_dict) + self.fast_state = self.optimizer.state + + def add_param_group(self, param_group): + param_group["counter"] = 0 + self.optimizer.add_param_group(param_group) \ No newline at end of file diff --git a/adaptation/mekf.py b/adaptation/mekf.py new file mode 100644 index 0000000..1fff483 --- /dev/null +++ b/adaptation/mekf.py @@ -0,0 +1,340 @@ + + +import torch +from torch.optim import Optimizer + + +class MEKF_MA(Optimizer): + """ + Modified Extended Kalman Filter with generalized exponential Moving Average + """ + + def __init__(self, params, dim_out, p0=1e-2, lbd=1, sigma_r=None, sigma_q=0, lr=1, + miu_v=0, miu_p=0, k_p=1, + R_decay=False,R_decay_step=1000000): + + if sigma_r is None: + sigma_r = max(lbd,0) + self._check_format(dim_out, p0, lbd, sigma_r, sigma_q, lr,miu_v,miu_p,k_p,R_decay,R_decay_step) + defaults = dict(p0=p0, lbd=lbd, sigma_r=sigma_r, sigma_q=sigma_q, + lr=lr,miu_v=miu_v,miu_p=miu_p,k_p=k_p, + R_decay=R_decay,R_decay_step=R_decay_step) + super(MEKF_MA, self).__init__(params, defaults) + + self.state['dim_out'] = dim_out + with torch.no_grad(): + self._init_mekf_matrix() + + def _check_format(self, dim_out, p0, lbd, sigma_r, sigma_q, lr, miu_v, miu_p, k_p, R_decay, R_decay_step): + if not isinstance(dim_out, int) and dim_out > 0: + raise ValueError("Invalid output dimension: {}".format(dim_out)) + if not 0.0 < p0: + raise ValueError("Invalid initial P value: {}".format(p0)) + if not 0.0 < lbd: + raise ValueError("Invalid forgetting factor: {}".format(lbd)) + if not 0.0 < sigma_r: + raise ValueError("Invalid covariance matrix value for R: {}".format(sigma_r)) + if not 0.0 <= sigma_q: + raise ValueError("Invalid covariance matrix value for Q: {}".format(sigma_q)) + if not 0.0 < lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + + if not 0.0 <= miu_v < 1.0: + raise ValueError("Invalid EMA decaying factor for V matrix: {}".format(miu_v)) + if not 0.0 <= miu_p < 1.0: + raise ValueError("Invalid EMA decaying factor for P matrix: {}".format(miu_p)) + if not isinstance(k_p, int) and k_p >= 0: + raise ValueError("Invalid delayed step size of Lookahead P: {}".format(k_p)) + + if not isinstance(R_decay, int) and not isinstance(R_decay, bool): + raise ValueError("Invalid R decay flag: {}".format(R_decay)) + if not isinstance(R_decay_step, int): + raise ValueError("Invalid max step for R decaying: {}".format(R_decay_step)) + + def _init_mekf_matrix(self): + self.state['step']=0 + self.state['mekf_groups']=[] + dim_out = self.state['dim_out'] + for group in self.param_groups: + mekf_mat=[] + for p in group['params']: + matrix = {} + size = p.size() + dim_w=1 + for dim in size: + dim_w*=dim + device= p.device + matrix['P'] = group['p0']*torch.eye(dim_w,dtype=torch.float,device=device) + matrix['R'] = group['sigma_r']*torch.eye(dim_out,dtype=torch.float,device=device) + matrix['Q'] = group['sigma_q'] * torch.eye(dim_w, dtype=torch.float, device=device) + matrix['H'] = None + mekf_mat.append(matrix) + self.state['mekf_groups'].append(mekf_mat) + + def step(self,closure=None, H_groups=None, err=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + H_groups: groups of gradient matrix + err: error value + example 1 (optimization step with closure): + # optimizer -> MEKF_MA optimizer + # y -> observed value, y_hat -> predicted value + y = y.contiguous().view((-1, 1)) # shape of (dim_out,1) + y_hat = y_hat.contiguous().view((-1, 1)) # shape of (dim_out,1) + err = (y - y_hat).detach() + + def mekf_closure(index=0): + optimizer.zero_grad() + dim_out = optimizer.state['dim_out'] + retain = index < dim_out - 1 + y_hat[index].backward(retain_graph=retain) + return err + + optimizer.step(mekf_closure) + + example 2 (optimization step with H_groups): + # y -> observed value, y_hat -> predicted value + # H -> gradient matrix that need to be specified + y = y.contiguous().view((-1, 1)) # shape of (dim_out,1) + y_hat = y_hat.contiguous().view((-1, 1)) # shape of (dim_out,1) + err = (y - y_hat).detach() + optimizer.step(H_groups=H_groups,err=err) + """ + self.state['step'] += 1 + + if closure is not None: + for y_ind in range(self.state['dim_out']): + err = closure(y_ind) + for group_ind in range(len(self.param_groups)): + group = self.param_groups[group_ind] + mekf_mat = self.state['mekf_groups'][group_ind] + for ii, w in enumerate(group['params']): + if w.grad is None: + continue + H_n = mekf_mat[ii]['H'] + grad = w.grad.data.detach() + if len(w.size())>1: + grad = grad.transpose(1, 0) + grad = grad.contiguous().view((1,-1)) + if y_ind ==0: + H_n=grad + else: + H_n = torch.cat([H_n,grad],dim=0) + self.state['mekf_groups'][group_ind][ii]['H'] = H_n + else: + for group_ind in range(len(self.param_groups)): + H_mats = H_groups[group_ind] + for ii, H_n in enumerate(H_mats): + self.state['mekf_groups'][group_ind][ii]['H'] = H_n + + err_T = err.transpose(0,1) + + for group_ind in range(len(self.param_groups)): + group = self.param_groups[group_ind] + mekf_mat = self.state['mekf_groups'][group_ind] + + miu_v = group['miu_v'] + miu_p = group['miu_p'] + k_p = group['k_p'] + lr = group['lr'] + lbd = group['lbd'] + + for ii,w in enumerate(group['params']): + if w.grad is None: + continue + + P_n_1 = mekf_mat[ii]['P'] + R_n = mekf_mat[ii]['R'] + Q_n = mekf_mat[ii]['Q'] + H_n = mekf_mat[ii]['H'] + H_n_T = H_n.transpose(0, 1) + + if group['R_decay']: + miu = 1.0 / min(self.state['step'],group['R_decay_step']) + R_n = R_n + miu * (err.mm(err_T) + H_n.mm(P_n_1).mm(H_n_T) - R_n) + self.state['mekf_groups'][group_ind][ii]['R']= R_n + + g_n = H_n.mm(P_n_1).mm(H_n_T) + R_n + g_n = g_n.inverse() + K_n = P_n_1.mm(H_n_T).mm(g_n) + V_n = lr * K_n.mm(err) + if len(w.size()) > 1: + V_n = V_n.view((w.size(1),w.size(0))).transpose(1,0) + else: + V_n = V_n.view(w.size()) + if miu_v>0: + param_state = self.state[w] + if 'buffer_V' not in param_state: + V_ema = param_state['buffer_V'] = torch.clone(V_n).detach() + else: + V_ema = param_state['buffer_V'] + V_ema.mul_(miu_v).add_(V_n.mul(1-miu_v).detach()) + V_n=V_ema + w.data.add_(V_n) + + P_n = (1/lbd) * (P_n_1 - K_n.mm(H_n).mm(P_n_1) + Q_n) + if miu_p>0 and k_p>0: + if self.state['step'] % k_p==0: + param_state = self.state[w] + if 'buffer_P' not in param_state: + P_ema = param_state['buffer_P'] = torch.clone(P_n).detach() + else: + P_ema = param_state['buffer_P'] + P_ema.mul_(miu_p).add_(P_n.mul(1 - miu_p).detach()) + P_n = P_ema + self.state['mekf_groups'][group_ind][ii]['P'] =P_n + + return err + +class MEKF(Optimizer): + """ + Modified Extended Kalman Filter + """ + + def __init__(self, params, dim_out, p0=1e-2, lbd=1, sigma_r=None, sigma_q=0, + R_decay=False,R_decay_step=1000000): + + if sigma_r is None: + sigma_r = max(lbd,0) + self._check_format(dim_out, p0, lbd, sigma_r, sigma_q, R_decay,R_decay_step) + defaults = dict(p0=p0, lbd=lbd, sigma_r=sigma_r, sigma_q=sigma_q, + R_decay=R_decay,R_decay_step=R_decay_step) + super(MEKF, self).__init__(params, defaults) + + self.state['dim_out'] = dim_out + with torch.no_grad(): + self._init_mekf_matrix() + + def _check_format(self, dim_out, p0, lbd, sigma_r, sigma_q, R_decay, R_decay_step): + if not isinstance(dim_out, int) and dim_out > 0: + raise ValueError("Invalid output dimension: {}".format(dim_out)) + if not 0.0 < p0: + raise ValueError("Invalid initial P value: {}".format(p0)) + if not 0.0 < lbd: + raise ValueError("Invalid forgetting factor: {}".format(lbd)) + if not 0.0 < sigma_r: + raise ValueError("Invalid covariance matrix value for R: {}".format(sigma_r)) + if not 0.0 <= sigma_q: + raise ValueError("Invalid covariance matrix value for Q: {}".format(sigma_q)) + + if not isinstance(R_decay, int) and not isinstance(R_decay, bool): + raise ValueError("Invalid R decay flag: {}".format(R_decay)) + if not isinstance(R_decay_step, int): + raise ValueError("Invalid max step for R decaying: {}".format(R_decay_step)) + + def _init_mekf_matrix(self): + self.state['step']=0 + self.state['mekf_groups']=[] + dim_out = self.state['dim_out'] + for group in self.param_groups: + mekf_mat=[] + for p in group['params']: + matrix = {} + size = p.size() + dim_w=1 + for dim in size: + dim_w*=dim + device= p.device + matrix['P'] = group['p0']*torch.eye(dim_w,dtype=torch.float,device=device) + matrix['R'] = group['sigma_r']*torch.eye(dim_out,dtype=torch.float,device=device) + matrix['Q'] = group['sigma_q'] * torch.eye(dim_w, dtype=torch.float, device=device) + matrix['H'] = None + mekf_mat.append(matrix) + self.state['mekf_groups'].append(mekf_mat) + + def step(self,closure=None, H_groups=None, err=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + H_groups: groups of gradient matrix + err: error value + example 1 (optimization step with closure): + # optimizer -> MEKF_MA optimizer + # y -> observed value, y_hat -> predicted value + y = y.contiguous().view((-1, 1)) # shape of (dim_out,1) + y_hat = y_hat.contiguous().view((-1, 1)) # shape of (dim_out,1) + err = (y - y_hat).detach() + + def mekf_closure(index=0): + optimizer.zero_grad() + dim_out = optimizer.state['dim_out'] + retain = index < dim_out - 1 + y_hat[index].backward(retain_graph=retain) + return err + + optimizer.step(mekf_closure) + + example 2 (optimization step with H_groups): + # y -> observed value, y_hat -> predicted value + # H -> gradient matrix that need to be specified + y = y.contiguous().view((-1, 1)) # shape of (dim_out,1) + y_hat = y_hat.contiguous().view((-1, 1)) # shape of (dim_out,1) + err = (y - y_hat).detach() + optimizer.step(H_groups=H_groups,err=err) + """ + self.state['step'] += 1 + + if closure is not None: + for y_ind in range(self.state['dim_out']): + err = closure(y_ind) + for group_ind in range(len(self.param_groups)): + group = self.param_groups[group_ind] + mekf_mat = self.state['mekf_groups'][group_ind] + for ii, w in enumerate(group['params']): + if w.grad is None: + continue + H_n = mekf_mat[ii]['H'] + grad = w.grad.data.detach() + if len(w.size())>1: + grad = grad.transpose(1, 0) + grad = grad.contiguous().view((1,-1)) + if y_ind ==0: + H_n=grad + else: + H_n = torch.cat([H_n,grad],dim=0) + self.state['mekf_groups'][group_ind][ii]['H'] = H_n + else: + for group_ind in range(len(self.param_groups)): + H_mats = H_groups[group_ind] + for ii, H_n in enumerate(H_mats): + self.state['mekf_groups'][group_ind][ii]['H'] = H_n + + err_T = err.transpose(0,1) + + for group_ind in range(len(self.param_groups)): + group = self.param_groups[group_ind] + mekf_mat = self.state['mekf_groups'][group_ind] + lbd = group['lbd'] + + for ii,w in enumerate(group['params']): + if w.grad is None: + continue + + P_n_1 = mekf_mat[ii]['P'] + R_n = mekf_mat[ii]['R'] + Q_n = mekf_mat[ii]['Q'] + H_n = mekf_mat[ii]['H'] + H_n_T = H_n.transpose(0, 1) + + if group['R_decay']: + miu = 1.0 / min(self.state['step'],group['R_decay_step']) + R_n = R_n + miu * (err.mm(err_T) + H_n.mm(P_n_1).mm(H_n_T) - R_n) + self.state['mekf_groups'][group_ind][ii]['R']= R_n + + g_n = H_n.mm(P_n_1).mm(H_n_T) + R_n + g_n = g_n.inverse() + K_n = P_n_1.mm(H_n_T).mm(g_n) + V_n = K_n.mm(err) + if len(w.size()) > 1: + V_n = V_n.view((w.size(1),w.size(0))).transpose(1,0) + else: + V_n = V_n.view(w.size()) + w.data.add_(V_n) + + P_n = (1/lbd) * (P_n_1 - K_n.mm(H_n).mm(P_n_1) + Q_n) + self.state['mekf_groups'][group_ind][ii]['P'] =P_n + + return err \ No newline at end of file diff --git a/data/vehicle_ngsim.pkl b/data/vehicle_ngsim.pkl new file mode 100644 index 0000000..16a7ecb Binary files /dev/null and b/data/vehicle_ngsim.pkl differ diff --git a/dataset/dataset.py b/dataset/dataset.py new file mode 100644 index 0000000..8293ef9 --- /dev/null +++ b/dataset/dataset.py @@ -0,0 +1,217 @@ +import os +from collections import Counter + +import joblib +import numpy as np +import torch +from torch.utils.data.dataset import Dataset + +def data_time_split(data_list,params): + input_time_step = params['input_time_step'] + output_time_step = params['output_time_step'] + trajs = data_list['traj'] + speeds = data_list['speed'] + features = data_list['feature'] + actions = data_list['action'] + x_traj,y_traj,x_traj_len=[],[],[] + y_intent = [] + x_speed, y_speed = [],[] + x_feature, y_feature = [],[] + data_ids = [] + inds=np.arange(0,len(trajs)) + + for ind in inds: + traj = trajs[ind] + speed = speeds[ind] + feature = features[ind] + action = actions[ind] + begin=0 + end=input_time_step+output_time_step + steps=len(traj) + src_len = steps - output_time_step + mask_now=False + if src_len < input_time_step/4: + continue + + if steps< end: + mask_now = True + pad_traj = np.array([traj[0]*0]*(end-steps)) + traj = np.concatenate([pad_traj,traj]) + pad_speed = np.array([speed[0] * 0] * (end - steps)) + speed = np.concatenate([pad_speed, speed]) + pad_feature= np.array([feature[0] * 0] * (end - steps)) + feature = np.concatenate([pad_feature, speed]) + pad_actions= np.array([action[0] * 0] * (end - steps)) + action = np.concatenate([pad_actions, action]) + steps = len(traj) + + while end<=steps: + # input + inp_traj = traj[begin:begin+input_time_step].reshape((input_time_step, -1)) + x_traj.append(inp_traj) + data_ids.append(ind) + if mask_now: + x_traj_len.append(src_len) + else: + x_traj_len.append(len(inp_traj)) + + inp_sp= speed[begin:begin+input_time_step].reshape((input_time_step, -1)) + x_speed.append(inp_sp) + + inp_feat= feature[begin:begin+input_time_step].reshape((input_time_step, -1)) + x_feature.append(inp_feat) + + # output + out_traj = traj[begin+input_time_step:end].reshape((output_time_step, -1)) + y_traj.append(out_traj) + + out_sp= speed[begin+input_time_step:end].reshape((output_time_step, -1)) + y_speed.append(out_sp) + + out_feat= feature[begin+input_time_step:end].reshape((output_time_step, -1)) + y_feature.append(out_feat) + + y_intent.append(action[begin+input_time_step-1]) + + begin += 1 + end += 1 + + x_traj=np.array(x_traj) + x_speed = np.array(x_speed) + x_feature = np.array(x_feature) + y_traj=np.array(y_traj) + y_speed = np.array(y_speed) + y_feature = np.array(y_feature) + y_intent=np.array(y_intent) + x_traj_len = np.array(x_traj_len) + data_ids=np.array(data_ids) + pred_start_pos = x_traj[:,-1] + data ={'x_traj':x_traj,'x_speed':x_speed,'x_feature':x_feature, + 'y_traj':y_traj,'y_speed':y_speed,'y_feature':y_feature, + 'y_intent':y_intent,'pred_start_pos':pred_start_pos, + 'x_traj_len':x_traj_len,'data_ids':data_ids} + return data + +def normalize_data(data, data_stats): + new_data={} + for k,v in data.items(): + if k in ['x_traj','x_speed','y_traj','y_speed','x_feature','y_feature']: + mark = k.split('_')[-1] + data_mean,data_std=data_stats[mark+'_mean'],data_stats[mark+'_std'] + new_data[k] = (v-data_mean)/data_std + else: + new_data[k] = v + return new_data + +class Trajectory_Data(Dataset): + def __init__(self, params, mode='train',data_stats={}): + self.mode = mode + print(mode,'data preprocessing') + cache_dir = params['log_dir']+mode+'.cache' + if os.path.exists(cache_dir): + print('loading data from cache',cache_dir) + self.data = joblib.load(cache_dir) + else: + raw_data = joblib.load(params['data_path'])[mode] + self.data = data_time_split(raw_data,params) + + if mode=='train': + data_stats['traj_mean'] = np.mean(self.data['x_traj'],axis=(0,1)) + data_stats['traj_std'] = np.std(self.data['x_traj'], axis=(0, 1)) + data_stats['speed_mean'] = np.mean(self.data['x_speed'],axis=(0,1)) + data_stats['speed_std'] = np.std(self.data['x_speed'], axis=(0, 1)) + data_stats['feature_mean'] = np.mean(self.data['x_feature'],axis=(0,1)) + data_stats['feature_std'] = np.std(self.data['x_feature'], axis=(0, 1)) + self.data['data_stats'] = data_stats + if params['normalize_data']: + if mode=='train': + print('data statistics:') + print(data_stats) + self.data = normalize_data(self.data, data_stats) + joblib.dump(self.data,cache_dir) + + enc_inp= None + for feat in params['inp_feat']: + dat = self.data['x_'+feat] + if enc_inp is None: + enc_inp = dat + else: + enc_inp = np.concatenate([enc_inp,dat],axis=-1) + + + self.data['x_encoder'] = enc_inp + self.data['y_decoder'] = self.data['y_speed'] + self.data['start_decode'] = self.data['x_speed'][:,-1] + + + self.input_time_step = params['input_time_step'] + self.input_feat_dim = self.data['x_encoder'].shape[2] + + print(mode + '_data size:', len(self.data['x_encoder'])) + print('each category counts:') + print(Counter(self.data['y_intent'])) + + def __getitem__(self, index): + x = self.data['x_encoder'][index] + y_traj = self.data['y_decoder'][index] + y_inten = self.data['y_intent'][index] + start_decode = self.data['start_decode'][index] # this is for incremental decoding + pred_start_pos = self.data['pred_start_pos'][index] + x_len = self.data['x_traj_len'][index] + x_mask = np.zeros(shape=x.shape[0],dtype=np.int) + bias = self.input_time_step - x_len + # left pad + if bias>0: + x_mask[:bias] = 1 + x[:bias] = 0 + + return (x, y_traj, y_inten, start_decode, pred_start_pos, x_mask) + + def __len__(self): + return len(self.data['x_encoder']) + +def get_data_loader(params, mode='train',pin_memory=False): + if mode == 'train': + train_data = Trajectory_Data(params, mode='train') + data_stats = train_data.data['data_stats'] + train_loader = torch.utils.data.DataLoader( + train_data, batch_size=params['batch_size'], shuffle=True, + drop_last=True, pin_memory=pin_memory) + + valid_data = Trajectory_Data(params, mode='valid',data_stats=data_stats) + valid_loader = torch.utils.data.DataLoader( + valid_data, batch_size=params['batch_size'], shuffle=False, + drop_last=False, pin_memory=pin_memory) + + test_data = Trajectory_Data(params, mode='test',data_stats=data_stats) + test_loader = torch.utils.data.DataLoader( + test_data, batch_size=params['batch_size'], shuffle=False, + drop_last=False, pin_memory=pin_memory) + + for k, v in data_stats.items(): + data_stats[k] = [float(x) for x in v] + params['data_stats'] = data_stats + params['print_step'] = max(1,len(train_loader) // 10) + return train_loader, valid_loader, test_loader, params + + elif mode == 'test': + data_stats = params['data_stats'] + for k,v in data_stats.items(): + data_stats[k] = np.array(v) + test_data = Trajectory_Data(params, mode='test',data_stats=data_stats) + test_loader = torch.utils.data.DataLoader( + test_data, batch_size=params['batch_size'], shuffle=False, + num_workers=1,drop_last=False, pin_memory=pin_memory) + return test_loader + elif mode == 'valid': + data_stats = params['data_stats'] + for k,v in data_stats.items(): + data_stats[k] = np.array(v) + test_data = Trajectory_Data(params, mode='valid',data_stats=data_stats) + test_loader = torch.utils.data.DataLoader( + test_data, batch_size=params['batch_size'], shuffle=False, + num_workers=1,drop_last=False, pin_memory=pin_memory) + return test_loader + else: + return None + diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/fc_model.py b/models/fc_model.py new file mode 100644 index 0000000..b12bc18 --- /dev/null +++ b/models/fc_model.py @@ -0,0 +1,102 @@ +import torch +import torch.nn as nn + + +class FCEncoder(nn.Module): + """Fully connected encoder.""" + + def __init__(self, feat_dim, max_seq_len, hidden_dims, dropout=0.0, act_fn=nn.ReLU, **kwargs): + super(FCEncoder, self).__init__() + self.feat_dim = feat_dim + + self.layers = nn.ModuleList([]) + input_dim = feat_dim * max_seq_len + for hd in hidden_dims: + self.layers.append(Linear(input_dim, hd)) + self.layers.append(nn.Dropout(dropout)) + self.layers.append(act_fn()) + input_dim = hd + + self.embed_inputs = None + self.output_units = hidden_dims[-1] + + def forward(self, src_seq): + batch_size = src_seq.size(0) + x = src_seq.view(batch_size, -1) + out_stack = [] + for ii, layer in enumerate(self.layers): + x = layer(x) + if (ii + 1) % 3 == 0: # fc, drop out, relu + out_stack.append(x) + out_stack = torch.cat(out_stack, dim=1) + return x, out_stack + + +class FCDecoder(nn.Module): + """FC decoder.""" + + def __init__(self, feat_dim, max_seq_len, hidden_dims, dropout=0.0, act_fn=nn.ReLU, + encoder_output_units=10, traj_attn_intent_dim=0, **kwargs): + super(FCDecoder, self).__init__() + self.feat_dim = feat_dim + self.max_seq_len = max_seq_len + self.traj_attn_intent_dim = traj_attn_intent_dim + + self.layers = nn.ModuleList([]) + input_dim = encoder_output_units + for hd in hidden_dims: + self.layers.append(Linear(input_dim, hd)) + self.layers.append(nn.Dropout(dropout)) + self.layers.append(act_fn()) + input_dim = hd + + self.output_projection = Linear(input_dim, max_seq_len * feat_dim) + + if traj_attn_intent_dim > 0: + self.traj_attn_fc = Linear(input_dim, traj_attn_intent_dim) + + def forward(self, encoder_outs): + x = encoder_outs + for layer in self.layers: + x = layer(x) + hidden_out = x + if self.traj_attn_intent_dim > 0: + hidden_out = self.traj_attn_fc(hidden_out) + x = self.output_projection(x) + x = x.view(-1, self.max_seq_len, self.feat_dim) + return x, hidden_out + + +class FCClassifier(nn.Module): + """FC classifier.""" + + def __init__(self, encoder_output_units, hidden_dims, dropout=0.0, act_fn=nn.ReLU, + num_class=11, traj_attn_intent_dim=0, **kwargs): + super(FCClassifier, self).__init__() + + self.layers = nn.ModuleList([]) + input_dim = encoder_output_units + traj_attn_intent_dim + for hd in hidden_dims: + self.layers.append(Linear(input_dim, hd)) + self.layers.append(nn.Dropout(dropout)) + self.layers.append(act_fn()) + input_dim = hd + + self.output_projection = Linear(input_dim, num_class, bias=False) + + def forward(self, encoder_outs): + x = encoder_outs + for layer in self.layers: + x = layer(x) + hidden=x + x = self.output_projection(x) + return x,hidden + + +def Linear(in_features, out_features, bias=True): + """Linear layer (input: N x T x C)""" + m = nn.Linear(in_features, out_features, bias=bias) + m.weight.data.uniform_(-0.1, 0.1) + if bias: + m.bias.data.uniform_(-0.1, 0.1) + return m diff --git a/models/model_factory.py b/models/model_factory.py new file mode 100644 index 0000000..5dc76bc --- /dev/null +++ b/models/model_factory.py @@ -0,0 +1,106 @@ +import torch +from torch import nn + +from .fc_model import FCEncoder, FCDecoder, FCClassifier +from .pooling_layer import AvgPooling, LastPooling, LinearSeqAttnPooling, NoPooling +from .rnn_model import RNNEncoder, RNNDecoder + + +class MultiTask_Model(nn.Module): + def __init__(self, encoder_type, decoder_type,pool_type, params): + super(MultiTask_Model, self).__init__() + + self.encoder_type = encoder_type + self.pool_type = pool_type + self.decoder_type = decoder_type + self.params = params + self.train_param=self.params.train_param() + self.traj_attn_intent_dim = self.train_param['traj_attn_intent_dim'] + + self.encoder = self._create_encoder(self.encoder_type) + self.enc_out_units = self.encoder.output_units + + self.decoder = self._create_decoder(self.decoder_type) + + self.clf_pool = self._create_pooling(self.pool_type) + self.classifier = self._create_decoder(decoder_type='classifier') + + if self.traj_attn_intent_dim>0: + self.attn_pool = self._create_pooling(self.pool_type,input_size=self.traj_attn_intent_dim) + + + def _create_encoder(self, encoder_type): + # create encoder + if encoder_type == 'rnn': + rnn_param = self.params.encode_rnn_param() + encoder = RNNEncoder(**rnn_param) + else: + fc_param = self.params.encode_fc_param() + encoder = FCEncoder(**fc_param) + return encoder + + def _create_pooling(self, pool_type,input_size=None): + if input_size is None: + input_size=self.enc_out_units + if pool_type == 'mean' or pool_type == 'avg': + pool = AvgPooling() + elif pool_type == 'last': + pool = LastPooling() + elif pool_type == 'linear_attn': + pool = LinearSeqAttnPooling(input_size=input_size) + else: + pool = NoPooling() + return pool + + def _create_decoder(self, decoder_type): + if decoder_type == 'rnn': + rnn_params = self.params.decode_rnn_param() + decoder = RNNDecoder(encoder_output_units=self.enc_out_units,traj_attn_intent_dim=self.traj_attn_intent_dim, + **rnn_params) + elif decoder_type == 'classifier': + clf_params = self.params.classifier_fc_param() + decoder = FCClassifier(encoder_output_units=self.enc_out_units,traj_attn_intent_dim=self.traj_attn_intent_dim, + **clf_params) + else: + fc_param = self.params.decode_fc_param() + decoder = FCDecoder(encoder_output_units=self.enc_out_units,traj_attn_intent_dim=self.traj_attn_intent_dim, + **fc_param) + + return decoder + + def forward(self, src_seq,start_decode=None,encoder_mask=None): + + enc = self.encoder(src_seq) + encoder_out, encoder_state= enc + + if self.decoder_type == 'rnn': + out_traj, hidden_out_traj = self.decoder(enc, start_decode,encoder_mask=encoder_mask) + else: + out_traj,hidden_out_traj = self.decoder(encoder_out) + + + clf_inp = self.clf_pool(encoder_out,x_mask=encoder_mask) + if self.traj_attn_intent_dim>0: + hidden_out_traj = self.attn_pool(hidden_out_traj) + clf_inp = torch.cat([clf_inp, hidden_out_traj], dim=1) + out_intent,_ = self.classifier(clf_inp) + + return out_traj, out_intent + +def create_model(params): + train_params = params.train_param() + if train_params['init_model'] is not None: + model = torch.load(train_params['init_model']) + print('load model', train_params['init_model']) + else: + model = MultiTask_Model( + encoder_type=train_params['encoder'], + pool_type=train_params['pool_type'], + decoder_type=train_params['decoder'], + params=params) + + param_num = sum([p.data.nelement() for p in model.parameters()]) + print("Number of model parameters: {} M".format(param_num / 1024. / 1024.)) + model.train() + + return model diff --git a/models/pooling_layer.py b/models/pooling_layer.py new file mode 100644 index 0000000..f9bad41 --- /dev/null +++ b/models/pooling_layer.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class MaxPooling(nn.Module): + def forward(self, x, x_mask=None): + if x_mask is None or x_mask.data.sum() == 0: + return torch.max(x, 1)[0] + else: + lengths = (1 - x_mask).sum(1) + return torch.cat([torch.max(i[:l], dim=0)[0].view(1, -1) for i, l in zip(x, lengths)], dim=0) + +class AvgPooling(nn.Module): + def forward(self, x, x_mask=None): + if x_mask is None or x_mask.data.sum() == 0: + return torch.mean(x, 1) + else: + lengths = (1 - x_mask).sum(1) + return torch.cat([torch.mean(i[:l], dim=0)[0].view(1, -1) for i, l in zip(x, lengths)], dim=0) + +class LastPooling(nn.Module): + def __init__(self): + super(LastPooling, self).__init__() + + def forward(self, x, x_mask=None): + if x_mask is None or x_mask.data.sum() == 0: + return x[:, -1, :] + else: + lengths = (1 - x_mask).sum(1) + return torch.cat([i[l - 1, :] for i, l in zip(x, lengths)], dim=0).view(x.size(0), -1) + +class LinearSeqAttnPooling(nn.Module): + """Self attention over a sequence: + + * o_i = softmax(Wx_i) for x_i in X. + """ + + def __init__(self, input_size,bias=False): + super(LinearSeqAttnPooling, self).__init__() + self.linear = nn.Linear(input_size, 1,bias=bias) + + def forward(self, x, x_mask=None): + """ + Args: + x: batch * len * hdim + x_mask: batch * len (1 for padding, 0 for true) + Output: + alpha: batch * len + """ + # TODO why need contiguous + x = x.contiguous() + x_flat = x.view(-1, x.size(-1)) + scores = self.linear(x_flat).view(x.size(0), x.size(1)) + if x_mask is not None: + scores.data.masked_fill_(x_mask.data, -float('inf')) + alpha = F.softmax(scores, dim=-1) + self.alpha = alpha + return alpha.unsqueeze(1).bmm(x).squeeze(1) + +class NoPooling(nn.Module): + # placeholder for identity mapping + def forward(self, x, x_mask=None): + return x + diff --git a/models/rnn_model.py b/models/rnn_model.py new file mode 100644 index 0000000..77bf71c --- /dev/null +++ b/models/rnn_model.py @@ -0,0 +1,268 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class RNNEncoder(nn.Module): + """RNN encoder.""" + + def __init__( + self, cell_type='lstm', feat_dim=3, hidden_size=128, num_layers=1, + dropout_fc=0.1, dropout_rnn=0.0, bidirectional=False,**kwargs): + super(RNNEncoder, self).__init__() + self.cell_type = cell_type.lower() + self.feat_dim = feat_dim + self.num_layers = num_layers + self.dropout_fc = dropout_fc + self.dropout_rnn = dropout_rnn + self.bidirectional = bidirectional + self.hidden_size = hidden_size + input_size = feat_dim + + if self.cell_type == 'lstm': + self.rnn = LSTM(input_size=input_size, hidden_size=hidden_size, + num_layers=num_layers, bidirectional=bidirectional, + dropout=self.dropout_rnn) + elif self.cell_type == 'gru': + self.rnn = GRU(input_size=input_size, hidden_size=hidden_size, + num_layers=num_layers, bidirectional=bidirectional, + dropout=self.dropout_rnn) + else: + self.rnn = RNN(input_size=input_size, hidden_size=hidden_size, + num_layers=num_layers, bidirectional=bidirectional, + dropout=self.dropout_rnn) + + self.output_units = hidden_size + if bidirectional: + self.output_units *= 2 + + def forward(self, src_seq): + x = src_seq + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + encoder_out, h_t = self.rnn(x) + + if self.cell_type == 'lstm': + final_hiddens, final_cells = h_t + else: + final_hiddens, final_cells = h_t, None + + if self.dropout_fc>0: + encoder_out = F.dropout(encoder_out, p=self.dropout_fc, training=self.training) + + if self.bidirectional: + batch_size = src_seq.size(0) + + def combine_bidir(outs): + out = outs.view(self.num_layers, 2, batch_size, -1).transpose(1, 2).contiguous() + return out.view(self.num_layers, batch_size, -1) + + final_hiddens = combine_bidir(final_hiddens) + if self.cell_type == 'lstm': + final_cells = combine_bidir(final_cells) + + # T x B x C -> B x T x C + encoder_out = encoder_out.transpose(0, 1) + final_hiddens = final_hiddens.transpose(0, 1) + if self.cell_type == 'lstm': + final_cells = final_cells.transpose(0, 1) + + return encoder_out, (final_hiddens, final_cells) + + +class AttentionLayer(nn.Module): + def __init__(self, input_embed_dim, source_embed_dim, output_embed_dim, bias=False): + super(AttentionLayer, self).__init__() + + self.input_proj = Linear(input_embed_dim, source_embed_dim, bias=bias) + self.output_proj = Linear(input_embed_dim + source_embed_dim, output_embed_dim, bias=bias) + + def forward(self, input, source_hids, encoder_padding_mask=None): + # input: bsz x input_embed_dim + # source_hids: srclen x bsz x output_embed_dim + + # x: bsz x output_embed_dim + x = self.input_proj(input) + + # compute attention + attn_scores = (source_hids * x.unsqueeze(0)).sum(dim=2) + + # don't attend over padding + if encoder_padding_mask is not None: + attn_scores = attn_scores.float().masked_fill_( + encoder_padding_mask, + float('-inf') + ).type_as(attn_scores) # FP16 support: cast to float and back + + attn_scores = F.softmax(attn_scores, dim=0) # srclen x bsz + + # sum weighted sources + x = (attn_scores.unsqueeze(2) * source_hids).sum(dim=0) + out = torch.cat((x, input), dim=1) + x = F.tanh(self.output_proj(out)) + return x, attn_scores + + +class RNNDecoder(nn.Module): + """RNN decoder.""" + + def __init__( + self, cell_type='lstm', feat_dim=3, hidden_size=128, num_layers=1, + dropout_fc=0.1, dropout_rnn=0.0, encoder_output_units=128, max_seq_len=10, + attention=True, traj_attn_intent_dim=0,**kwargs): + super(RNNDecoder, self).__init__() + self.cell_type = cell_type.lower() + self.dropout_fc = dropout_fc + self.dropout_rnn = dropout_rnn + self.hidden_size = hidden_size + self.encoder_output_units = encoder_output_units + self.num_layers = num_layers + self.max_seq_len = max_seq_len + self.feat_dim = feat_dim + self.traj_attn_intent_dim =traj_attn_intent_dim + input_size = feat_dim + + if encoder_output_units != hidden_size: + self.encoder_hidden_proj = Linear(encoder_output_units, hidden_size) + if self.cell_type == 'lstm': + self.encoder_cell_proj = Linear(encoder_output_units, hidden_size) + else: + self.encoder_cell_proj = None + else: + self.encoder_hidden_proj = self.encoder_cell_proj = None + + if self.cell_type == 'lstm': + self.cell = LSTM + elif self.cell_type == 'gru': + self.cell = GRU + else: + self.cell = RNN + + self.rnn = self.cell(input_size=input_size, + hidden_size=hidden_size, bidirectional=False, + dropout=self.dropout_rnn, + num_layers=num_layers) + + if attention: + self.attention = AttentionLayer(hidden_size, encoder_output_units, hidden_size, bias=False) + else: + self.attention = None + + self.output_projection = Linear(hidden_size, feat_dim) + + if traj_attn_intent_dim>0: + self.traj_attn_fc = Linear(hidden_size, traj_attn_intent_dim) + + def forward(self, encoder_out_list, start_decode=None, encoder_mask=None): + + x = start_decode.unsqueeze(1) + bsz = x.size(0) + + # get outputs from encoder + encoder_outs, (encoder_hiddens, encoder_cells) = encoder_out_list + # B x T x C -> T x B x C + encoder_outs = encoder_outs.transpose(0, 1) + encoder_hiddens = encoder_hiddens.transpose(0, 1) + if encoder_mask is not None: + encoder_mask = encoder_mask.transpose(0, 1) + prev_hiddens = [encoder_hiddens[i] for i in range(self.num_layers)] + if self.cell_type == 'lstm': + encoder_cells = encoder_cells.transpose(0, 1) + prev_cells = [encoder_cells[i] for i in range(self.num_layers)] + + x = x.transpose(0, 1) + srclen = encoder_outs.size(0) + + # initialize previous states + + if self.encoder_hidden_proj is not None: + prev_hiddens = [self.encoder_hidden_proj(x) for x in prev_hiddens] + prev_hiddens = torch.stack(prev_hiddens, dim=0) + if self.encoder_cell_proj is not None: + prev_cells = [self.encoder_cell_proj(x) for x in prev_cells] + if self.cell_type == 'lstm': + prev_cells = torch.stack(prev_cells, dim=0) + + attn_scores = x.new_zeros(srclen, self.max_seq_len, bsz) + inp = x + outs = [] + hidden_outs=[] + for j in range(self.max_seq_len): + if self.cell_type == 'lstm': + output, (prev_hiddens, prev_cells) = self.rnn(inp, (prev_hiddens, prev_cells)) + else: + output, prev_hiddens = self.rnn(inp, prev_hiddens) + output = output.view(bsz, -1) + # apply attention using the last layer's hidden state + if self.attention is not None: + out, attn_scores[:, j, :] = self.attention(output, encoder_outs, encoder_mask) + else: + out = output + if self.dropout_fc>0: + out = F.dropout(out, p=self.dropout_fc, training=self.training) + hid_out = out + if self.traj_attn_intent_dim > 0: + hid_out= self.traj_attn_fc(hid_out) + hid_out = F.selu(hid_out) + hidden_outs.append(hid_out) + + out = self.output_projection(out) + # save final output + outs.append(out) + + inp = out.unsqueeze(0) + + # collect outputs across time steps + x = torch.cat(outs, dim=0).view(self.max_seq_len, bsz, self.feat_dim) + hidden_outs = torch.cat(hidden_outs, dim=0).view(self.max_seq_len, bsz, -1) + # T x B x C -> B x T x C + x = x.transpose(1, 0) + hidden_outs=hidden_outs.transpose(1, 0) + # srclen x tgtlen x bsz -> bsz x tgtlen x srclen + attn_scores = attn_scores.transpose(0, 2) + + # project back to input space + + return x, hidden_outs + + +def LSTM(input_size, hidden_size, **kwargs): + m = nn.LSTM(input_size, hidden_size, **kwargs) + for name, param in m.named_parameters(): + if 'weight' in name or 'bias' in name: + param.data.uniform_(-0.1, 0.1) + return m + + +def GRU(input_size, hidden_size, **kwargs): + m = nn.GRU(input_size, hidden_size, **kwargs) + for name, param in m.named_parameters(): + if 'weight' in name or 'bias' in name: + param.data.uniform_(-0.1, 0.1) + return m + + +def RNN(input_size, hidden_size, **kwargs): + m = nn.RNN(input_size, hidden_size, **kwargs) + for name, param in m.named_parameters(): + if 'weight' in name or 'bias' in name: + param.data.uniform_(-0.1, 0.1) + return m + + +def Linear(in_features, out_features, bias=True): + """Linear layer (input: N x T x C)""" + m = nn.Linear(in_features, out_features, bias=bias) + m.weight.data.uniform_(-0.1, 0.1) + if bias: + m.bias.data.uniform_(-0.1, 0.1) + return m + diff --git a/parameters.py b/parameters.py new file mode 100644 index 0000000..b955096 --- /dev/null +++ b/parameters.py @@ -0,0 +1,369 @@ +import json +import os + +class hyper_parameters(object): + def __init__(self, input_time_step=20, output_time_step=50, + data_dir='data/',dataset='vehicle_ngsim',model_type='rnn', + num_class=3,coordinate_dim=2,inp_feat=('traj', 'speed')): #32 + + self.input_time_step = input_time_step + self.output_time_step = output_time_step + self.coordinate_dim = coordinate_dim + self.inten_num_class = num_class + + self.inp_feat = inp_feat + self.out_feat = ('speed',) + self.encoder_feat_dim = self.coordinate_dim * len(self.inp_feat) # x, v + self.decoder_feat_dim = self.coordinate_dim * 1 # v + + self.data_dir = data_dir + self.dataset = dataset + self.model_type = model_type + self.params_dict = {} + + + def _set_default_dataset_params(self): + if self.dataset=='human_kinect': + self.inp_feat = ('traj', 'speed') + self.input_time_step = 20 + self.output_time_step = 10 + self.inten_num_class = 12 + self.coordinate_dim = 3 + self.encoder_feat_dim = self.coordinate_dim * 2 # x,v + self.decoder_feat_dim =self.coordinate_dim + + if self.dataset=='human_mocap': + self.inp_feat = ('traj', 'speed') + self.input_time_step = 20 + self.output_time_step = 10 + self.inten_num_class = 3 + self.coordinate_dim = 3 + self.encoder_feat_dim = self.coordinate_dim * 2 + self.decoder_feat_dim =self.coordinate_dim + + if self.dataset=='vehicle_holomatic': + self.inp_feat = ('feature', 'speed') + self.input_time_step = 20 + self.output_time_step = 50 + self.inten_num_class = 5 + self.coordinate_dim = 2 + self.traj_feature_dim = 8 + self.encoder_feat_dim =self.coordinate_dim + self.traj_feature_dim + self.decoder_feat_dim = self.coordinate_dim + + if self.dataset=='vehicle_ngsim': + self.inp_feat = ('feature', 'speed') + self.input_time_step = 20 + self.output_time_step = 50 + self.inten_num_class = 3 + self.coordinate_dim = 2 + self.traj_feature_dim = 4 + self.encoder_feat_dim =self.coordinate_dim + self.traj_feature_dim + self.decoder_feat_dim = self.coordinate_dim + + def train_param(self, param_dict=None): + default_train_params = dict( + dataset=self.dataset, + data_path=self.data_dir + self.dataset + '.pkl', + save_dir='output/'+self.dataset+'/'+self.model_type +'/', + init_model=None, + normalize_data=True, + input_time_step=self.input_time_step, + output_time_step=self.output_time_step, + inp_feat = self.inp_feat, + + traj_intent_loss_ratio=[1, 0.1], # traj loss : intent loss + lr=0.01, + lr_schedule='multistep', # multistep + lr_decay_epochs=[7, 14], + lr_decay=0.1, + epochs=20, + batch_size=128, + + coordinate_dim=self.coordinate_dim, + encoder=self.model_type, + encoder_feat_dim=self.encoder_feat_dim, + + decoder=self.model_type, + decoder_feat_dim = self.decoder_feat_dim, + + + class_num=self.inten_num_class, + pool_type='linear_attn', + label_smooth=0.1, + traj_attn_intent_dim=64, + ) + if param_dict is None and 'train_param' in self.params_dict: + param_dict = self.params_dict['train_param'] + params = self._overwrite_params(default_train_params, param_dict) + params['log_dir'] = params['save_dir'] + 'log/' + dir_split = params['log_dir'].replace('\\','/').split('/') + base_dir='' + for _path in dir_split: + base_dir =os.path.join(base_dir,_path) + if not os.path.exists(base_dir): + os.mkdir(base_dir) + + if params['encoder'] == 'fc': + params['pool_type'] = 'none' + + return params + + def encode_rnn_param(self, param_dict=None): + default_rnn_params = dict( + cell_type='gru', + feat_dim=self.encoder_feat_dim, + max_seq_len=self.input_time_step, + hidden_size=64, + num_layers=1, + dropout_fc=0., + dropout_rnn=0., + bidirectional=False, + ) + if param_dict is None and 'encode_rnn_param' in self.params_dict: + param_dict = self.params_dict['encode_rnn_param'] + param = self._overwrite_params(default_rnn_params, param_dict) + return param + + def encode_fc_param(self, param_dict=None): + default_fc_params = dict( + feat_dim=self.encoder_feat_dim, + max_seq_len=self.input_time_step, + hidden_dims=[64,64],#[128,128],#[128,128,64] + dropout=0., + ) + if param_dict is None and 'encode_fc_param' in self.params_dict: + param_dict = self.params_dict['encode_fc_param'] + param = self._overwrite_params(default_fc_params, param_dict) + return param + + def decode_rnn_param(self, param_dict=None): + default_rnn_params = dict( + cell_type='gru', + feat_dim = self.decoder_feat_dim, + max_seq_len=self.output_time_step, + hidden_size=64, + num_layers=1, + dropout_fc=0., + dropout_rnn=0., + attention=True, + ) + if param_dict is None and 'decode_rnn_param' in self.params_dict: + param_dict = self.params_dict['decode_rnn_param'] + param = self._overwrite_params(default_rnn_params, param_dict) + return param + + def decode_fc_param(self, param_dict=None): + default_fc_params = dict( + feat_dim=self.decoder_feat_dim, + max_seq_len=self.output_time_step, + hidden_dims=[64,64], + dropout=0., + ) + if param_dict is None and 'decode_fc_param' in self.params_dict: + param_dict = self.params_dict['decode_fc_param'] + param = self._overwrite_params(default_fc_params, param_dict) + return param + + def classifier_fc_param(self, param_dict=None): + default_fc_params = dict( + hidden_dims=[64], + dropout=0., + num_class=self.inten_num_class, + ) + if param_dict is None and 'classifier_fc_param' in self.params_dict: + param_dict = self.params_dict['classifier_fc_param'] + param = self._overwrite_params(default_fc_params, param_dict) + return param + + def print_params(self): + print('train parameters:') + t_param = self.train_param() + print(t_param) + print('encode_param:') + encode_param = self.encode_fc_param() if t_param['encoder'] == 'fc' else self.encode_rnn_param() + print(encode_param) + print('decode_param:') + decode_param = self.decode_fc_param() if t_param['decoder'] == 'fc' else self.decode_rnn_param() + print(decode_param) + print('classifier_fc_param:') + print(self.classifier_fc_param()) + + def _overwrite_params(self, old_param, new_param): + if new_param is None: + return old_param + for k, v in new_param.items(): + old_param[k] = v + return old_param + + def _save_parameters(self, log_dir=None): + params_dict = {} + params_dict['train_param'] = self.train_param() + params_dict['encode_rnn_param'] = self.encode_rnn_param() + params_dict['encode_fc_param'] = self.encode_fc_param() + params_dict['decode_rnn_param'] = self.decode_rnn_param() + params_dict['decode_fc_param'] = self.decode_fc_param() + params_dict['classifier_fc_param'] = self.classifier_fc_param() + + if log_dir is None: + log_dir = params_dict['train_param']['log_dir'] + + with open(log_dir + 'hyper_parameters.json', 'w') as f: + json.dump(params_dict, f) + + def _save_overwrite_parameters(self, params_key, params_value, log_dir=None): + params_dict = {} + params_dict['train_param'] = self.train_param() + params_dict['encode_rnn_param'] = self.encode_rnn_param() + params_dict['encode_fc_param'] = self.encode_fc_param() + params_dict['decode_rnn_param'] = self.decode_rnn_param() + params_dict['decode_fc_param'] = self.decode_fc_param() + params_dict['classifier_fc_param'] = self.classifier_fc_param() + + params_dict[params_key] = params_value + + if log_dir is None: + log_dir = params_dict['train_param']['log_dir'] + + with open(log_dir + 'hyper_parameters.json', 'w') as f: + json.dump(params_dict, f) + + def _load_parameters(self, log_dir=None): + if log_dir is None: + log_dir = self.train_param()['log_dir'] + + with open(log_dir + 'hyper_parameters.json', 'r') as f: + self.params_dict = json.load(f) + + +class adapt_hyper_parameters(object): + def __init__(self, adaptor='none',adapt_step=1,log_dir=None): + self.adaptor = adaptor + self.adapt_step = adapt_step + self.log_dir = log_dir + self.params_dict = {} + + adaptor=adaptor.lower() + if adaptor=='nrls' or adaptor=='mekf' or adaptor=='mekf_ma': + self.adapt_param=self.mekf_param + elif adaptor=='sgd': + self.adapt_param=self.sgd_param + elif adaptor=='adam': + self.adapt_param=self.adam_param + elif adaptor=='lbfgs': + self.adapt_param=self.lbfgs_param + + def strategy_param(self,param_dict=None): + default_params = dict( + adapt_step=self.adapt_step, + use_multi_epoch=True, + multiepoch_thresh=(-1, -1), + ) + if param_dict is None and 'strategy_param' in self.params_dict: + param_dict = self.params_dict['strategy_param'] + params = self._overwrite_params(default_params, param_dict) + return params + + def mekf_param(self, param_dict=None): + default_params = dict( + p0=1e-2, # 1e-2 + lbd=1-1e-6, # 1 + sigma_r=1, + sigma_q=0, + lr=1, # 1 + + miu_v=0, #momentum + miu_p=0, # EMA of P + k_p=1, #look ahead of P + + use_lookahead=False, # outer lookahead + la_k=1, # outer lookahead + la_alpha=1, + ) + if param_dict is None and 'mekf_param' in self.params_dict: + param_dict = self.params_dict['mekf_param'] + params = self._overwrite_params(default_params, param_dict) + return params + + def sgd_param(self, param_dict=None): + default_params = dict( + lr=1e-6, + momentum=0.7, + nesterov=False, + + use_lookahead=False, #look ahead + la_k=5, + la_alpha = 0.8, + ) + if param_dict is None and 'sgd_param' in self.params_dict: + param_dict = self.params_dict['sgd_param'] + param = self._overwrite_params(default_params, param_dict) + return param + + def adam_param(self, param_dict=None): + default_params = dict( + lr=1e-6, + betas=(0.1, 0.99), + amsgrad=True, + + use_lookahead=False, #look ahead + la_k=5, + la_alpha=0.8, + ) + if param_dict is None and 'adam_param' in self.params_dict: + param_dict = self.params_dict['adam_param'] + param = self._overwrite_params(default_params, param_dict) + return param + + def lbfgs_param(self, param_dict=None): + default_params = dict( + lr=0.002, + max_iter=20, + history_size=100, + + use_lookahead=False, #look ahead + la_k=1, + la_alpha=1.0, + ) + if param_dict is None and 'lbfgs_param' in self.params_dict: + param_dict = self.params_dict['lbfgs_param'] + param = self._overwrite_params(default_params, param_dict) + return param + + def print_params(self): + print('adaptation optimizer:',self.adaptor) + print('adaptation strategy parameters:') + print(self.strategy_param()) + if self.adaptor=='none': + print('no adaptation') + else: + print('adaptation optimizer parameters:') + print(self.adapt_param()) + + def _overwrite_params(self, old_param, new_param): + if new_param is None: + return old_param + for k, v in new_param.items(): + old_param[k] = v + return old_param + + def _save_parameters(self, log_dir=None): + params_dict = {} + params_dict['strategy_param'] = self.strategy_param() + params_dict['mekf_param'] = self.mekf_param() + params_dict['sgd_param'] = self.sgd_param() + params_dict['adam_param'] = self.adam_param() + params_dict['lbfgs_param'] = self.lbfgs_param() + + if log_dir is None: + log_dir = self.log_dir + + with open(log_dir + 'adapt_hyper_parameters.json', 'w') as f: + json.dump(params_dict, f) + + def _load_parameters(self, log_dir=None): + if log_dir is None: + log_dir = self.log_dir + + with open(log_dir + 'adapt_hyper_parameters.json', 'r') as f: + self.params_dict = json.load(f) \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..99f1bf9 --- /dev/null +++ b/train.py @@ -0,0 +1,168 @@ +# coding=utf-8 + +import os +import warnings + +import torch +from parameters import hyper_parameters +from dataset.dataset import get_data_loader +from models.model_factory import create_model +from utils.pred_utils import get_prediction_on_batch, get_predictions,get_position +from utils.train_utils import CrossEntropyLoss, get_lr_schedule + +os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' +os.environ['CUDA_VISIBLE_DEVICES'] = "0" +warnings.filterwarnings("ignore") + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +print('training with device:', device) + + +def evaluate(model, data_loader, criterion_traj, criterion_intend, params, epoch=0, mark='valid'): + data_stats = {'data_mean': params['data_mean'], 'data_std': params['data_std']} + print('[Evaluation %s Set] -------------------------------' % mark) + out_str = "epoch: %d, " % (epoch) + with torch.no_grad(): + dat = get_predictions(data_loader, model, device) + traj_hist, traj_preds, traj_labels, intent_preds, intent_labels, pred_start_pos = dat + + loss_traj = criterion_traj(traj_preds, traj_labels) + loss_traj = loss_traj.cpu().detach().numpy() + traj_preds = get_position(traj_preds, pred_start_pos, data_stats) + traj_labels = get_position(traj_labels, pred_start_pos, data_stats) + mse = (traj_preds - traj_labels).pow(2).sum().float() / (traj_preds.size(0) * traj_preds.size(1)) + mse = mse.cpu().detach().numpy() + out_str += "trajectory_loss: %.4f, trajectory_mse: %.4f, " % (loss_traj, mse) + + loss_intent = criterion_intend(intent_preds, intent_labels) + loss_intent = loss_intent.cpu().detach().numpy() + _, pred_intent_cls = intent_preds.max(1) + label_cls = intent_labels + acc = (pred_intent_cls == label_cls).sum().float() / label_cls.size(0) + acc = acc.cpu().detach().numpy() + out_str += "intent_loss: %.4f, intent_acc: %.4f, " % (loss_intent, acc) + + print(out_str) + print('-------------------------------') + + log_dir = params['log_dir'] + if not os.path.exists(log_dir + '%s.tsv' % mark): + with open(log_dir + 'test.tsv', 'a') as f: + f.write('epoch\ttraj_loss\tintent_loss\tmse\tacc\n') + + with open(log_dir + '%s.tsv' % mark, 'a') as f: + f.write('%05d\t%f\t%f\t%f\t%f\n' % (epoch, loss_traj, loss_intent, mse, acc)) + return acc, mse + + +def train_on_batch(data, model, optimizer, criterion_traj, criterion_intend, params, print_result=False, epoch=0, + iter=0): + optimizer.zero_grad() + x, pred_traj, y_traj, pred_intent, y_intent, pred_start_pos = get_prediction_on_batch(data, model, device) + + loss_traj = criterion_traj(pred_traj, y_traj) + loss_intent = criterion_intend(pred_intent, y_intent) + loss = params['traj_intent_loss_ratio'][0] * loss_traj + params['traj_intent_loss_ratio'][1] * loss_intent + + loss.backward() + _ = torch.nn.utils.clip_grad_norm_(model.parameters(), 10) + optimizer.step() + + if print_result: + data_stats = {'data_mean': params['data_mean'], 'data_std': params['data_std']} + out_str = "epoch: %d, iter: %d, loss: %.4f " % (epoch, iter,loss.detach().cpu().numpy()) + + pred_traj = get_position(pred_traj, pred_start_pos, data_stats) + y_traj = get_position(y_traj, pred_start_pos, data_stats) + mse = (pred_traj - y_traj).pow(2).sum().float() / (pred_traj.size(0) * pred_traj.size(1)) + mse = mse.cpu().detach().numpy() + loss_traj_val = loss_traj.cpu().detach().numpy() + out_str += "trajectory_loss: %.4f, trajectory_mse: %.4f, " % (loss_traj_val, mse) + + _, pred_intent_cls = pred_intent.max(1) + label_cls = y_intent + acc = (pred_intent_cls == label_cls).sum().float() / label_cls.size(0) + acc = acc.cpu().detach().numpy() + loss_intent_val = loss_intent.cpu().detach().numpy() + out_str += "intent_loss: %.4f, intent_acc: %.4f, " % (loss_intent_val, acc) + + print(out_str) + log_path = params['log_dir'] + 'train.tsv' + if not os.path.exists(log_path): + with open(log_path, 'a') as f: + f.write('epoch\titer\ttraj_loss\tintent_loss\tmse\tacc\n') + + with open(log_path, 'a') as f: + f.write('%05d\t%05d\t%f\t%f\t%f\t%f\n' % (epoch, iter, loss_traj_val, loss_intent_val, mse, acc)) + + return loss + + +def train(params): + train_params = params.train_param() + + train_loader, valid_loader, test_loader, train_params = get_data_loader(train_params, mode='train') + params._save_overwrite_parameters(params_key='train_param', params_value=train_params) + + train_params['data_mean'] = torch.tensor(train_params['data_stats']['speed_mean'], dtype=torch.float).unsqueeze( + 0).to(device) + train_params['data_std'] = torch.tensor(train_params['data_stats']['speed_std'], dtype=torch.float).unsqueeze(0).to( + device) + + model = create_model(params) + model = model.to(device) + + criterion_traj = torch.nn.MSELoss(reduction='mean').to(device) + criterion_intend = CrossEntropyLoss(class_num=train_params['class_num'], + label_smooth=train_params['label_smooth']).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=train_params['lr']) + + scheduler = get_lr_schedule(train_params['lr_schedule'], train_params, optimizer) + + best_result = {'valid_acc': 0, 'valid_mse': 99999, 'test_acc': 0, 'test_mse': 99999, 'epoch': 0} + print('begin to train') + for epoch in range(1, train_params['epochs'] + 1): + for i, data in enumerate(train_loader, 0): + print_result = True if i % train_params['print_step'] == 0 else False + train_on_batch(data, model, optimizer, criterion_traj, criterion_intend, params=train_params, + print_result=print_result, epoch=epoch, iter=i) + + + save_model_path = os.path.join(train_params['save_dir'], 'model_%d.pkl' % (epoch)) + torch.save(model, save_model_path) + print('save model to', save_model_path) + + + model.eval() + valid_acc, valid_mse = evaluate(model, valid_loader, criterion_traj, criterion_intend, params=train_params, + epoch=epoch, + mark='valid') + test_acc, test_mse = evaluate(model, test_loader, criterion_traj, criterion_intend, params=train_params, + epoch=epoch, + mark='test') + model.train() + if valid_mse < best_result['valid_mse'] or valid_acc > best_result['valid_acc']: + best_result['valid_mse'] = valid_mse + best_result['valid_acc'] = valid_acc + best_result['test_mse'] = test_mse + best_result['test_acc'] = test_acc + best_result['epoch'] = epoch + + if scheduler is not None: + scheduler.step(epoch) + + print('Best Results (epoch %d):' % best_result['epoch']) + print('validation_acc = %f, validation_mse = %f, test_acc = %f, test_mse = %f' + % (best_result['valid_acc'], best_result['valid_mse'], best_result['test_acc'], best_result['test_mse'])) + return model + + +def main(): + params = hyper_parameters(dataset='vehicle_ngsim', model_type='rnn') + params._set_default_dataset_params() + params.print_params() + train(params) + + +if __name__ == '__main__': + main() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/adapt_utils.py b/utils/adapt_utils.py new file mode 100644 index 0000000..428d978 --- /dev/null +++ b/utils/adapt_utils.py @@ -0,0 +1,143 @@ +from time import time +import numpy as np +import torch +from .pred_utils import get_prediction_on_batch + +data_size=100 +def batch2iter_data(dataloader, device='cpu',data_size=data_size): + traj_hist, traj_labels, intent_labels, start_decodes, pred_start_pos, x_mask = None, None, None, None, None, None + + for i, data in enumerate(dataloader, 0): + x, y_traj, y_intent, start_decode, start_pos, mask = data + if traj_hist is None: + traj_hist = x + traj_labels = y_traj + intent_labels = y_intent + start_decodes = start_decode + pred_start_pos = start_pos + x_mask = mask + else: + traj_hist = torch.cat([traj_hist, x], dim=0) + traj_labels = torch.cat([traj_labels, y_traj], dim=0) + intent_labels = torch.cat([intent_labels, y_intent], dim=0) + start_decodes = torch.cat([start_decodes, start_decode], dim=0) + pred_start_pos = torch.cat([pred_start_pos, start_pos], dim=0) + x_mask = torch.cat([x_mask, mask], dim=0) + + if data_size>0 and traj_hist.size(0)>data_size: + break + + traj_hist = traj_hist.float().to(device) + traj_labels = traj_labels.float().to(device) + intent_labels = intent_labels.float().to(device) + start_decodes = start_decodes.float().to(device) + pred_start_pos = pred_start_pos.float().to(device) + x_mask = x_mask.byte().to(device) + data = [traj_hist, traj_labels, intent_labels, start_decodes, pred_start_pos, x_mask] + + return data + + +def online_adaptation(dataloader, model, optimizer, params, device, + adapt_step=1, use_multi_epoch=False,multiepoch_thresh=(0, 0)): + optim_name = optimizer.__class__.__name__ + if optim_name == 'Lookahead': + optim_name = optim_name + '_' + optimizer.optimizer.__class__.__name__ + print('optimizer:', optim_name) + print('adapt_step:', adapt_step, ', use_multi_epoch:', use_multi_epoch,', multiepoch_thresh:', multiepoch_thresh), + t1 = time() + + data = batch2iter_data(dataloader, device) + traj_hist, traj_labels, intent_labels, _, pred_start_pos, _ = data + batches = [] + for ii in range(len(pred_start_pos)): + temp_batch=[] + for item in data: + temp_batch.append(item[[ii]]) + batches.append(temp_batch) + + traj_preds = torch.zeros_like(traj_labels) + intent_preds = torch.zeros(size=(len(intent_labels),params['class_num']),dtype=torch.float,device=intent_labels.device) + + temp_pred_list = [] + temp_label_list = [] + temp_data_list = [] + cnt = [0, 0, 0] + cost_list = [] + post_cost_list=[] + for t in range(len(pred_start_pos)): + batch_data = batches[t] + _, pred_traj, y_traj, pred_intent, _, _ = get_prediction_on_batch(batch_data, model, device) + traj_preds[t] = pred_traj[0].detach() + intent_preds[t] = pred_intent[0].detach() + + temp_pred_list += [pred_traj] + temp_label_list += [y_traj] + temp_data_list += [batch_data] + if len(temp_pred_list) > adapt_step: + temp_pred_list = temp_pred_list[1:] + temp_label_list = temp_label_list[1:] + temp_data_list = temp_data_list[1:] + + if t < adapt_step - 1: + continue + + Y = temp_label_list[0] + Y_hat = temp_pred_list[0] + full_loss =(Y - Y_hat).detach().pow(2).mean().cpu().numpy().round(6) + cost_list.append(full_loss) + + Y_tau = Y[:, :adapt_step].contiguous().view((-1, 1)) + Y_hat_tau = Y_hat[:, :adapt_step].contiguous().view((-1, 1)) + err = (Y_tau - Y_hat_tau).detach() + curr_cost = err.pow(2).mean().cpu().numpy() + update_epoch = 1 + if 0 <= multiepoch_thresh[0] <= multiepoch_thresh[1]: + if curr_cost< multiepoch_thresh[0]: + update_epoch=1 + elif curr_cost< multiepoch_thresh[1]: + update_epoch = 2 + else: + update_epoch = 0 + cnt[update_epoch] += 1 + for cycle in range(update_epoch): + def mekf_closure(index=0): + optimizer.zero_grad() + dim_out = optimizer.optimizer.state['dim_out'] if 'Lookahead' in optim_name else optimizer.state['dim_out'] + retain = index < dim_out - 1 + Y_hat_tau[index].backward(retain_graph=retain) + return err + + def lbfgs_closure(): + optimizer.zero_grad() + temp_data = temp_data_list[0] + _, temp_pred_traj, temp_y_traj, _, _, _ = get_prediction_on_batch(temp_data, model, device) + y_tau = temp_y_traj[:, :adapt_step].contiguous().view((-1, 1)) + y_hat_tau = temp_pred_traj[:, :adapt_step].contiguous().view((-1, 1)) + loss = (y_tau - y_hat_tau).pow(2).mean() + loss.backward() + return loss + + if 'MEKF' in optim_name: + optimizer.step(mekf_closure) + elif 'LBFGS' in optim_name: + optimizer.step(lbfgs_closure) + else: + loss = (Y_tau - Y_hat_tau).pow(2).mean() + loss.backward() + optimizer.step() + + temp_data = temp_data_list[0] + _, post_pred_traj, post_y_traj, _, _, _ = get_prediction_on_batch(temp_data, model, device) + post_loss = (post_pred_traj - post_y_traj).detach().pow(2).mean().cpu().numpy().round(6) + post_cost_list.append(post_loss) + + + if t % 10 == 0: + print('finished pred {}, time:{}, partial cost before adapt:{}, partial cost after adapt:{}'.format(t, time() - t1, full_loss,post_loss)) + t1 = time() + + print('avg_cost:', np.mean(cost_list)) + print('number of update epoch', cnt) + return traj_hist, traj_preds,traj_labels, intent_preds, intent_labels, pred_start_pos + diff --git a/utils/pred_utils.py b/utils/pred_utils.py new file mode 100644 index 0000000..4431caa --- /dev/null +++ b/utils/pred_utils.py @@ -0,0 +1,46 @@ + +import torch + +def get_prediction_on_batch(data, model, device='cpu'): + x, y_traj, y_intent, start_decode, pred_start_pos,x_mask = data + x = x.float().to(device) + y_traj = y_traj.float().to(device) + y_intent = y_intent.long().to(device) + start_decode = start_decode.float().to(device) + x_mask = x_mask.byte().to(device) + pred_start_pos = pred_start_pos.float().to(device) + + pred_traj, pred_intent = model(src_seq=x, start_decode=start_decode, encoder_mask=x_mask) + return x, pred_traj, y_traj, pred_intent, y_intent, pred_start_pos + +def get_predictions(dataloader, model, device,data_size=-1): + traj_hist, traj_preds, traj_labels, intent_preds, intent_labels, pred_start_pos = None,None,None,None,None,None + + for i, data in enumerate(dataloader, 0): + x, pred_traj, y_traj, pred_intent, y_intent, start_pos = get_prediction_on_batch(data, model, device) + if traj_hist is None: + traj_hist = x + traj_preds = pred_traj + traj_labels = y_traj + intent_preds = pred_intent + intent_labels = y_intent + pred_start_pos = start_pos + else: + traj_hist = torch.cat([traj_hist,x],dim=0) + traj_labels = torch.cat([traj_labels, y_traj], dim=0) + intent_labels = torch.cat([intent_labels, y_intent], dim=0) + pred_start_pos = torch.cat([pred_start_pos, start_pos], dim=0) + traj_preds = torch.cat([traj_preds, pred_traj], dim=0) + intent_preds = torch.cat([intent_preds, pred_intent], dim=0) + if data_size>0 and traj_hist.size(0)>data_size: + break + return traj_hist, traj_preds, traj_labels, intent_preds, intent_labels, pred_start_pos + +def get_position(speed, start_pose=None, data_stats=None): + if data_stats is None or start_pose is None: + return speed + speed = speed * data_stats['data_std'] + data_stats['data_mean'] + displacement = torch.cumsum(speed, dim=1) + start_pose = torch.unsqueeze(start_pose, dim=1) + position = displacement + start_pose + return position \ No newline at end of file diff --git a/utils/train_utils.py b/utils/train_utils.py new file mode 100644 index 0000000..339b7a7 --- /dev/null +++ b/utils/train_utils.py @@ -0,0 +1,36 @@ +import torch + + +class CrossEntropyLoss(torch.nn.Module): + """ Cross entropy that accepts label smoothing""" + + def __init__(self, class_num=12, label_smooth=0, size_average=True): + super(CrossEntropyLoss, self).__init__() + self.class_num = class_num + self.size_average = size_average + self.label_smooth = label_smooth + + def forward(self, input, target): + logsoftmax = torch.nn.LogSoftmax() + one_hot_target = torch.zeros(target.size()[0], self.class_num,device=target.device) + one_hot_target = one_hot_target.scatter_(1, target.unsqueeze(1), 1) + if self.label_smooth > 0: + one_hot_target = (1 - self.label_smooth) * one_hot_target + self.label_smooth * (1 - one_hot_target) + if self.size_average: + return torch.mean(torch.sum(-one_hot_target * logsoftmax(input), dim=1)) + else: + return torch.sum(torch.sum(-one_hot_target * logsoftmax(input), dim=1)) + + +def get_lr_schedule(lr_schedule, params, optimizer): + if lr_schedule == 'multistep': + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=params['lr_decay_epochs'], + gamma=params['lr_decay'], ) + elif lr_schedule == 'cyclic': + scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, params['lr'] / 5, params['lr'], + step_size_up=params['period'], + mode='triangular', gamma=1.0, ) + else: + scheduler = None + + return scheduler