Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Walleclipse committed Dec 29, 2019
1 parent eb8028e commit f456f28
Show file tree
Hide file tree
Showing 18 changed files with 2,089 additions and 5 deletions.
15 changes: 10 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# MEKF_MAME
**M**odified **E**xtended **K**alman **F**ilter with generalized exponential **M**oving **A**verage and dynamic **M**ulti-**E**poch update strategy (MEKF_MAME)
# MEKF<sub>EMA-DME</sub>
**M**odified **E**xtended **K**alman **F**ilter with generalized **E**xponential **M**oving **A**verage and **D**ynamic **M**ulti-**E**poch update strategy (MEKF<sub>EMA-DME</sub>)

Pytorch implementation source coder for paper "Robust Nonlinear Adaptation Algorithms for Multi-TaskPrediction Networks".
Pytorch implementation source coder for paper [Robust Online Model Adaptation by Extended Kalman Filter with Exponential Moving Average and Dynamic Multi-Epoch Strategy](https://arxiv.org/abs/1912.01790).

**We will release the code, once the paper is published.**

In the paper, EKF based adaptation algorithm MEKF_λ was introduced as an effective base algorithm for online adaptation. In order to improve the convergence property of MEKF_λ, generalized exponential moving average filtering was investigated. Then this paper introduced a dynamic multi-epoch update strategy, which can be compatible with any optimizers. By combining all extensions with base MEKF_λ algorithm, robust online adaptation algorithm MEKF_MAME was created.
Inspired by Extended Kalman Filter (EKF), a base adaptation algorithm Modified EKF with forgetting
factor (MEKF_$\lambda$) is introduced first. Using exponential moving average (EMA) methods, this
paper proposes EMA filtering to the base EKF<sub>λ</sub> in order to increase the convergence rate. followed by exponential moving average filtering techniques.
Then in order to effectively utilize the samples in online
adaptation, this paper proposes a dynamic multi-epoch update strategy to discriminate the “hard”
samples from “easy” samples, and sets different weights for them. With all these extensions, we propose a robust online adaptation algorithm:
MEKF with Exponential Moving Average and Dynamic Multi-Epoch strategy (MEKF<sub>EMA-DME</sub>).

150 changes: 150 additions & 0 deletions adapt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# coding=utf-8

import os
import warnings

import joblib
import torch
import numpy as np
from dataset.dataset import get_data_loader
from adaptation.lookahead import Lookahead
from adaptation.mekf import MEKF_MA
from parameters import hyper_parameters, adapt_hyper_parameters
from utils.adapt_utils import online_adaptation
from utils.pred_utils import get_predictions,get_position

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
warnings.filterwarnings("ignore")


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device =torch.device("cpu")
print('testing with device:', device)

rnn_layer_name = ['encoder.rnn.weight_ih_l0', 'encoder.rnn.bias_ih_l0',
'encoder.rnn.weight_hh_l0','encoder.rnn.bias_hh_l0',
'decoder.rnn.weight_ih_l0', 'decoder.rnn.bias_ih_l0',
'decoder.rnn.weight_hh_l0','decoder.rnn.bias_hh_l0',
'decoder.output_projection.weight', 'decoder.output_projection.bias']
fc_layer_name = ['encoder.layers.0.weight', 'encoder.layers.0.bias',
'encoder.layers.3.weight', 'encoder.layers.3.bias',
'decoder.layers.0.weight', 'decoder.layers.0.bias',
'decoder.layers.3.weight', 'decoder.layers.3.bias',
'decoder.output_projection.weight', 'decoder.output_projection.bias', ]


def adaptable_prediction(data_loader, model, train_params, device, adaptor, adapt_step=1):
'''adaptation hyper param'''
adapt_params = adapt_hyper_parameters(adaptor=adaptor, adapt_step=adapt_step, log_dir=train_params['log_dir'])
adapt_params._save_parameters()
adapt_params.print_params()

adapt_weights = []
if train_params['encoder'] == 'rnn':
adapt_layers = rnn_layer_name[8:]
else:
adapt_layers = fc_layer_name[8:]
print('adapt_weights:')
print(adapt_layers)
for name, p in model.named_parameters():
if name in adapt_layers:
adapt_weights.append(p)
print(name, p.size())

optim_param = adapt_params.adapt_param()
if adaptor == 'mekf' or adaptor=='mekf_ma':
optimizer = MEKF_MA(adapt_weights, dim_out=adapt_step * train_params['coordinate_dim'],
p0=optim_param['p0'], lbd=optim_param['lbd'], sigma_r=optim_param['sigma_r'],
sigma_q=optim_param['sigma_q'], lr=optim_param['lr'],
miu_v=optim_param['miu_v'], miu_p=optim_param['miu_p'],
k_p=optim_param['k_p'])
elif adaptor == 'sgd':
optimizer = torch.optim.SGD(adapt_weights, lr=optim_param['lr'], momentum=optim_param['momentum'],
nesterov=optim_param['nesterov'])

elif adaptor == 'adam':
optimizer = torch.optim.Adam(adapt_weights, lr=optim_param['lr'], betas=optim_param['betas'],
amsgrad=optim_param['amsgrad'])

elif adaptor == 'lbfgs':
optimizer = torch.optim.LBFGS(adapt_weights, lr=optim_param['lr'], max_iter=optim_param['max_iter'],
history_size=optim_param['history_size'])
else:
raise NotImplementedError
print('base optimizer configs:', optimizer.defaults)
if optim_param['use_lookahead']:
optimizer = Lookahead(optimizer, k=optim_param['la_k'], alpha=optim_param['la_alpha'])

st_param = adapt_params.strategy_param()
pred_result = online_adaptation(data_loader, model, optimizer, train_params, device,
adapt_step=adapt_step,
use_multi_epoch=st_param['use_multi_epoch'],
multiepoch_thresh=st_param['multiepoch_thresh'])


return pred_result


def test(params, adaptor='none', adapt_step=1):
train_params = params.train_param()
train_params['data_mean'] = torch.tensor(train_params['data_stats']['speed_mean'], dtype=torch.float).unsqueeze(
0).to(device)
train_params['data_std'] = torch.tensor(train_params['data_stats']['speed_std'], dtype=torch.float).unsqueeze(0).to(
device)
data_stats = {'data_mean': train_params['data_mean'], 'data_std': train_params['data_std']}

model = torch.load(train_params['init_model'])
model = model.to(device)
print('load model', train_params['init_model'])

data_loader = get_data_loader(train_params, mode='test')
print('begin to test')
if adaptor == 'none':
with torch.no_grad():
pred_result = get_predictions(data_loader, model, device)
else:
pred_result = adaptable_prediction(data_loader, model, train_params, device, adaptor, adapt_step)

traj_hist, traj_preds, traj_labels, intent_preds, intent_labels, pred_start_pos = pred_result
traj_preds = get_position(traj_preds, pred_start_pos, data_stats)
traj_labels = get_position(traj_labels, pred_start_pos, data_stats)
intent_preds_prob = intent_preds.detach().clone()
_, intent_preds = intent_preds.max(1)

result = {'traj_hist': traj_hist, 'traj_preds': traj_preds, 'traj_labels': traj_labels,
'intent_preds': intent_preds,'intent_preds_prob':intent_preds_prob,
'intent_labels': intent_labels, 'pred_start_pos': pred_start_pos}

for k, v in result.items():
result[k] = v.cpu().detach().numpy()

out_str = 'Evaluation Result: \n'

num, time_step = result['traj_labels'].shape[:2]
mse = np.power(result['traj_labels'] - result['traj_preds'], 2).sum() / (num * time_step)
out_str += "trajectory_mse: %.4f, \n" % (mse)

acc = (result['intent_labels'] == result['intent_preds']).sum() / len(result['intent_labels'])
out_str += "action_acc: %.4f, \n" % (acc)

print(out_str)
save_path = train_params['log_dir'] + adaptor + str(adapt_step) + '_pred.pkl'
joblib.dump(result, save_path)
print('save result to', save_path)
return result


def main(dataset='vehicle_ngsim', model_type='rnn', adaptor='mekf',adapt_step=1):
save_dir = 'output/' + dataset + '/' + model_type + '/'
model_path = save_dir + 'model_1.pkl'
params = hyper_parameters()
params._load_parameters(save_dir + 'log/')
params.params_dict['train_param']['init_model'] = model_path
params.print_params()
test(params, adaptor=adaptor, adapt_step=adapt_step)


if __name__ == '__main__':
main()

Empty file added adaptation/__init__.py
Empty file.
70 changes: 70 additions & 0 deletions adaptation/lookahead.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from collections import defaultdict
from torch.optim import Optimizer
import torch


class Lookahead(Optimizer):
def __init__(self, optimizer, k=5, alpha=0.5):
self.optimizer = optimizer
self.k = k
self.alpha = alpha
self.param_groups = self.optimizer.param_groups
self.state = defaultdict(dict)
self.fast_state = self.optimizer.state
for group in self.param_groups:
group["counter"] = 0

def update(self, group):
for fast in group["params"]:
param_state = self.state[fast]
if "slow_param" not in param_state:
param_state["slow_param"] = torch.zeros_like(fast.data)
param_state["slow_param"].copy_(fast.data)
slow = param_state["slow_param"]
slow += (fast.data - slow) * self.alpha
fast.data.copy_(slow)

def update_lookahead(self):
for group in self.param_groups:
self.update(group)

def step(self, closure=None):
loss = self.optimizer.step(closure)
for group in self.param_groups:
if group["counter"] == 0:
self.update(group)
group["counter"] += 1
if group["counter"] >= self.k:
group["counter"] = 0
return loss

def state_dict(self):
fast_state_dict = self.optimizer.state_dict()
slow_state = {
(id(k) if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()
}
fast_state = fast_state_dict["state"]
param_groups = fast_state_dict["param_groups"]
return {
"fast_state": fast_state,
"slow_state": slow_state,
"param_groups": param_groups,
}

def load_state_dict(self, state_dict):
slow_state_dict = {
"state": state_dict["slow_state"],
"param_groups": state_dict["param_groups"],
}
fast_state_dict = {
"state": state_dict["fast_state"],
"param_groups": state_dict["param_groups"],
}
super(Lookahead, self).load_state_dict(slow_state_dict)
self.optimizer.load_state_dict(fast_state_dict)
self.fast_state = self.optimizer.state

def add_param_group(self, param_group):
param_group["counter"] = 0
self.optimizer.add_param_group(param_group)
Loading

0 comments on commit f456f28

Please sign in to comment.