diff --git a/README.md b/README.md
index 18eccd7..f834438 100644
--- a/README.md
+++ b/README.md
@@ -1,9 +1,14 @@
-# MEKF_MAME
-**M**odified **E**xtended **K**alman **F**ilter with generalized exponential **M**oving **A**verage and dynamic **M**ulti-**E**poch update strategy (MEKF_MAME)
+# MEKFEMA-DME
+**M**odified **E**xtended **K**alman **F**ilter with generalized **E**xponential **M**oving **A**verage and **D**ynamic **M**ulti-**E**poch update strategy (MEKFEMA-DME)
-Pytorch implementation source coder for paper "Robust Nonlinear Adaptation Algorithms for Multi-TaskPrediction Networks".
+Pytorch implementation source coder for paper [Robust Online Model Adaptation by Extended Kalman Filter with Exponential Moving Average and Dynamic Multi-Epoch Strategy](https://arxiv.org/abs/1912.01790).
-**We will release the code, once the paper is published.**
-In the paper, EKF based adaptation algorithm MEKF_λ was introduced as an effective base algorithm for online adaptation. In order to improve the convergence property of MEKF_λ, generalized exponential moving average filtering was investigated. Then this paper introduced a dynamic multi-epoch update strategy, which can be compatible with any optimizers. By combining all extensions with base MEKF_λ algorithm, robust online adaptation algorithm MEKF_MAME was created.
+Inspired by Extended Kalman Filter (EKF), a base adaptation algorithm Modified EKF with forgetting
+factor (MEKF_$\lambda$) is introduced first. Using exponential moving average (EMA) methods, this
+paper proposes EMA filtering to the base EKFλ in order to increase the convergence rate. followed by exponential moving average filtering techniques.
+Then in order to effectively utilize the samples in online
+adaptation, this paper proposes a dynamic multi-epoch update strategy to discriminate the “hard”
+samples from “easy” samples, and sets different weights for them. With all these extensions, we propose a robust online adaptation algorithm:
+MEKF with Exponential Moving Average and Dynamic Multi-Epoch strategy (MEKFEMA-DME).
diff --git a/adapt.py b/adapt.py
new file mode 100644
index 0000000..9965955
--- /dev/null
+++ b/adapt.py
@@ -0,0 +1,150 @@
+# coding=utf-8
+
+import os
+import warnings
+
+import joblib
+import torch
+import numpy as np
+from dataset.dataset import get_data_loader
+from adaptation.lookahead import Lookahead
+from adaptation.mekf import MEKF_MA
+from parameters import hyper_parameters, adapt_hyper_parameters
+from utils.adapt_utils import online_adaptation
+from utils.pred_utils import get_predictions,get_position
+
+os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
+os.environ['CUDA_VISIBLE_DEVICES'] = "0"
+warnings.filterwarnings("ignore")
+
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+device =torch.device("cpu")
+print('testing with device:', device)
+
+rnn_layer_name = ['encoder.rnn.weight_ih_l0', 'encoder.rnn.bias_ih_l0',
+ 'encoder.rnn.weight_hh_l0','encoder.rnn.bias_hh_l0',
+ 'decoder.rnn.weight_ih_l0', 'decoder.rnn.bias_ih_l0',
+ 'decoder.rnn.weight_hh_l0','decoder.rnn.bias_hh_l0',
+ 'decoder.output_projection.weight', 'decoder.output_projection.bias']
+fc_layer_name = ['encoder.layers.0.weight', 'encoder.layers.0.bias',
+ 'encoder.layers.3.weight', 'encoder.layers.3.bias',
+ 'decoder.layers.0.weight', 'decoder.layers.0.bias',
+ 'decoder.layers.3.weight', 'decoder.layers.3.bias',
+ 'decoder.output_projection.weight', 'decoder.output_projection.bias', ]
+
+
+def adaptable_prediction(data_loader, model, train_params, device, adaptor, adapt_step=1):
+ '''adaptation hyper param'''
+ adapt_params = adapt_hyper_parameters(adaptor=adaptor, adapt_step=adapt_step, log_dir=train_params['log_dir'])
+ adapt_params._save_parameters()
+ adapt_params.print_params()
+
+ adapt_weights = []
+ if train_params['encoder'] == 'rnn':
+ adapt_layers = rnn_layer_name[8:]
+ else:
+ adapt_layers = fc_layer_name[8:]
+ print('adapt_weights:')
+ print(adapt_layers)
+ for name, p in model.named_parameters():
+ if name in adapt_layers:
+ adapt_weights.append(p)
+ print(name, p.size())
+
+ optim_param = adapt_params.adapt_param()
+ if adaptor == 'mekf' or adaptor=='mekf_ma':
+ optimizer = MEKF_MA(adapt_weights, dim_out=adapt_step * train_params['coordinate_dim'],
+ p0=optim_param['p0'], lbd=optim_param['lbd'], sigma_r=optim_param['sigma_r'],
+ sigma_q=optim_param['sigma_q'], lr=optim_param['lr'],
+ miu_v=optim_param['miu_v'], miu_p=optim_param['miu_p'],
+ k_p=optim_param['k_p'])
+ elif adaptor == 'sgd':
+ optimizer = torch.optim.SGD(adapt_weights, lr=optim_param['lr'], momentum=optim_param['momentum'],
+ nesterov=optim_param['nesterov'])
+
+ elif adaptor == 'adam':
+ optimizer = torch.optim.Adam(adapt_weights, lr=optim_param['lr'], betas=optim_param['betas'],
+ amsgrad=optim_param['amsgrad'])
+
+ elif adaptor == 'lbfgs':
+ optimizer = torch.optim.LBFGS(adapt_weights, lr=optim_param['lr'], max_iter=optim_param['max_iter'],
+ history_size=optim_param['history_size'])
+ else:
+ raise NotImplementedError
+ print('base optimizer configs:', optimizer.defaults)
+ if optim_param['use_lookahead']:
+ optimizer = Lookahead(optimizer, k=optim_param['la_k'], alpha=optim_param['la_alpha'])
+
+ st_param = adapt_params.strategy_param()
+ pred_result = online_adaptation(data_loader, model, optimizer, train_params, device,
+ adapt_step=adapt_step,
+ use_multi_epoch=st_param['use_multi_epoch'],
+ multiepoch_thresh=st_param['multiepoch_thresh'])
+
+
+ return pred_result
+
+
+def test(params, adaptor='none', adapt_step=1):
+ train_params = params.train_param()
+ train_params['data_mean'] = torch.tensor(train_params['data_stats']['speed_mean'], dtype=torch.float).unsqueeze(
+ 0).to(device)
+ train_params['data_std'] = torch.tensor(train_params['data_stats']['speed_std'], dtype=torch.float).unsqueeze(0).to(
+ device)
+ data_stats = {'data_mean': train_params['data_mean'], 'data_std': train_params['data_std']}
+
+ model = torch.load(train_params['init_model'])
+ model = model.to(device)
+ print('load model', train_params['init_model'])
+
+ data_loader = get_data_loader(train_params, mode='test')
+ print('begin to test')
+ if adaptor == 'none':
+ with torch.no_grad():
+ pred_result = get_predictions(data_loader, model, device)
+ else:
+ pred_result = adaptable_prediction(data_loader, model, train_params, device, adaptor, adapt_step)
+
+ traj_hist, traj_preds, traj_labels, intent_preds, intent_labels, pred_start_pos = pred_result
+ traj_preds = get_position(traj_preds, pred_start_pos, data_stats)
+ traj_labels = get_position(traj_labels, pred_start_pos, data_stats)
+ intent_preds_prob = intent_preds.detach().clone()
+ _, intent_preds = intent_preds.max(1)
+
+ result = {'traj_hist': traj_hist, 'traj_preds': traj_preds, 'traj_labels': traj_labels,
+ 'intent_preds': intent_preds,'intent_preds_prob':intent_preds_prob,
+ 'intent_labels': intent_labels, 'pred_start_pos': pred_start_pos}
+
+ for k, v in result.items():
+ result[k] = v.cpu().detach().numpy()
+
+ out_str = 'Evaluation Result: \n'
+
+ num, time_step = result['traj_labels'].shape[:2]
+ mse = np.power(result['traj_labels'] - result['traj_preds'], 2).sum() / (num * time_step)
+ out_str += "trajectory_mse: %.4f, \n" % (mse)
+
+ acc = (result['intent_labels'] == result['intent_preds']).sum() / len(result['intent_labels'])
+ out_str += "action_acc: %.4f, \n" % (acc)
+
+ print(out_str)
+ save_path = train_params['log_dir'] + adaptor + str(adapt_step) + '_pred.pkl'
+ joblib.dump(result, save_path)
+ print('save result to', save_path)
+ return result
+
+
+def main(dataset='vehicle_ngsim', model_type='rnn', adaptor='mekf',adapt_step=1):
+ save_dir = 'output/' + dataset + '/' + model_type + '/'
+ model_path = save_dir + 'model_1.pkl'
+ params = hyper_parameters()
+ params._load_parameters(save_dir + 'log/')
+ params.params_dict['train_param']['init_model'] = model_path
+ params.print_params()
+ test(params, adaptor=adaptor, adapt_step=adapt_step)
+
+
+if __name__ == '__main__':
+ main()
+
diff --git a/adaptation/__init__.py b/adaptation/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/adaptation/lookahead.py b/adaptation/lookahead.py
new file mode 100644
index 0000000..bb98809
--- /dev/null
+++ b/adaptation/lookahead.py
@@ -0,0 +1,70 @@
+from collections import defaultdict
+from torch.optim import Optimizer
+import torch
+
+
+class Lookahead(Optimizer):
+ def __init__(self, optimizer, k=5, alpha=0.5):
+ self.optimizer = optimizer
+ self.k = k
+ self.alpha = alpha
+ self.param_groups = self.optimizer.param_groups
+ self.state = defaultdict(dict)
+ self.fast_state = self.optimizer.state
+ for group in self.param_groups:
+ group["counter"] = 0
+
+ def update(self, group):
+ for fast in group["params"]:
+ param_state = self.state[fast]
+ if "slow_param" not in param_state:
+ param_state["slow_param"] = torch.zeros_like(fast.data)
+ param_state["slow_param"].copy_(fast.data)
+ slow = param_state["slow_param"]
+ slow += (fast.data - slow) * self.alpha
+ fast.data.copy_(slow)
+
+ def update_lookahead(self):
+ for group in self.param_groups:
+ self.update(group)
+
+ def step(self, closure=None):
+ loss = self.optimizer.step(closure)
+ for group in self.param_groups:
+ if group["counter"] == 0:
+ self.update(group)
+ group["counter"] += 1
+ if group["counter"] >= self.k:
+ group["counter"] = 0
+ return loss
+
+ def state_dict(self):
+ fast_state_dict = self.optimizer.state_dict()
+ slow_state = {
+ (id(k) if isinstance(k, torch.Tensor) else k): v
+ for k, v in self.state.items()
+ }
+ fast_state = fast_state_dict["state"]
+ param_groups = fast_state_dict["param_groups"]
+ return {
+ "fast_state": fast_state,
+ "slow_state": slow_state,
+ "param_groups": param_groups,
+ }
+
+ def load_state_dict(self, state_dict):
+ slow_state_dict = {
+ "state": state_dict["slow_state"],
+ "param_groups": state_dict["param_groups"],
+ }
+ fast_state_dict = {
+ "state": state_dict["fast_state"],
+ "param_groups": state_dict["param_groups"],
+ }
+ super(Lookahead, self).load_state_dict(slow_state_dict)
+ self.optimizer.load_state_dict(fast_state_dict)
+ self.fast_state = self.optimizer.state
+
+ def add_param_group(self, param_group):
+ param_group["counter"] = 0
+ self.optimizer.add_param_group(param_group)
\ No newline at end of file
diff --git a/adaptation/mekf.py b/adaptation/mekf.py
new file mode 100644
index 0000000..1fff483
--- /dev/null
+++ b/adaptation/mekf.py
@@ -0,0 +1,340 @@
+
+
+import torch
+from torch.optim import Optimizer
+
+
+class MEKF_MA(Optimizer):
+ """
+ Modified Extended Kalman Filter with generalized exponential Moving Average
+ """
+
+ def __init__(self, params, dim_out, p0=1e-2, lbd=1, sigma_r=None, sigma_q=0, lr=1,
+ miu_v=0, miu_p=0, k_p=1,
+ R_decay=False,R_decay_step=1000000):
+
+ if sigma_r is None:
+ sigma_r = max(lbd,0)
+ self._check_format(dim_out, p0, lbd, sigma_r, sigma_q, lr,miu_v,miu_p,k_p,R_decay,R_decay_step)
+ defaults = dict(p0=p0, lbd=lbd, sigma_r=sigma_r, sigma_q=sigma_q,
+ lr=lr,miu_v=miu_v,miu_p=miu_p,k_p=k_p,
+ R_decay=R_decay,R_decay_step=R_decay_step)
+ super(MEKF_MA, self).__init__(params, defaults)
+
+ self.state['dim_out'] = dim_out
+ with torch.no_grad():
+ self._init_mekf_matrix()
+
+ def _check_format(self, dim_out, p0, lbd, sigma_r, sigma_q, lr, miu_v, miu_p, k_p, R_decay, R_decay_step):
+ if not isinstance(dim_out, int) and dim_out > 0:
+ raise ValueError("Invalid output dimension: {}".format(dim_out))
+ if not 0.0 < p0:
+ raise ValueError("Invalid initial P value: {}".format(p0))
+ if not 0.0 < lbd:
+ raise ValueError("Invalid forgetting factor: {}".format(lbd))
+ if not 0.0 < sigma_r:
+ raise ValueError("Invalid covariance matrix value for R: {}".format(sigma_r))
+ if not 0.0 <= sigma_q:
+ raise ValueError("Invalid covariance matrix value for Q: {}".format(sigma_q))
+ if not 0.0 < lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+
+ if not 0.0 <= miu_v < 1.0:
+ raise ValueError("Invalid EMA decaying factor for V matrix: {}".format(miu_v))
+ if not 0.0 <= miu_p < 1.0:
+ raise ValueError("Invalid EMA decaying factor for P matrix: {}".format(miu_p))
+ if not isinstance(k_p, int) and k_p >= 0:
+ raise ValueError("Invalid delayed step size of Lookahead P: {}".format(k_p))
+
+ if not isinstance(R_decay, int) and not isinstance(R_decay, bool):
+ raise ValueError("Invalid R decay flag: {}".format(R_decay))
+ if not isinstance(R_decay_step, int):
+ raise ValueError("Invalid max step for R decaying: {}".format(R_decay_step))
+
+ def _init_mekf_matrix(self):
+ self.state['step']=0
+ self.state['mekf_groups']=[]
+ dim_out = self.state['dim_out']
+ for group in self.param_groups:
+ mekf_mat=[]
+ for p in group['params']:
+ matrix = {}
+ size = p.size()
+ dim_w=1
+ for dim in size:
+ dim_w*=dim
+ device= p.device
+ matrix['P'] = group['p0']*torch.eye(dim_w,dtype=torch.float,device=device)
+ matrix['R'] = group['sigma_r']*torch.eye(dim_out,dtype=torch.float,device=device)
+ matrix['Q'] = group['sigma_q'] * torch.eye(dim_w, dtype=torch.float, device=device)
+ matrix['H'] = None
+ mekf_mat.append(matrix)
+ self.state['mekf_groups'].append(mekf_mat)
+
+ def step(self,closure=None, H_groups=None, err=None):
+ """Performs a single optimization step.
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ H_groups: groups of gradient matrix
+ err: error value
+ example 1 (optimization step with closure):
+ # optimizer -> MEKF_MA optimizer
+ # y -> observed value, y_hat -> predicted value
+ y = y.contiguous().view((-1, 1)) # shape of (dim_out,1)
+ y_hat = y_hat.contiguous().view((-1, 1)) # shape of (dim_out,1)
+ err = (y - y_hat).detach()
+
+ def mekf_closure(index=0):
+ optimizer.zero_grad()
+ dim_out = optimizer.state['dim_out']
+ retain = index < dim_out - 1
+ y_hat[index].backward(retain_graph=retain)
+ return err
+
+ optimizer.step(mekf_closure)
+
+ example 2 (optimization step with H_groups):
+ # y -> observed value, y_hat -> predicted value
+ # H -> gradient matrix that need to be specified
+ y = y.contiguous().view((-1, 1)) # shape of (dim_out,1)
+ y_hat = y_hat.contiguous().view((-1, 1)) # shape of (dim_out,1)
+ err = (y - y_hat).detach()
+ optimizer.step(H_groups=H_groups,err=err)
+ """
+ self.state['step'] += 1
+
+ if closure is not None:
+ for y_ind in range(self.state['dim_out']):
+ err = closure(y_ind)
+ for group_ind in range(len(self.param_groups)):
+ group = self.param_groups[group_ind]
+ mekf_mat = self.state['mekf_groups'][group_ind]
+ for ii, w in enumerate(group['params']):
+ if w.grad is None:
+ continue
+ H_n = mekf_mat[ii]['H']
+ grad = w.grad.data.detach()
+ if len(w.size())>1:
+ grad = grad.transpose(1, 0)
+ grad = grad.contiguous().view((1,-1))
+ if y_ind ==0:
+ H_n=grad
+ else:
+ H_n = torch.cat([H_n,grad],dim=0)
+ self.state['mekf_groups'][group_ind][ii]['H'] = H_n
+ else:
+ for group_ind in range(len(self.param_groups)):
+ H_mats = H_groups[group_ind]
+ for ii, H_n in enumerate(H_mats):
+ self.state['mekf_groups'][group_ind][ii]['H'] = H_n
+
+ err_T = err.transpose(0,1)
+
+ for group_ind in range(len(self.param_groups)):
+ group = self.param_groups[group_ind]
+ mekf_mat = self.state['mekf_groups'][group_ind]
+
+ miu_v = group['miu_v']
+ miu_p = group['miu_p']
+ k_p = group['k_p']
+ lr = group['lr']
+ lbd = group['lbd']
+
+ for ii,w in enumerate(group['params']):
+ if w.grad is None:
+ continue
+
+ P_n_1 = mekf_mat[ii]['P']
+ R_n = mekf_mat[ii]['R']
+ Q_n = mekf_mat[ii]['Q']
+ H_n = mekf_mat[ii]['H']
+ H_n_T = H_n.transpose(0, 1)
+
+ if group['R_decay']:
+ miu = 1.0 / min(self.state['step'],group['R_decay_step'])
+ R_n = R_n + miu * (err.mm(err_T) + H_n.mm(P_n_1).mm(H_n_T) - R_n)
+ self.state['mekf_groups'][group_ind][ii]['R']= R_n
+
+ g_n = H_n.mm(P_n_1).mm(H_n_T) + R_n
+ g_n = g_n.inverse()
+ K_n = P_n_1.mm(H_n_T).mm(g_n)
+ V_n = lr * K_n.mm(err)
+ if len(w.size()) > 1:
+ V_n = V_n.view((w.size(1),w.size(0))).transpose(1,0)
+ else:
+ V_n = V_n.view(w.size())
+ if miu_v>0:
+ param_state = self.state[w]
+ if 'buffer_V' not in param_state:
+ V_ema = param_state['buffer_V'] = torch.clone(V_n).detach()
+ else:
+ V_ema = param_state['buffer_V']
+ V_ema.mul_(miu_v).add_(V_n.mul(1-miu_v).detach())
+ V_n=V_ema
+ w.data.add_(V_n)
+
+ P_n = (1/lbd) * (P_n_1 - K_n.mm(H_n).mm(P_n_1) + Q_n)
+ if miu_p>0 and k_p>0:
+ if self.state['step'] % k_p==0:
+ param_state = self.state[w]
+ if 'buffer_P' not in param_state:
+ P_ema = param_state['buffer_P'] = torch.clone(P_n).detach()
+ else:
+ P_ema = param_state['buffer_P']
+ P_ema.mul_(miu_p).add_(P_n.mul(1 - miu_p).detach())
+ P_n = P_ema
+ self.state['mekf_groups'][group_ind][ii]['P'] =P_n
+
+ return err
+
+class MEKF(Optimizer):
+ """
+ Modified Extended Kalman Filter
+ """
+
+ def __init__(self, params, dim_out, p0=1e-2, lbd=1, sigma_r=None, sigma_q=0,
+ R_decay=False,R_decay_step=1000000):
+
+ if sigma_r is None:
+ sigma_r = max(lbd,0)
+ self._check_format(dim_out, p0, lbd, sigma_r, sigma_q, R_decay,R_decay_step)
+ defaults = dict(p0=p0, lbd=lbd, sigma_r=sigma_r, sigma_q=sigma_q,
+ R_decay=R_decay,R_decay_step=R_decay_step)
+ super(MEKF, self).__init__(params, defaults)
+
+ self.state['dim_out'] = dim_out
+ with torch.no_grad():
+ self._init_mekf_matrix()
+
+ def _check_format(self, dim_out, p0, lbd, sigma_r, sigma_q, R_decay, R_decay_step):
+ if not isinstance(dim_out, int) and dim_out > 0:
+ raise ValueError("Invalid output dimension: {}".format(dim_out))
+ if not 0.0 < p0:
+ raise ValueError("Invalid initial P value: {}".format(p0))
+ if not 0.0 < lbd:
+ raise ValueError("Invalid forgetting factor: {}".format(lbd))
+ if not 0.0 < sigma_r:
+ raise ValueError("Invalid covariance matrix value for R: {}".format(sigma_r))
+ if not 0.0 <= sigma_q:
+ raise ValueError("Invalid covariance matrix value for Q: {}".format(sigma_q))
+
+ if not isinstance(R_decay, int) and not isinstance(R_decay, bool):
+ raise ValueError("Invalid R decay flag: {}".format(R_decay))
+ if not isinstance(R_decay_step, int):
+ raise ValueError("Invalid max step for R decaying: {}".format(R_decay_step))
+
+ def _init_mekf_matrix(self):
+ self.state['step']=0
+ self.state['mekf_groups']=[]
+ dim_out = self.state['dim_out']
+ for group in self.param_groups:
+ mekf_mat=[]
+ for p in group['params']:
+ matrix = {}
+ size = p.size()
+ dim_w=1
+ for dim in size:
+ dim_w*=dim
+ device= p.device
+ matrix['P'] = group['p0']*torch.eye(dim_w,dtype=torch.float,device=device)
+ matrix['R'] = group['sigma_r']*torch.eye(dim_out,dtype=torch.float,device=device)
+ matrix['Q'] = group['sigma_q'] * torch.eye(dim_w, dtype=torch.float, device=device)
+ matrix['H'] = None
+ mekf_mat.append(matrix)
+ self.state['mekf_groups'].append(mekf_mat)
+
+ def step(self,closure=None, H_groups=None, err=None):
+ """Performs a single optimization step.
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ H_groups: groups of gradient matrix
+ err: error value
+ example 1 (optimization step with closure):
+ # optimizer -> MEKF_MA optimizer
+ # y -> observed value, y_hat -> predicted value
+ y = y.contiguous().view((-1, 1)) # shape of (dim_out,1)
+ y_hat = y_hat.contiguous().view((-1, 1)) # shape of (dim_out,1)
+ err = (y - y_hat).detach()
+
+ def mekf_closure(index=0):
+ optimizer.zero_grad()
+ dim_out = optimizer.state['dim_out']
+ retain = index < dim_out - 1
+ y_hat[index].backward(retain_graph=retain)
+ return err
+
+ optimizer.step(mekf_closure)
+
+ example 2 (optimization step with H_groups):
+ # y -> observed value, y_hat -> predicted value
+ # H -> gradient matrix that need to be specified
+ y = y.contiguous().view((-1, 1)) # shape of (dim_out,1)
+ y_hat = y_hat.contiguous().view((-1, 1)) # shape of (dim_out,1)
+ err = (y - y_hat).detach()
+ optimizer.step(H_groups=H_groups,err=err)
+ """
+ self.state['step'] += 1
+
+ if closure is not None:
+ for y_ind in range(self.state['dim_out']):
+ err = closure(y_ind)
+ for group_ind in range(len(self.param_groups)):
+ group = self.param_groups[group_ind]
+ mekf_mat = self.state['mekf_groups'][group_ind]
+ for ii, w in enumerate(group['params']):
+ if w.grad is None:
+ continue
+ H_n = mekf_mat[ii]['H']
+ grad = w.grad.data.detach()
+ if len(w.size())>1:
+ grad = grad.transpose(1, 0)
+ grad = grad.contiguous().view((1,-1))
+ if y_ind ==0:
+ H_n=grad
+ else:
+ H_n = torch.cat([H_n,grad],dim=0)
+ self.state['mekf_groups'][group_ind][ii]['H'] = H_n
+ else:
+ for group_ind in range(len(self.param_groups)):
+ H_mats = H_groups[group_ind]
+ for ii, H_n in enumerate(H_mats):
+ self.state['mekf_groups'][group_ind][ii]['H'] = H_n
+
+ err_T = err.transpose(0,1)
+
+ for group_ind in range(len(self.param_groups)):
+ group = self.param_groups[group_ind]
+ mekf_mat = self.state['mekf_groups'][group_ind]
+ lbd = group['lbd']
+
+ for ii,w in enumerate(group['params']):
+ if w.grad is None:
+ continue
+
+ P_n_1 = mekf_mat[ii]['P']
+ R_n = mekf_mat[ii]['R']
+ Q_n = mekf_mat[ii]['Q']
+ H_n = mekf_mat[ii]['H']
+ H_n_T = H_n.transpose(0, 1)
+
+ if group['R_decay']:
+ miu = 1.0 / min(self.state['step'],group['R_decay_step'])
+ R_n = R_n + miu * (err.mm(err_T) + H_n.mm(P_n_1).mm(H_n_T) - R_n)
+ self.state['mekf_groups'][group_ind][ii]['R']= R_n
+
+ g_n = H_n.mm(P_n_1).mm(H_n_T) + R_n
+ g_n = g_n.inverse()
+ K_n = P_n_1.mm(H_n_T).mm(g_n)
+ V_n = K_n.mm(err)
+ if len(w.size()) > 1:
+ V_n = V_n.view((w.size(1),w.size(0))).transpose(1,0)
+ else:
+ V_n = V_n.view(w.size())
+ w.data.add_(V_n)
+
+ P_n = (1/lbd) * (P_n_1 - K_n.mm(H_n).mm(P_n_1) + Q_n)
+ self.state['mekf_groups'][group_ind][ii]['P'] =P_n
+
+ return err
\ No newline at end of file
diff --git a/data/vehicle_ngsim.pkl b/data/vehicle_ngsim.pkl
new file mode 100644
index 0000000..16a7ecb
Binary files /dev/null and b/data/vehicle_ngsim.pkl differ
diff --git a/dataset/dataset.py b/dataset/dataset.py
new file mode 100644
index 0000000..8293ef9
--- /dev/null
+++ b/dataset/dataset.py
@@ -0,0 +1,217 @@
+import os
+from collections import Counter
+
+import joblib
+import numpy as np
+import torch
+from torch.utils.data.dataset import Dataset
+
+def data_time_split(data_list,params):
+ input_time_step = params['input_time_step']
+ output_time_step = params['output_time_step']
+ trajs = data_list['traj']
+ speeds = data_list['speed']
+ features = data_list['feature']
+ actions = data_list['action']
+ x_traj,y_traj,x_traj_len=[],[],[]
+ y_intent = []
+ x_speed, y_speed = [],[]
+ x_feature, y_feature = [],[]
+ data_ids = []
+ inds=np.arange(0,len(trajs))
+
+ for ind in inds:
+ traj = trajs[ind]
+ speed = speeds[ind]
+ feature = features[ind]
+ action = actions[ind]
+ begin=0
+ end=input_time_step+output_time_step
+ steps=len(traj)
+ src_len = steps - output_time_step
+ mask_now=False
+ if src_len < input_time_step/4:
+ continue
+
+ if steps< end:
+ mask_now = True
+ pad_traj = np.array([traj[0]*0]*(end-steps))
+ traj = np.concatenate([pad_traj,traj])
+ pad_speed = np.array([speed[0] * 0] * (end - steps))
+ speed = np.concatenate([pad_speed, speed])
+ pad_feature= np.array([feature[0] * 0] * (end - steps))
+ feature = np.concatenate([pad_feature, speed])
+ pad_actions= np.array([action[0] * 0] * (end - steps))
+ action = np.concatenate([pad_actions, action])
+ steps = len(traj)
+
+ while end<=steps:
+ # input
+ inp_traj = traj[begin:begin+input_time_step].reshape((input_time_step, -1))
+ x_traj.append(inp_traj)
+ data_ids.append(ind)
+ if mask_now:
+ x_traj_len.append(src_len)
+ else:
+ x_traj_len.append(len(inp_traj))
+
+ inp_sp= speed[begin:begin+input_time_step].reshape((input_time_step, -1))
+ x_speed.append(inp_sp)
+
+ inp_feat= feature[begin:begin+input_time_step].reshape((input_time_step, -1))
+ x_feature.append(inp_feat)
+
+ # output
+ out_traj = traj[begin+input_time_step:end].reshape((output_time_step, -1))
+ y_traj.append(out_traj)
+
+ out_sp= speed[begin+input_time_step:end].reshape((output_time_step, -1))
+ y_speed.append(out_sp)
+
+ out_feat= feature[begin+input_time_step:end].reshape((output_time_step, -1))
+ y_feature.append(out_feat)
+
+ y_intent.append(action[begin+input_time_step-1])
+
+ begin += 1
+ end += 1
+
+ x_traj=np.array(x_traj)
+ x_speed = np.array(x_speed)
+ x_feature = np.array(x_feature)
+ y_traj=np.array(y_traj)
+ y_speed = np.array(y_speed)
+ y_feature = np.array(y_feature)
+ y_intent=np.array(y_intent)
+ x_traj_len = np.array(x_traj_len)
+ data_ids=np.array(data_ids)
+ pred_start_pos = x_traj[:,-1]
+ data ={'x_traj':x_traj,'x_speed':x_speed,'x_feature':x_feature,
+ 'y_traj':y_traj,'y_speed':y_speed,'y_feature':y_feature,
+ 'y_intent':y_intent,'pred_start_pos':pred_start_pos,
+ 'x_traj_len':x_traj_len,'data_ids':data_ids}
+ return data
+
+def normalize_data(data, data_stats):
+ new_data={}
+ for k,v in data.items():
+ if k in ['x_traj','x_speed','y_traj','y_speed','x_feature','y_feature']:
+ mark = k.split('_')[-1]
+ data_mean,data_std=data_stats[mark+'_mean'],data_stats[mark+'_std']
+ new_data[k] = (v-data_mean)/data_std
+ else:
+ new_data[k] = v
+ return new_data
+
+class Trajectory_Data(Dataset):
+ def __init__(self, params, mode='train',data_stats={}):
+ self.mode = mode
+ print(mode,'data preprocessing')
+ cache_dir = params['log_dir']+mode+'.cache'
+ if os.path.exists(cache_dir):
+ print('loading data from cache',cache_dir)
+ self.data = joblib.load(cache_dir)
+ else:
+ raw_data = joblib.load(params['data_path'])[mode]
+ self.data = data_time_split(raw_data,params)
+
+ if mode=='train':
+ data_stats['traj_mean'] = np.mean(self.data['x_traj'],axis=(0,1))
+ data_stats['traj_std'] = np.std(self.data['x_traj'], axis=(0, 1))
+ data_stats['speed_mean'] = np.mean(self.data['x_speed'],axis=(0,1))
+ data_stats['speed_std'] = np.std(self.data['x_speed'], axis=(0, 1))
+ data_stats['feature_mean'] = np.mean(self.data['x_feature'],axis=(0,1))
+ data_stats['feature_std'] = np.std(self.data['x_feature'], axis=(0, 1))
+ self.data['data_stats'] = data_stats
+ if params['normalize_data']:
+ if mode=='train':
+ print('data statistics:')
+ print(data_stats)
+ self.data = normalize_data(self.data, data_stats)
+ joblib.dump(self.data,cache_dir)
+
+ enc_inp= None
+ for feat in params['inp_feat']:
+ dat = self.data['x_'+feat]
+ if enc_inp is None:
+ enc_inp = dat
+ else:
+ enc_inp = np.concatenate([enc_inp,dat],axis=-1)
+
+
+ self.data['x_encoder'] = enc_inp
+ self.data['y_decoder'] = self.data['y_speed']
+ self.data['start_decode'] = self.data['x_speed'][:,-1]
+
+
+ self.input_time_step = params['input_time_step']
+ self.input_feat_dim = self.data['x_encoder'].shape[2]
+
+ print(mode + '_data size:', len(self.data['x_encoder']))
+ print('each category counts:')
+ print(Counter(self.data['y_intent']))
+
+ def __getitem__(self, index):
+ x = self.data['x_encoder'][index]
+ y_traj = self.data['y_decoder'][index]
+ y_inten = self.data['y_intent'][index]
+ start_decode = self.data['start_decode'][index] # this is for incremental decoding
+ pred_start_pos = self.data['pred_start_pos'][index]
+ x_len = self.data['x_traj_len'][index]
+ x_mask = np.zeros(shape=x.shape[0],dtype=np.int)
+ bias = self.input_time_step - x_len
+ # left pad
+ if bias>0:
+ x_mask[:bias] = 1
+ x[:bias] = 0
+
+ return (x, y_traj, y_inten, start_decode, pred_start_pos, x_mask)
+
+ def __len__(self):
+ return len(self.data['x_encoder'])
+
+def get_data_loader(params, mode='train',pin_memory=False):
+ if mode == 'train':
+ train_data = Trajectory_Data(params, mode='train')
+ data_stats = train_data.data['data_stats']
+ train_loader = torch.utils.data.DataLoader(
+ train_data, batch_size=params['batch_size'], shuffle=True,
+ drop_last=True, pin_memory=pin_memory)
+
+ valid_data = Trajectory_Data(params, mode='valid',data_stats=data_stats)
+ valid_loader = torch.utils.data.DataLoader(
+ valid_data, batch_size=params['batch_size'], shuffle=False,
+ drop_last=False, pin_memory=pin_memory)
+
+ test_data = Trajectory_Data(params, mode='test',data_stats=data_stats)
+ test_loader = torch.utils.data.DataLoader(
+ test_data, batch_size=params['batch_size'], shuffle=False,
+ drop_last=False, pin_memory=pin_memory)
+
+ for k, v in data_stats.items():
+ data_stats[k] = [float(x) for x in v]
+ params['data_stats'] = data_stats
+ params['print_step'] = max(1,len(train_loader) // 10)
+ return train_loader, valid_loader, test_loader, params
+
+ elif mode == 'test':
+ data_stats = params['data_stats']
+ for k,v in data_stats.items():
+ data_stats[k] = np.array(v)
+ test_data = Trajectory_Data(params, mode='test',data_stats=data_stats)
+ test_loader = torch.utils.data.DataLoader(
+ test_data, batch_size=params['batch_size'], shuffle=False,
+ num_workers=1,drop_last=False, pin_memory=pin_memory)
+ return test_loader
+ elif mode == 'valid':
+ data_stats = params['data_stats']
+ for k,v in data_stats.items():
+ data_stats[k] = np.array(v)
+ test_data = Trajectory_Data(params, mode='valid',data_stats=data_stats)
+ test_loader = torch.utils.data.DataLoader(
+ test_data, batch_size=params['batch_size'], shuffle=False,
+ num_workers=1,drop_last=False, pin_memory=pin_memory)
+ return test_loader
+ else:
+ return None
+
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/models/fc_model.py b/models/fc_model.py
new file mode 100644
index 0000000..b12bc18
--- /dev/null
+++ b/models/fc_model.py
@@ -0,0 +1,102 @@
+import torch
+import torch.nn as nn
+
+
+class FCEncoder(nn.Module):
+ """Fully connected encoder."""
+
+ def __init__(self, feat_dim, max_seq_len, hidden_dims, dropout=0.0, act_fn=nn.ReLU, **kwargs):
+ super(FCEncoder, self).__init__()
+ self.feat_dim = feat_dim
+
+ self.layers = nn.ModuleList([])
+ input_dim = feat_dim * max_seq_len
+ for hd in hidden_dims:
+ self.layers.append(Linear(input_dim, hd))
+ self.layers.append(nn.Dropout(dropout))
+ self.layers.append(act_fn())
+ input_dim = hd
+
+ self.embed_inputs = None
+ self.output_units = hidden_dims[-1]
+
+ def forward(self, src_seq):
+ batch_size = src_seq.size(0)
+ x = src_seq.view(batch_size, -1)
+ out_stack = []
+ for ii, layer in enumerate(self.layers):
+ x = layer(x)
+ if (ii + 1) % 3 == 0: # fc, drop out, relu
+ out_stack.append(x)
+ out_stack = torch.cat(out_stack, dim=1)
+ return x, out_stack
+
+
+class FCDecoder(nn.Module):
+ """FC decoder."""
+
+ def __init__(self, feat_dim, max_seq_len, hidden_dims, dropout=0.0, act_fn=nn.ReLU,
+ encoder_output_units=10, traj_attn_intent_dim=0, **kwargs):
+ super(FCDecoder, self).__init__()
+ self.feat_dim = feat_dim
+ self.max_seq_len = max_seq_len
+ self.traj_attn_intent_dim = traj_attn_intent_dim
+
+ self.layers = nn.ModuleList([])
+ input_dim = encoder_output_units
+ for hd in hidden_dims:
+ self.layers.append(Linear(input_dim, hd))
+ self.layers.append(nn.Dropout(dropout))
+ self.layers.append(act_fn())
+ input_dim = hd
+
+ self.output_projection = Linear(input_dim, max_seq_len * feat_dim)
+
+ if traj_attn_intent_dim > 0:
+ self.traj_attn_fc = Linear(input_dim, traj_attn_intent_dim)
+
+ def forward(self, encoder_outs):
+ x = encoder_outs
+ for layer in self.layers:
+ x = layer(x)
+ hidden_out = x
+ if self.traj_attn_intent_dim > 0:
+ hidden_out = self.traj_attn_fc(hidden_out)
+ x = self.output_projection(x)
+ x = x.view(-1, self.max_seq_len, self.feat_dim)
+ return x, hidden_out
+
+
+class FCClassifier(nn.Module):
+ """FC classifier."""
+
+ def __init__(self, encoder_output_units, hidden_dims, dropout=0.0, act_fn=nn.ReLU,
+ num_class=11, traj_attn_intent_dim=0, **kwargs):
+ super(FCClassifier, self).__init__()
+
+ self.layers = nn.ModuleList([])
+ input_dim = encoder_output_units + traj_attn_intent_dim
+ for hd in hidden_dims:
+ self.layers.append(Linear(input_dim, hd))
+ self.layers.append(nn.Dropout(dropout))
+ self.layers.append(act_fn())
+ input_dim = hd
+
+ self.output_projection = Linear(input_dim, num_class, bias=False)
+
+ def forward(self, encoder_outs):
+ x = encoder_outs
+ for layer in self.layers:
+ x = layer(x)
+ hidden=x
+ x = self.output_projection(x)
+ return x,hidden
+
+
+def Linear(in_features, out_features, bias=True):
+ """Linear layer (input: N x T x C)"""
+ m = nn.Linear(in_features, out_features, bias=bias)
+ m.weight.data.uniform_(-0.1, 0.1)
+ if bias:
+ m.bias.data.uniform_(-0.1, 0.1)
+ return m
diff --git a/models/model_factory.py b/models/model_factory.py
new file mode 100644
index 0000000..5dc76bc
--- /dev/null
+++ b/models/model_factory.py
@@ -0,0 +1,106 @@
+import torch
+from torch import nn
+
+from .fc_model import FCEncoder, FCDecoder, FCClassifier
+from .pooling_layer import AvgPooling, LastPooling, LinearSeqAttnPooling, NoPooling
+from .rnn_model import RNNEncoder, RNNDecoder
+
+
+class MultiTask_Model(nn.Module):
+ def __init__(self, encoder_type, decoder_type,pool_type, params):
+ super(MultiTask_Model, self).__init__()
+
+ self.encoder_type = encoder_type
+ self.pool_type = pool_type
+ self.decoder_type = decoder_type
+ self.params = params
+ self.train_param=self.params.train_param()
+ self.traj_attn_intent_dim = self.train_param['traj_attn_intent_dim']
+
+ self.encoder = self._create_encoder(self.encoder_type)
+ self.enc_out_units = self.encoder.output_units
+
+ self.decoder = self._create_decoder(self.decoder_type)
+
+ self.clf_pool = self._create_pooling(self.pool_type)
+ self.classifier = self._create_decoder(decoder_type='classifier')
+
+ if self.traj_attn_intent_dim>0:
+ self.attn_pool = self._create_pooling(self.pool_type,input_size=self.traj_attn_intent_dim)
+
+
+ def _create_encoder(self, encoder_type):
+ # create encoder
+ if encoder_type == 'rnn':
+ rnn_param = self.params.encode_rnn_param()
+ encoder = RNNEncoder(**rnn_param)
+ else:
+ fc_param = self.params.encode_fc_param()
+ encoder = FCEncoder(**fc_param)
+ return encoder
+
+ def _create_pooling(self, pool_type,input_size=None):
+ if input_size is None:
+ input_size=self.enc_out_units
+ if pool_type == 'mean' or pool_type == 'avg':
+ pool = AvgPooling()
+ elif pool_type == 'last':
+ pool = LastPooling()
+ elif pool_type == 'linear_attn':
+ pool = LinearSeqAttnPooling(input_size=input_size)
+ else:
+ pool = NoPooling()
+ return pool
+
+ def _create_decoder(self, decoder_type):
+ if decoder_type == 'rnn':
+ rnn_params = self.params.decode_rnn_param()
+ decoder = RNNDecoder(encoder_output_units=self.enc_out_units,traj_attn_intent_dim=self.traj_attn_intent_dim,
+ **rnn_params)
+ elif decoder_type == 'classifier':
+ clf_params = self.params.classifier_fc_param()
+ decoder = FCClassifier(encoder_output_units=self.enc_out_units,traj_attn_intent_dim=self.traj_attn_intent_dim,
+ **clf_params)
+ else:
+ fc_param = self.params.decode_fc_param()
+ decoder = FCDecoder(encoder_output_units=self.enc_out_units,traj_attn_intent_dim=self.traj_attn_intent_dim,
+ **fc_param)
+
+ return decoder
+
+ def forward(self, src_seq,start_decode=None,encoder_mask=None):
+
+ enc = self.encoder(src_seq)
+ encoder_out, encoder_state= enc
+
+ if self.decoder_type == 'rnn':
+ out_traj, hidden_out_traj = self.decoder(enc, start_decode,encoder_mask=encoder_mask)
+ else:
+ out_traj,hidden_out_traj = self.decoder(encoder_out)
+
+
+ clf_inp = self.clf_pool(encoder_out,x_mask=encoder_mask)
+ if self.traj_attn_intent_dim>0:
+ hidden_out_traj = self.attn_pool(hidden_out_traj)
+ clf_inp = torch.cat([clf_inp, hidden_out_traj], dim=1)
+ out_intent,_ = self.classifier(clf_inp)
+
+ return out_traj, out_intent
+
+def create_model(params):
+ train_params = params.train_param()
+ if train_params['init_model'] is not None:
+ model = torch.load(train_params['init_model'])
+ print('load model', train_params['init_model'])
+ else:
+ model = MultiTask_Model(
+ encoder_type=train_params['encoder'],
+ pool_type=train_params['pool_type'],
+ decoder_type=train_params['decoder'],
+ params=params)
+
+ param_num = sum([p.data.nelement() for p in model.parameters()])
+ print("Number of model parameters: {} M".format(param_num / 1024. / 1024.))
+ model.train()
+
+ return model
diff --git a/models/pooling_layer.py b/models/pooling_layer.py
new file mode 100644
index 0000000..f9bad41
--- /dev/null
+++ b/models/pooling_layer.py
@@ -0,0 +1,64 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class MaxPooling(nn.Module):
+ def forward(self, x, x_mask=None):
+ if x_mask is None or x_mask.data.sum() == 0:
+ return torch.max(x, 1)[0]
+ else:
+ lengths = (1 - x_mask).sum(1)
+ return torch.cat([torch.max(i[:l], dim=0)[0].view(1, -1) for i, l in zip(x, lengths)], dim=0)
+
+class AvgPooling(nn.Module):
+ def forward(self, x, x_mask=None):
+ if x_mask is None or x_mask.data.sum() == 0:
+ return torch.mean(x, 1)
+ else:
+ lengths = (1 - x_mask).sum(1)
+ return torch.cat([torch.mean(i[:l], dim=0)[0].view(1, -1) for i, l in zip(x, lengths)], dim=0)
+
+class LastPooling(nn.Module):
+ def __init__(self):
+ super(LastPooling, self).__init__()
+
+ def forward(self, x, x_mask=None):
+ if x_mask is None or x_mask.data.sum() == 0:
+ return x[:, -1, :]
+ else:
+ lengths = (1 - x_mask).sum(1)
+ return torch.cat([i[l - 1, :] for i, l in zip(x, lengths)], dim=0).view(x.size(0), -1)
+
+class LinearSeqAttnPooling(nn.Module):
+ """Self attention over a sequence:
+
+ * o_i = softmax(Wx_i) for x_i in X.
+ """
+
+ def __init__(self, input_size,bias=False):
+ super(LinearSeqAttnPooling, self).__init__()
+ self.linear = nn.Linear(input_size, 1,bias=bias)
+
+ def forward(self, x, x_mask=None):
+ """
+ Args:
+ x: batch * len * hdim
+ x_mask: batch * len (1 for padding, 0 for true)
+ Output:
+ alpha: batch * len
+ """
+ # TODO why need contiguous
+ x = x.contiguous()
+ x_flat = x.view(-1, x.size(-1))
+ scores = self.linear(x_flat).view(x.size(0), x.size(1))
+ if x_mask is not None:
+ scores.data.masked_fill_(x_mask.data, -float('inf'))
+ alpha = F.softmax(scores, dim=-1)
+ self.alpha = alpha
+ return alpha.unsqueeze(1).bmm(x).squeeze(1)
+
+class NoPooling(nn.Module):
+ # placeholder for identity mapping
+ def forward(self, x, x_mask=None):
+ return x
+
diff --git a/models/rnn_model.py b/models/rnn_model.py
new file mode 100644
index 0000000..77bf71c
--- /dev/null
+++ b/models/rnn_model.py
@@ -0,0 +1,268 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the LICENSE file in
+# the root directory of this source tree. An additional grant of patent rights
+# can be found in the PATENTS file in the same directory.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class RNNEncoder(nn.Module):
+ """RNN encoder."""
+
+ def __init__(
+ self, cell_type='lstm', feat_dim=3, hidden_size=128, num_layers=1,
+ dropout_fc=0.1, dropout_rnn=0.0, bidirectional=False,**kwargs):
+ super(RNNEncoder, self).__init__()
+ self.cell_type = cell_type.lower()
+ self.feat_dim = feat_dim
+ self.num_layers = num_layers
+ self.dropout_fc = dropout_fc
+ self.dropout_rnn = dropout_rnn
+ self.bidirectional = bidirectional
+ self.hidden_size = hidden_size
+ input_size = feat_dim
+
+ if self.cell_type == 'lstm':
+ self.rnn = LSTM(input_size=input_size, hidden_size=hidden_size,
+ num_layers=num_layers, bidirectional=bidirectional,
+ dropout=self.dropout_rnn)
+ elif self.cell_type == 'gru':
+ self.rnn = GRU(input_size=input_size, hidden_size=hidden_size,
+ num_layers=num_layers, bidirectional=bidirectional,
+ dropout=self.dropout_rnn)
+ else:
+ self.rnn = RNN(input_size=input_size, hidden_size=hidden_size,
+ num_layers=num_layers, bidirectional=bidirectional,
+ dropout=self.dropout_rnn)
+
+ self.output_units = hidden_size
+ if bidirectional:
+ self.output_units *= 2
+
+ def forward(self, src_seq):
+ x = src_seq
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+
+ encoder_out, h_t = self.rnn(x)
+
+ if self.cell_type == 'lstm':
+ final_hiddens, final_cells = h_t
+ else:
+ final_hiddens, final_cells = h_t, None
+
+ if self.dropout_fc>0:
+ encoder_out = F.dropout(encoder_out, p=self.dropout_fc, training=self.training)
+
+ if self.bidirectional:
+ batch_size = src_seq.size(0)
+
+ def combine_bidir(outs):
+ out = outs.view(self.num_layers, 2, batch_size, -1).transpose(1, 2).contiguous()
+ return out.view(self.num_layers, batch_size, -1)
+
+ final_hiddens = combine_bidir(final_hiddens)
+ if self.cell_type == 'lstm':
+ final_cells = combine_bidir(final_cells)
+
+ # T x B x C -> B x T x C
+ encoder_out = encoder_out.transpose(0, 1)
+ final_hiddens = final_hiddens.transpose(0, 1)
+ if self.cell_type == 'lstm':
+ final_cells = final_cells.transpose(0, 1)
+
+ return encoder_out, (final_hiddens, final_cells)
+
+
+class AttentionLayer(nn.Module):
+ def __init__(self, input_embed_dim, source_embed_dim, output_embed_dim, bias=False):
+ super(AttentionLayer, self).__init__()
+
+ self.input_proj = Linear(input_embed_dim, source_embed_dim, bias=bias)
+ self.output_proj = Linear(input_embed_dim + source_embed_dim, output_embed_dim, bias=bias)
+
+ def forward(self, input, source_hids, encoder_padding_mask=None):
+ # input: bsz x input_embed_dim
+ # source_hids: srclen x bsz x output_embed_dim
+
+ # x: bsz x output_embed_dim
+ x = self.input_proj(input)
+
+ # compute attention
+ attn_scores = (source_hids * x.unsqueeze(0)).sum(dim=2)
+
+ # don't attend over padding
+ if encoder_padding_mask is not None:
+ attn_scores = attn_scores.float().masked_fill_(
+ encoder_padding_mask,
+ float('-inf')
+ ).type_as(attn_scores) # FP16 support: cast to float and back
+
+ attn_scores = F.softmax(attn_scores, dim=0) # srclen x bsz
+
+ # sum weighted sources
+ x = (attn_scores.unsqueeze(2) * source_hids).sum(dim=0)
+ out = torch.cat((x, input), dim=1)
+ x = F.tanh(self.output_proj(out))
+ return x, attn_scores
+
+
+class RNNDecoder(nn.Module):
+ """RNN decoder."""
+
+ def __init__(
+ self, cell_type='lstm', feat_dim=3, hidden_size=128, num_layers=1,
+ dropout_fc=0.1, dropout_rnn=0.0, encoder_output_units=128, max_seq_len=10,
+ attention=True, traj_attn_intent_dim=0,**kwargs):
+ super(RNNDecoder, self).__init__()
+ self.cell_type = cell_type.lower()
+ self.dropout_fc = dropout_fc
+ self.dropout_rnn = dropout_rnn
+ self.hidden_size = hidden_size
+ self.encoder_output_units = encoder_output_units
+ self.num_layers = num_layers
+ self.max_seq_len = max_seq_len
+ self.feat_dim = feat_dim
+ self.traj_attn_intent_dim =traj_attn_intent_dim
+ input_size = feat_dim
+
+ if encoder_output_units != hidden_size:
+ self.encoder_hidden_proj = Linear(encoder_output_units, hidden_size)
+ if self.cell_type == 'lstm':
+ self.encoder_cell_proj = Linear(encoder_output_units, hidden_size)
+ else:
+ self.encoder_cell_proj = None
+ else:
+ self.encoder_hidden_proj = self.encoder_cell_proj = None
+
+ if self.cell_type == 'lstm':
+ self.cell = LSTM
+ elif self.cell_type == 'gru':
+ self.cell = GRU
+ else:
+ self.cell = RNN
+
+ self.rnn = self.cell(input_size=input_size,
+ hidden_size=hidden_size, bidirectional=False,
+ dropout=self.dropout_rnn,
+ num_layers=num_layers)
+
+ if attention:
+ self.attention = AttentionLayer(hidden_size, encoder_output_units, hidden_size, bias=False)
+ else:
+ self.attention = None
+
+ self.output_projection = Linear(hidden_size, feat_dim)
+
+ if traj_attn_intent_dim>0:
+ self.traj_attn_fc = Linear(hidden_size, traj_attn_intent_dim)
+
+ def forward(self, encoder_out_list, start_decode=None, encoder_mask=None):
+
+ x = start_decode.unsqueeze(1)
+ bsz = x.size(0)
+
+ # get outputs from encoder
+ encoder_outs, (encoder_hiddens, encoder_cells) = encoder_out_list
+ # B x T x C -> T x B x C
+ encoder_outs = encoder_outs.transpose(0, 1)
+ encoder_hiddens = encoder_hiddens.transpose(0, 1)
+ if encoder_mask is not None:
+ encoder_mask = encoder_mask.transpose(0, 1)
+ prev_hiddens = [encoder_hiddens[i] for i in range(self.num_layers)]
+ if self.cell_type == 'lstm':
+ encoder_cells = encoder_cells.transpose(0, 1)
+ prev_cells = [encoder_cells[i] for i in range(self.num_layers)]
+
+ x = x.transpose(0, 1)
+ srclen = encoder_outs.size(0)
+
+ # initialize previous states
+
+ if self.encoder_hidden_proj is not None:
+ prev_hiddens = [self.encoder_hidden_proj(x) for x in prev_hiddens]
+ prev_hiddens = torch.stack(prev_hiddens, dim=0)
+ if self.encoder_cell_proj is not None:
+ prev_cells = [self.encoder_cell_proj(x) for x in prev_cells]
+ if self.cell_type == 'lstm':
+ prev_cells = torch.stack(prev_cells, dim=0)
+
+ attn_scores = x.new_zeros(srclen, self.max_seq_len, bsz)
+ inp = x
+ outs = []
+ hidden_outs=[]
+ for j in range(self.max_seq_len):
+ if self.cell_type == 'lstm':
+ output, (prev_hiddens, prev_cells) = self.rnn(inp, (prev_hiddens, prev_cells))
+ else:
+ output, prev_hiddens = self.rnn(inp, prev_hiddens)
+ output = output.view(bsz, -1)
+ # apply attention using the last layer's hidden state
+ if self.attention is not None:
+ out, attn_scores[:, j, :] = self.attention(output, encoder_outs, encoder_mask)
+ else:
+ out = output
+ if self.dropout_fc>0:
+ out = F.dropout(out, p=self.dropout_fc, training=self.training)
+ hid_out = out
+ if self.traj_attn_intent_dim > 0:
+ hid_out= self.traj_attn_fc(hid_out)
+ hid_out = F.selu(hid_out)
+ hidden_outs.append(hid_out)
+
+ out = self.output_projection(out)
+ # save final output
+ outs.append(out)
+
+ inp = out.unsqueeze(0)
+
+ # collect outputs across time steps
+ x = torch.cat(outs, dim=0).view(self.max_seq_len, bsz, self.feat_dim)
+ hidden_outs = torch.cat(hidden_outs, dim=0).view(self.max_seq_len, bsz, -1)
+ # T x B x C -> B x T x C
+ x = x.transpose(1, 0)
+ hidden_outs=hidden_outs.transpose(1, 0)
+ # srclen x tgtlen x bsz -> bsz x tgtlen x srclen
+ attn_scores = attn_scores.transpose(0, 2)
+
+ # project back to input space
+
+ return x, hidden_outs
+
+
+def LSTM(input_size, hidden_size, **kwargs):
+ m = nn.LSTM(input_size, hidden_size, **kwargs)
+ for name, param in m.named_parameters():
+ if 'weight' in name or 'bias' in name:
+ param.data.uniform_(-0.1, 0.1)
+ return m
+
+
+def GRU(input_size, hidden_size, **kwargs):
+ m = nn.GRU(input_size, hidden_size, **kwargs)
+ for name, param in m.named_parameters():
+ if 'weight' in name or 'bias' in name:
+ param.data.uniform_(-0.1, 0.1)
+ return m
+
+
+def RNN(input_size, hidden_size, **kwargs):
+ m = nn.RNN(input_size, hidden_size, **kwargs)
+ for name, param in m.named_parameters():
+ if 'weight' in name or 'bias' in name:
+ param.data.uniform_(-0.1, 0.1)
+ return m
+
+
+def Linear(in_features, out_features, bias=True):
+ """Linear layer (input: N x T x C)"""
+ m = nn.Linear(in_features, out_features, bias=bias)
+ m.weight.data.uniform_(-0.1, 0.1)
+ if bias:
+ m.bias.data.uniform_(-0.1, 0.1)
+ return m
+
diff --git a/parameters.py b/parameters.py
new file mode 100644
index 0000000..b955096
--- /dev/null
+++ b/parameters.py
@@ -0,0 +1,369 @@
+import json
+import os
+
+class hyper_parameters(object):
+ def __init__(self, input_time_step=20, output_time_step=50,
+ data_dir='data/',dataset='vehicle_ngsim',model_type='rnn',
+ num_class=3,coordinate_dim=2,inp_feat=('traj', 'speed')): #32
+
+ self.input_time_step = input_time_step
+ self.output_time_step = output_time_step
+ self.coordinate_dim = coordinate_dim
+ self.inten_num_class = num_class
+
+ self.inp_feat = inp_feat
+ self.out_feat = ('speed',)
+ self.encoder_feat_dim = self.coordinate_dim * len(self.inp_feat) # x, v
+ self.decoder_feat_dim = self.coordinate_dim * 1 # v
+
+ self.data_dir = data_dir
+ self.dataset = dataset
+ self.model_type = model_type
+ self.params_dict = {}
+
+
+ def _set_default_dataset_params(self):
+ if self.dataset=='human_kinect':
+ self.inp_feat = ('traj', 'speed')
+ self.input_time_step = 20
+ self.output_time_step = 10
+ self.inten_num_class = 12
+ self.coordinate_dim = 3
+ self.encoder_feat_dim = self.coordinate_dim * 2 # x,v
+ self.decoder_feat_dim =self.coordinate_dim
+
+ if self.dataset=='human_mocap':
+ self.inp_feat = ('traj', 'speed')
+ self.input_time_step = 20
+ self.output_time_step = 10
+ self.inten_num_class = 3
+ self.coordinate_dim = 3
+ self.encoder_feat_dim = self.coordinate_dim * 2
+ self.decoder_feat_dim =self.coordinate_dim
+
+ if self.dataset=='vehicle_holomatic':
+ self.inp_feat = ('feature', 'speed')
+ self.input_time_step = 20
+ self.output_time_step = 50
+ self.inten_num_class = 5
+ self.coordinate_dim = 2
+ self.traj_feature_dim = 8
+ self.encoder_feat_dim =self.coordinate_dim + self.traj_feature_dim
+ self.decoder_feat_dim = self.coordinate_dim
+
+ if self.dataset=='vehicle_ngsim':
+ self.inp_feat = ('feature', 'speed')
+ self.input_time_step = 20
+ self.output_time_step = 50
+ self.inten_num_class = 3
+ self.coordinate_dim = 2
+ self.traj_feature_dim = 4
+ self.encoder_feat_dim =self.coordinate_dim + self.traj_feature_dim
+ self.decoder_feat_dim = self.coordinate_dim
+
+ def train_param(self, param_dict=None):
+ default_train_params = dict(
+ dataset=self.dataset,
+ data_path=self.data_dir + self.dataset + '.pkl',
+ save_dir='output/'+self.dataset+'/'+self.model_type +'/',
+ init_model=None,
+ normalize_data=True,
+ input_time_step=self.input_time_step,
+ output_time_step=self.output_time_step,
+ inp_feat = self.inp_feat,
+
+ traj_intent_loss_ratio=[1, 0.1], # traj loss : intent loss
+ lr=0.01,
+ lr_schedule='multistep', # multistep
+ lr_decay_epochs=[7, 14],
+ lr_decay=0.1,
+ epochs=20,
+ batch_size=128,
+
+ coordinate_dim=self.coordinate_dim,
+ encoder=self.model_type,
+ encoder_feat_dim=self.encoder_feat_dim,
+
+ decoder=self.model_type,
+ decoder_feat_dim = self.decoder_feat_dim,
+
+
+ class_num=self.inten_num_class,
+ pool_type='linear_attn',
+ label_smooth=0.1,
+ traj_attn_intent_dim=64,
+ )
+ if param_dict is None and 'train_param' in self.params_dict:
+ param_dict = self.params_dict['train_param']
+ params = self._overwrite_params(default_train_params, param_dict)
+ params['log_dir'] = params['save_dir'] + 'log/'
+ dir_split = params['log_dir'].replace('\\','/').split('/')
+ base_dir=''
+ for _path in dir_split:
+ base_dir =os.path.join(base_dir,_path)
+ if not os.path.exists(base_dir):
+ os.mkdir(base_dir)
+
+ if params['encoder'] == 'fc':
+ params['pool_type'] = 'none'
+
+ return params
+
+ def encode_rnn_param(self, param_dict=None):
+ default_rnn_params = dict(
+ cell_type='gru',
+ feat_dim=self.encoder_feat_dim,
+ max_seq_len=self.input_time_step,
+ hidden_size=64,
+ num_layers=1,
+ dropout_fc=0.,
+ dropout_rnn=0.,
+ bidirectional=False,
+ )
+ if param_dict is None and 'encode_rnn_param' in self.params_dict:
+ param_dict = self.params_dict['encode_rnn_param']
+ param = self._overwrite_params(default_rnn_params, param_dict)
+ return param
+
+ def encode_fc_param(self, param_dict=None):
+ default_fc_params = dict(
+ feat_dim=self.encoder_feat_dim,
+ max_seq_len=self.input_time_step,
+ hidden_dims=[64,64],#[128,128],#[128,128,64]
+ dropout=0.,
+ )
+ if param_dict is None and 'encode_fc_param' in self.params_dict:
+ param_dict = self.params_dict['encode_fc_param']
+ param = self._overwrite_params(default_fc_params, param_dict)
+ return param
+
+ def decode_rnn_param(self, param_dict=None):
+ default_rnn_params = dict(
+ cell_type='gru',
+ feat_dim = self.decoder_feat_dim,
+ max_seq_len=self.output_time_step,
+ hidden_size=64,
+ num_layers=1,
+ dropout_fc=0.,
+ dropout_rnn=0.,
+ attention=True,
+ )
+ if param_dict is None and 'decode_rnn_param' in self.params_dict:
+ param_dict = self.params_dict['decode_rnn_param']
+ param = self._overwrite_params(default_rnn_params, param_dict)
+ return param
+
+ def decode_fc_param(self, param_dict=None):
+ default_fc_params = dict(
+ feat_dim=self.decoder_feat_dim,
+ max_seq_len=self.output_time_step,
+ hidden_dims=[64,64],
+ dropout=0.,
+ )
+ if param_dict is None and 'decode_fc_param' in self.params_dict:
+ param_dict = self.params_dict['decode_fc_param']
+ param = self._overwrite_params(default_fc_params, param_dict)
+ return param
+
+ def classifier_fc_param(self, param_dict=None):
+ default_fc_params = dict(
+ hidden_dims=[64],
+ dropout=0.,
+ num_class=self.inten_num_class,
+ )
+ if param_dict is None and 'classifier_fc_param' in self.params_dict:
+ param_dict = self.params_dict['classifier_fc_param']
+ param = self._overwrite_params(default_fc_params, param_dict)
+ return param
+
+ def print_params(self):
+ print('train parameters:')
+ t_param = self.train_param()
+ print(t_param)
+ print('encode_param:')
+ encode_param = self.encode_fc_param() if t_param['encoder'] == 'fc' else self.encode_rnn_param()
+ print(encode_param)
+ print('decode_param:')
+ decode_param = self.decode_fc_param() if t_param['decoder'] == 'fc' else self.decode_rnn_param()
+ print(decode_param)
+ print('classifier_fc_param:')
+ print(self.classifier_fc_param())
+
+ def _overwrite_params(self, old_param, new_param):
+ if new_param is None:
+ return old_param
+ for k, v in new_param.items():
+ old_param[k] = v
+ return old_param
+
+ def _save_parameters(self, log_dir=None):
+ params_dict = {}
+ params_dict['train_param'] = self.train_param()
+ params_dict['encode_rnn_param'] = self.encode_rnn_param()
+ params_dict['encode_fc_param'] = self.encode_fc_param()
+ params_dict['decode_rnn_param'] = self.decode_rnn_param()
+ params_dict['decode_fc_param'] = self.decode_fc_param()
+ params_dict['classifier_fc_param'] = self.classifier_fc_param()
+
+ if log_dir is None:
+ log_dir = params_dict['train_param']['log_dir']
+
+ with open(log_dir + 'hyper_parameters.json', 'w') as f:
+ json.dump(params_dict, f)
+
+ def _save_overwrite_parameters(self, params_key, params_value, log_dir=None):
+ params_dict = {}
+ params_dict['train_param'] = self.train_param()
+ params_dict['encode_rnn_param'] = self.encode_rnn_param()
+ params_dict['encode_fc_param'] = self.encode_fc_param()
+ params_dict['decode_rnn_param'] = self.decode_rnn_param()
+ params_dict['decode_fc_param'] = self.decode_fc_param()
+ params_dict['classifier_fc_param'] = self.classifier_fc_param()
+
+ params_dict[params_key] = params_value
+
+ if log_dir is None:
+ log_dir = params_dict['train_param']['log_dir']
+
+ with open(log_dir + 'hyper_parameters.json', 'w') as f:
+ json.dump(params_dict, f)
+
+ def _load_parameters(self, log_dir=None):
+ if log_dir is None:
+ log_dir = self.train_param()['log_dir']
+
+ with open(log_dir + 'hyper_parameters.json', 'r') as f:
+ self.params_dict = json.load(f)
+
+
+class adapt_hyper_parameters(object):
+ def __init__(self, adaptor='none',adapt_step=1,log_dir=None):
+ self.adaptor = adaptor
+ self.adapt_step = adapt_step
+ self.log_dir = log_dir
+ self.params_dict = {}
+
+ adaptor=adaptor.lower()
+ if adaptor=='nrls' or adaptor=='mekf' or adaptor=='mekf_ma':
+ self.adapt_param=self.mekf_param
+ elif adaptor=='sgd':
+ self.adapt_param=self.sgd_param
+ elif adaptor=='adam':
+ self.adapt_param=self.adam_param
+ elif adaptor=='lbfgs':
+ self.adapt_param=self.lbfgs_param
+
+ def strategy_param(self,param_dict=None):
+ default_params = dict(
+ adapt_step=self.adapt_step,
+ use_multi_epoch=True,
+ multiepoch_thresh=(-1, -1),
+ )
+ if param_dict is None and 'strategy_param' in self.params_dict:
+ param_dict = self.params_dict['strategy_param']
+ params = self._overwrite_params(default_params, param_dict)
+ return params
+
+ def mekf_param(self, param_dict=None):
+ default_params = dict(
+ p0=1e-2, # 1e-2
+ lbd=1-1e-6, # 1
+ sigma_r=1,
+ sigma_q=0,
+ lr=1, # 1
+
+ miu_v=0, #momentum
+ miu_p=0, # EMA of P
+ k_p=1, #look ahead of P
+
+ use_lookahead=False, # outer lookahead
+ la_k=1, # outer lookahead
+ la_alpha=1,
+ )
+ if param_dict is None and 'mekf_param' in self.params_dict:
+ param_dict = self.params_dict['mekf_param']
+ params = self._overwrite_params(default_params, param_dict)
+ return params
+
+ def sgd_param(self, param_dict=None):
+ default_params = dict(
+ lr=1e-6,
+ momentum=0.7,
+ nesterov=False,
+
+ use_lookahead=False, #look ahead
+ la_k=5,
+ la_alpha = 0.8,
+ )
+ if param_dict is None and 'sgd_param' in self.params_dict:
+ param_dict = self.params_dict['sgd_param']
+ param = self._overwrite_params(default_params, param_dict)
+ return param
+
+ def adam_param(self, param_dict=None):
+ default_params = dict(
+ lr=1e-6,
+ betas=(0.1, 0.99),
+ amsgrad=True,
+
+ use_lookahead=False, #look ahead
+ la_k=5,
+ la_alpha=0.8,
+ )
+ if param_dict is None and 'adam_param' in self.params_dict:
+ param_dict = self.params_dict['adam_param']
+ param = self._overwrite_params(default_params, param_dict)
+ return param
+
+ def lbfgs_param(self, param_dict=None):
+ default_params = dict(
+ lr=0.002,
+ max_iter=20,
+ history_size=100,
+
+ use_lookahead=False, #look ahead
+ la_k=1,
+ la_alpha=1.0,
+ )
+ if param_dict is None and 'lbfgs_param' in self.params_dict:
+ param_dict = self.params_dict['lbfgs_param']
+ param = self._overwrite_params(default_params, param_dict)
+ return param
+
+ def print_params(self):
+ print('adaptation optimizer:',self.adaptor)
+ print('adaptation strategy parameters:')
+ print(self.strategy_param())
+ if self.adaptor=='none':
+ print('no adaptation')
+ else:
+ print('adaptation optimizer parameters:')
+ print(self.adapt_param())
+
+ def _overwrite_params(self, old_param, new_param):
+ if new_param is None:
+ return old_param
+ for k, v in new_param.items():
+ old_param[k] = v
+ return old_param
+
+ def _save_parameters(self, log_dir=None):
+ params_dict = {}
+ params_dict['strategy_param'] = self.strategy_param()
+ params_dict['mekf_param'] = self.mekf_param()
+ params_dict['sgd_param'] = self.sgd_param()
+ params_dict['adam_param'] = self.adam_param()
+ params_dict['lbfgs_param'] = self.lbfgs_param()
+
+ if log_dir is None:
+ log_dir = self.log_dir
+
+ with open(log_dir + 'adapt_hyper_parameters.json', 'w') as f:
+ json.dump(params_dict, f)
+
+ def _load_parameters(self, log_dir=None):
+ if log_dir is None:
+ log_dir = self.log_dir
+
+ with open(log_dir + 'adapt_hyper_parameters.json', 'r') as f:
+ self.params_dict = json.load(f)
\ No newline at end of file
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..99f1bf9
--- /dev/null
+++ b/train.py
@@ -0,0 +1,168 @@
+# coding=utf-8
+
+import os
+import warnings
+
+import torch
+from parameters import hyper_parameters
+from dataset.dataset import get_data_loader
+from models.model_factory import create_model
+from utils.pred_utils import get_prediction_on_batch, get_predictions,get_position
+from utils.train_utils import CrossEntropyLoss, get_lr_schedule
+
+os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
+os.environ['CUDA_VISIBLE_DEVICES'] = "0"
+warnings.filterwarnings("ignore")
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+print('training with device:', device)
+
+
+def evaluate(model, data_loader, criterion_traj, criterion_intend, params, epoch=0, mark='valid'):
+ data_stats = {'data_mean': params['data_mean'], 'data_std': params['data_std']}
+ print('[Evaluation %s Set] -------------------------------' % mark)
+ out_str = "epoch: %d, " % (epoch)
+ with torch.no_grad():
+ dat = get_predictions(data_loader, model, device)
+ traj_hist, traj_preds, traj_labels, intent_preds, intent_labels, pred_start_pos = dat
+
+ loss_traj = criterion_traj(traj_preds, traj_labels)
+ loss_traj = loss_traj.cpu().detach().numpy()
+ traj_preds = get_position(traj_preds, pred_start_pos, data_stats)
+ traj_labels = get_position(traj_labels, pred_start_pos, data_stats)
+ mse = (traj_preds - traj_labels).pow(2).sum().float() / (traj_preds.size(0) * traj_preds.size(1))
+ mse = mse.cpu().detach().numpy()
+ out_str += "trajectory_loss: %.4f, trajectory_mse: %.4f, " % (loss_traj, mse)
+
+ loss_intent = criterion_intend(intent_preds, intent_labels)
+ loss_intent = loss_intent.cpu().detach().numpy()
+ _, pred_intent_cls = intent_preds.max(1)
+ label_cls = intent_labels
+ acc = (pred_intent_cls == label_cls).sum().float() / label_cls.size(0)
+ acc = acc.cpu().detach().numpy()
+ out_str += "intent_loss: %.4f, intent_acc: %.4f, " % (loss_intent, acc)
+
+ print(out_str)
+ print('-------------------------------')
+
+ log_dir = params['log_dir']
+ if not os.path.exists(log_dir + '%s.tsv' % mark):
+ with open(log_dir + 'test.tsv', 'a') as f:
+ f.write('epoch\ttraj_loss\tintent_loss\tmse\tacc\n')
+
+ with open(log_dir + '%s.tsv' % mark, 'a') as f:
+ f.write('%05d\t%f\t%f\t%f\t%f\n' % (epoch, loss_traj, loss_intent, mse, acc))
+ return acc, mse
+
+
+def train_on_batch(data, model, optimizer, criterion_traj, criterion_intend, params, print_result=False, epoch=0,
+ iter=0):
+ optimizer.zero_grad()
+ x, pred_traj, y_traj, pred_intent, y_intent, pred_start_pos = get_prediction_on_batch(data, model, device)
+
+ loss_traj = criterion_traj(pred_traj, y_traj)
+ loss_intent = criterion_intend(pred_intent, y_intent)
+ loss = params['traj_intent_loss_ratio'][0] * loss_traj + params['traj_intent_loss_ratio'][1] * loss_intent
+
+ loss.backward()
+ _ = torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
+ optimizer.step()
+
+ if print_result:
+ data_stats = {'data_mean': params['data_mean'], 'data_std': params['data_std']}
+ out_str = "epoch: %d, iter: %d, loss: %.4f " % (epoch, iter,loss.detach().cpu().numpy())
+
+ pred_traj = get_position(pred_traj, pred_start_pos, data_stats)
+ y_traj = get_position(y_traj, pred_start_pos, data_stats)
+ mse = (pred_traj - y_traj).pow(2).sum().float() / (pred_traj.size(0) * pred_traj.size(1))
+ mse = mse.cpu().detach().numpy()
+ loss_traj_val = loss_traj.cpu().detach().numpy()
+ out_str += "trajectory_loss: %.4f, trajectory_mse: %.4f, " % (loss_traj_val, mse)
+
+ _, pred_intent_cls = pred_intent.max(1)
+ label_cls = y_intent
+ acc = (pred_intent_cls == label_cls).sum().float() / label_cls.size(0)
+ acc = acc.cpu().detach().numpy()
+ loss_intent_val = loss_intent.cpu().detach().numpy()
+ out_str += "intent_loss: %.4f, intent_acc: %.4f, " % (loss_intent_val, acc)
+
+ print(out_str)
+ log_path = params['log_dir'] + 'train.tsv'
+ if not os.path.exists(log_path):
+ with open(log_path, 'a') as f:
+ f.write('epoch\titer\ttraj_loss\tintent_loss\tmse\tacc\n')
+
+ with open(log_path, 'a') as f:
+ f.write('%05d\t%05d\t%f\t%f\t%f\t%f\n' % (epoch, iter, loss_traj_val, loss_intent_val, mse, acc))
+
+ return loss
+
+
+def train(params):
+ train_params = params.train_param()
+
+ train_loader, valid_loader, test_loader, train_params = get_data_loader(train_params, mode='train')
+ params._save_overwrite_parameters(params_key='train_param', params_value=train_params)
+
+ train_params['data_mean'] = torch.tensor(train_params['data_stats']['speed_mean'], dtype=torch.float).unsqueeze(
+ 0).to(device)
+ train_params['data_std'] = torch.tensor(train_params['data_stats']['speed_std'], dtype=torch.float).unsqueeze(0).to(
+ device)
+
+ model = create_model(params)
+ model = model.to(device)
+
+ criterion_traj = torch.nn.MSELoss(reduction='mean').to(device)
+ criterion_intend = CrossEntropyLoss(class_num=train_params['class_num'],
+ label_smooth=train_params['label_smooth']).to(device)
+ optimizer = torch.optim.Adam(model.parameters(), lr=train_params['lr'])
+
+ scheduler = get_lr_schedule(train_params['lr_schedule'], train_params, optimizer)
+
+ best_result = {'valid_acc': 0, 'valid_mse': 99999, 'test_acc': 0, 'test_mse': 99999, 'epoch': 0}
+ print('begin to train')
+ for epoch in range(1, train_params['epochs'] + 1):
+ for i, data in enumerate(train_loader, 0):
+ print_result = True if i % train_params['print_step'] == 0 else False
+ train_on_batch(data, model, optimizer, criterion_traj, criterion_intend, params=train_params,
+ print_result=print_result, epoch=epoch, iter=i)
+
+
+ save_model_path = os.path.join(train_params['save_dir'], 'model_%d.pkl' % (epoch))
+ torch.save(model, save_model_path)
+ print('save model to', save_model_path)
+
+
+ model.eval()
+ valid_acc, valid_mse = evaluate(model, valid_loader, criterion_traj, criterion_intend, params=train_params,
+ epoch=epoch,
+ mark='valid')
+ test_acc, test_mse = evaluate(model, test_loader, criterion_traj, criterion_intend, params=train_params,
+ epoch=epoch,
+ mark='test')
+ model.train()
+ if valid_mse < best_result['valid_mse'] or valid_acc > best_result['valid_acc']:
+ best_result['valid_mse'] = valid_mse
+ best_result['valid_acc'] = valid_acc
+ best_result['test_mse'] = test_mse
+ best_result['test_acc'] = test_acc
+ best_result['epoch'] = epoch
+
+ if scheduler is not None:
+ scheduler.step(epoch)
+
+ print('Best Results (epoch %d):' % best_result['epoch'])
+ print('validation_acc = %f, validation_mse = %f, test_acc = %f, test_mse = %f'
+ % (best_result['valid_acc'], best_result['valid_mse'], best_result['test_acc'], best_result['test_mse']))
+ return model
+
+
+def main():
+ params = hyper_parameters(dataset='vehicle_ngsim', model_type='rnn')
+ params._set_default_dataset_params()
+ params.print_params()
+ train(params)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/utils/adapt_utils.py b/utils/adapt_utils.py
new file mode 100644
index 0000000..428d978
--- /dev/null
+++ b/utils/adapt_utils.py
@@ -0,0 +1,143 @@
+from time import time
+import numpy as np
+import torch
+from .pred_utils import get_prediction_on_batch
+
+data_size=100
+def batch2iter_data(dataloader, device='cpu',data_size=data_size):
+ traj_hist, traj_labels, intent_labels, start_decodes, pred_start_pos, x_mask = None, None, None, None, None, None
+
+ for i, data in enumerate(dataloader, 0):
+ x, y_traj, y_intent, start_decode, start_pos, mask = data
+ if traj_hist is None:
+ traj_hist = x
+ traj_labels = y_traj
+ intent_labels = y_intent
+ start_decodes = start_decode
+ pred_start_pos = start_pos
+ x_mask = mask
+ else:
+ traj_hist = torch.cat([traj_hist, x], dim=0)
+ traj_labels = torch.cat([traj_labels, y_traj], dim=0)
+ intent_labels = torch.cat([intent_labels, y_intent], dim=0)
+ start_decodes = torch.cat([start_decodes, start_decode], dim=0)
+ pred_start_pos = torch.cat([pred_start_pos, start_pos], dim=0)
+ x_mask = torch.cat([x_mask, mask], dim=0)
+
+ if data_size>0 and traj_hist.size(0)>data_size:
+ break
+
+ traj_hist = traj_hist.float().to(device)
+ traj_labels = traj_labels.float().to(device)
+ intent_labels = intent_labels.float().to(device)
+ start_decodes = start_decodes.float().to(device)
+ pred_start_pos = pred_start_pos.float().to(device)
+ x_mask = x_mask.byte().to(device)
+ data = [traj_hist, traj_labels, intent_labels, start_decodes, pred_start_pos, x_mask]
+
+ return data
+
+
+def online_adaptation(dataloader, model, optimizer, params, device,
+ adapt_step=1, use_multi_epoch=False,multiepoch_thresh=(0, 0)):
+ optim_name = optimizer.__class__.__name__
+ if optim_name == 'Lookahead':
+ optim_name = optim_name + '_' + optimizer.optimizer.__class__.__name__
+ print('optimizer:', optim_name)
+ print('adapt_step:', adapt_step, ', use_multi_epoch:', use_multi_epoch,', multiepoch_thresh:', multiepoch_thresh),
+ t1 = time()
+
+ data = batch2iter_data(dataloader, device)
+ traj_hist, traj_labels, intent_labels, _, pred_start_pos, _ = data
+ batches = []
+ for ii in range(len(pred_start_pos)):
+ temp_batch=[]
+ for item in data:
+ temp_batch.append(item[[ii]])
+ batches.append(temp_batch)
+
+ traj_preds = torch.zeros_like(traj_labels)
+ intent_preds = torch.zeros(size=(len(intent_labels),params['class_num']),dtype=torch.float,device=intent_labels.device)
+
+ temp_pred_list = []
+ temp_label_list = []
+ temp_data_list = []
+ cnt = [0, 0, 0]
+ cost_list = []
+ post_cost_list=[]
+ for t in range(len(pred_start_pos)):
+ batch_data = batches[t]
+ _, pred_traj, y_traj, pred_intent, _, _ = get_prediction_on_batch(batch_data, model, device)
+ traj_preds[t] = pred_traj[0].detach()
+ intent_preds[t] = pred_intent[0].detach()
+
+ temp_pred_list += [pred_traj]
+ temp_label_list += [y_traj]
+ temp_data_list += [batch_data]
+ if len(temp_pred_list) > adapt_step:
+ temp_pred_list = temp_pred_list[1:]
+ temp_label_list = temp_label_list[1:]
+ temp_data_list = temp_data_list[1:]
+
+ if t < adapt_step - 1:
+ continue
+
+ Y = temp_label_list[0]
+ Y_hat = temp_pred_list[0]
+ full_loss =(Y - Y_hat).detach().pow(2).mean().cpu().numpy().round(6)
+ cost_list.append(full_loss)
+
+ Y_tau = Y[:, :adapt_step].contiguous().view((-1, 1))
+ Y_hat_tau = Y_hat[:, :adapt_step].contiguous().view((-1, 1))
+ err = (Y_tau - Y_hat_tau).detach()
+ curr_cost = err.pow(2).mean().cpu().numpy()
+ update_epoch = 1
+ if 0 <= multiepoch_thresh[0] <= multiepoch_thresh[1]:
+ if curr_cost< multiepoch_thresh[0]:
+ update_epoch=1
+ elif curr_cost< multiepoch_thresh[1]:
+ update_epoch = 2
+ else:
+ update_epoch = 0
+ cnt[update_epoch] += 1
+ for cycle in range(update_epoch):
+ def mekf_closure(index=0):
+ optimizer.zero_grad()
+ dim_out = optimizer.optimizer.state['dim_out'] if 'Lookahead' in optim_name else optimizer.state['dim_out']
+ retain = index < dim_out - 1
+ Y_hat_tau[index].backward(retain_graph=retain)
+ return err
+
+ def lbfgs_closure():
+ optimizer.zero_grad()
+ temp_data = temp_data_list[0]
+ _, temp_pred_traj, temp_y_traj, _, _, _ = get_prediction_on_batch(temp_data, model, device)
+ y_tau = temp_y_traj[:, :adapt_step].contiguous().view((-1, 1))
+ y_hat_tau = temp_pred_traj[:, :adapt_step].contiguous().view((-1, 1))
+ loss = (y_tau - y_hat_tau).pow(2).mean()
+ loss.backward()
+ return loss
+
+ if 'MEKF' in optim_name:
+ optimizer.step(mekf_closure)
+ elif 'LBFGS' in optim_name:
+ optimizer.step(lbfgs_closure)
+ else:
+ loss = (Y_tau - Y_hat_tau).pow(2).mean()
+ loss.backward()
+ optimizer.step()
+
+ temp_data = temp_data_list[0]
+ _, post_pred_traj, post_y_traj, _, _, _ = get_prediction_on_batch(temp_data, model, device)
+ post_loss = (post_pred_traj - post_y_traj).detach().pow(2).mean().cpu().numpy().round(6)
+ post_cost_list.append(post_loss)
+
+
+ if t % 10 == 0:
+ print('finished pred {}, time:{}, partial cost before adapt:{}, partial cost after adapt:{}'.format(t, time() - t1, full_loss,post_loss))
+ t1 = time()
+
+ print('avg_cost:', np.mean(cost_list))
+ print('number of update epoch', cnt)
+ return traj_hist, traj_preds,traj_labels, intent_preds, intent_labels, pred_start_pos
+
diff --git a/utils/pred_utils.py b/utils/pred_utils.py
new file mode 100644
index 0000000..4431caa
--- /dev/null
+++ b/utils/pred_utils.py
@@ -0,0 +1,46 @@
+
+import torch
+
+def get_prediction_on_batch(data, model, device='cpu'):
+ x, y_traj, y_intent, start_decode, pred_start_pos,x_mask = data
+ x = x.float().to(device)
+ y_traj = y_traj.float().to(device)
+ y_intent = y_intent.long().to(device)
+ start_decode = start_decode.float().to(device)
+ x_mask = x_mask.byte().to(device)
+ pred_start_pos = pred_start_pos.float().to(device)
+
+ pred_traj, pred_intent = model(src_seq=x, start_decode=start_decode, encoder_mask=x_mask)
+ return x, pred_traj, y_traj, pred_intent, y_intent, pred_start_pos
+
+def get_predictions(dataloader, model, device,data_size=-1):
+ traj_hist, traj_preds, traj_labels, intent_preds, intent_labels, pred_start_pos = None,None,None,None,None,None
+
+ for i, data in enumerate(dataloader, 0):
+ x, pred_traj, y_traj, pred_intent, y_intent, start_pos = get_prediction_on_batch(data, model, device)
+ if traj_hist is None:
+ traj_hist = x
+ traj_preds = pred_traj
+ traj_labels = y_traj
+ intent_preds = pred_intent
+ intent_labels = y_intent
+ pred_start_pos = start_pos
+ else:
+ traj_hist = torch.cat([traj_hist,x],dim=0)
+ traj_labels = torch.cat([traj_labels, y_traj], dim=0)
+ intent_labels = torch.cat([intent_labels, y_intent], dim=0)
+ pred_start_pos = torch.cat([pred_start_pos, start_pos], dim=0)
+ traj_preds = torch.cat([traj_preds, pred_traj], dim=0)
+ intent_preds = torch.cat([intent_preds, pred_intent], dim=0)
+ if data_size>0 and traj_hist.size(0)>data_size:
+ break
+ return traj_hist, traj_preds, traj_labels, intent_preds, intent_labels, pred_start_pos
+
+def get_position(speed, start_pose=None, data_stats=None):
+ if data_stats is None or start_pose is None:
+ return speed
+ speed = speed * data_stats['data_std'] + data_stats['data_mean']
+ displacement = torch.cumsum(speed, dim=1)
+ start_pose = torch.unsqueeze(start_pose, dim=1)
+ position = displacement + start_pose
+ return position
\ No newline at end of file
diff --git a/utils/train_utils.py b/utils/train_utils.py
new file mode 100644
index 0000000..339b7a7
--- /dev/null
+++ b/utils/train_utils.py
@@ -0,0 +1,36 @@
+import torch
+
+
+class CrossEntropyLoss(torch.nn.Module):
+ """ Cross entropy that accepts label smoothing"""
+
+ def __init__(self, class_num=12, label_smooth=0, size_average=True):
+ super(CrossEntropyLoss, self).__init__()
+ self.class_num = class_num
+ self.size_average = size_average
+ self.label_smooth = label_smooth
+
+ def forward(self, input, target):
+ logsoftmax = torch.nn.LogSoftmax()
+ one_hot_target = torch.zeros(target.size()[0], self.class_num,device=target.device)
+ one_hot_target = one_hot_target.scatter_(1, target.unsqueeze(1), 1)
+ if self.label_smooth > 0:
+ one_hot_target = (1 - self.label_smooth) * one_hot_target + self.label_smooth * (1 - one_hot_target)
+ if self.size_average:
+ return torch.mean(torch.sum(-one_hot_target * logsoftmax(input), dim=1))
+ else:
+ return torch.sum(torch.sum(-one_hot_target * logsoftmax(input), dim=1))
+
+
+def get_lr_schedule(lr_schedule, params, optimizer):
+ if lr_schedule == 'multistep':
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=params['lr_decay_epochs'],
+ gamma=params['lr_decay'], )
+ elif lr_schedule == 'cyclic':
+ scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, params['lr'] / 5, params['lr'],
+ step_size_up=params['period'],
+ mode='triangular', gamma=1.0, )
+ else:
+ scheduler = None
+
+ return scheduler