Skip to content

Commit

Permalink
utils folder restructured
Browse files Browse the repository at this point in the history
  • Loading branch information
syedshabbirahmed committed May 31, 2024
1 parent 1dce296 commit 0656eca
Show file tree
Hide file tree
Showing 12 changed files with 1,795 additions and 1,245 deletions.
8 changes: 4 additions & 4 deletions examples/ex_imm_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def covariance(self, x, u, dt) -> np.ndarray:
off_diag_p = 0.02
Pi = np.ones((n_models, n_models)) * off_diag_p
Pi = Pi + (1 - off_diag_p * (n_models)) * np.diag(np.ones(n_models))
imm = nav.imm.InteractingModelFilter(kf_list, Pi)
imm = nav.InteractingModelFilter(kf_list, Pi)


dg = nav.DataGenerator(
Expand All @@ -102,16 +102,16 @@ def imm_trial(trial_number: int) -> List[nav.GaussianResult]:

x0_check = x0.plus(nav.randvec(P0))

estimate_list = nav.imm.run_imm_filter(
estimate_list = nav.run_imm_filter(
imm, x0_check, P0, input_list, meas_list
)

results = [
nav.imm.IMMResult(estimate_list[i], state_true[i])
nav.IMMResult(estimate_list[i], state_true[i])
for i in range(len(estimate_list))
]

return nav.imm.IMMResultList(results)
return nav.IMMResultList(results)


def ekf_trial(trial_number: int) -> List[nav.GaussianResult]:
Expand Down
3 changes: 2 additions & 1 deletion examples/ex_imm_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import navlie as nav
from navlie.lib.models import DoubleIntegrator, RangePointToAnchor, VectorState
from navlie.imm import InteractingModelFilter, run_imm_filter, IMMResultList
from navlie.filters import InteractingModelFilter, run_imm_filter
from navlie.utils import IMMResultList
import numpy as np
from typing import List
from matplotlib import pyplot as plt
Expand Down
44 changes: 29 additions & 15 deletions navlie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,51 @@
UnscentedKalmanFilter,
CubatureKalmanFilter,
GaussHermiteKalmanFilter,
InteractingModelFilter,
run_filter,
run_imm_filter,
)
from . import batch
from . import imm
from . import lib
from . import utils
from .batch import BatchEstimator

from .datagen import DataGenerator, generate_measurement
from .utils import (

from .composite import (
CompositeState,
CompositeProcessModel,
CompositeMeasurementModel,
CompositeInput,
)

from .lib.states import StampedValue # for backwards compatibility

from .utils.common import (
state_interp,
GaussianResult,
GaussianResultList,
MonteCarloResult,
plot_error,
plot_meas,
plot_poses,
plot_nees,
IMMResult,
IMMResultList,
monte_carlo,
randvec,
van_loans,
state_interp,
schedule_sequential_measurements,
associate_stamps,
set_axes_equal,
find_nearest_stamp_idx,
randvec,
jacobian,
)

from .composite import (
CompositeState,
CompositeProcessModel,
CompositeMeasurementModel,
CompositeInput,
from .utils.plot import (
plot_error,
plot_nees,
plot_meas,
plot_meas_by_model,
plot_poses,
set_axes_equal
)

from .lib.states import StampedValue # for backwards compatibility
from .utils.mixture import (
gaussian_mixing,
)
239 changes: 239 additions & 0 deletions navlie/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
Measurement,
StateWithCovariance,
)
from navlie.lib.states import IMMState
from navlie.utils.mixture import gaussian_mixing
import numpy as np
from scipy.stats.distributions import chi2
from scipy.stats import multivariate_normal
from numpy.polynomial.hermite_e import hermeroots
from math import factorial
from scipy.special import eval_hermitenorm
Expand Down Expand Up @@ -817,6 +820,166 @@ def __init__(
)


class InteractingModelFilter:
"""
On-manifold Interacting Multiple Model Filter (IMM).
References for the IMM:
H. A. P. Blom and Y. Bar-Shalom, "The interacting
multiple model algorithm for systems with Markovian switching coefficients,"
in IEEE Transactions on Automatic Control, vol. 33, no. 8, pp. 780-783, Aug.
1988, doi: 10.1109/9.1299.
The IMM involves Gaussian mixtures. Reference for mixing Gaussians on
manifolds:
J. Ćesić, I. Marković and I. Petrović, "Mixture Reduction on Matrix Lie
Groups," in IEEE Signal Processing Letters, vol. 24, no. 11, pp.
1719-1723, Nov. 2017, doi: 10.1109/LSP.2017.2723765.
"""

def __init__(
self, kf_list: List[ExtendedKalmanFilter], transition_matrix: np.ndarray
):
"""
Initialize InteractingModelFilter.
Parameters
----------
kf_list : List[ExtendedKalmanFilter]
A list of filter instances which correspond to
each model of the IMM.
transition_matrix : np.ndarray
Probability transition matrix corresponding to the IMM models.
"""
self.kf_list = kf_list
self.transition_matrix = transition_matrix

def interaction(
self,
x: IMMState,
):
"""The interaction (mixing) step of the IMM.
Parameters
----------
x : IMMState
Returns
-------
IMMState
"""

x_km_models = x.model_states.copy()
mu_models = np.array(x.model_probabilities)

n_modes = self.transition_matrix.shape[0]
c = self.transition_matrix.T @ mu_models.reshape(-1, 1)

mu_mix = np.zeros((n_modes, n_modes))
for i in range(n_modes):
for j in range(n_modes):
mu_mix[i, j] = (
1.0 / c[j] * self.transition_matrix[i, j] * mu_models[i]
)
x_mix = []

for j in range(n_modes):
weights = list(mu_mix[:, j])
x_mix.append(gaussian_mixing(weights, x_km_models))

return IMMState(x_mix, mu_models)

def predict(self, x_km: IMMState, u: Input, dt: float):
"""
Carries out prediction step for each model of the IMM.
Parameters
----------
x_km : IMMState
Model estimates from previous timestep, after mixing.
u : Input
Input
dt : Float
Timestep
Returns
-------
IMMState
"""

x_km_models = x_km.model_states.copy()
x_check = []
for lv1, kf in enumerate(self.kf_list):
x_check.append(kf.predict(x_km_models[lv1], u, dt))
return IMMState(x_check, x_km.model_probabilities)

def correct(
self,
x_check: IMMState,
y: Measurement,
u: Input,
):
"""
Carry out the correction step for each model and update model
probabilities.
Parameters
----------
x_check: IMMState mu_km_models : List[Float]
Probabilities for each model from previous timestep.
y : Measurement
Measurement to be fused into the current state estimate.
u: Input
Most recent input, to be used to predict the state forward if the
measurement stamp is larger than the state stamp.
Returns
-------
IMMState
Corrected state estimates and probabilities
"""
x_models_check = x_check.model_states.copy()
mu_km_models = x_check.model_probabilities.copy()
n_modes = len(x_models_check)
mu_k = np.zeros(n_modes)

# Compute each model's normalization constant
c_bar = np.zeros(n_modes)
for i in range(n_modes):
for j in range(n_modes):
c_bar[j] = (
c_bar[j] + self.transition_matrix[i, j] * mu_km_models[i]
)

# Correct and update model probabilities
x_hat = []
for lv1, kf in enumerate(self.kf_list):
x, details_dict = kf.correct(
x_models_check[lv1], y, u, output_details=True
)
x_hat.append(x)
z = details_dict["z"]
S = details_dict["S"]
z = z.ravel()
model_likelihood = multivariate_normal.pdf(
z, mean=np.zeros(z.shape), cov=S
)
mu_k[lv1] = model_likelihood * c_bar[lv1]

# If all model likelihoods are zero to machine tolerance, np.sum(mu_k)=0 and it fails
# Add this fudge factor to get through those cases.
if np.allclose(mu_k, np.zeros(mu_k.shape)):
mu_k = 1e-10 * np.ones(mu_k.shape)

mu_k = mu_k / np.sum(mu_k)

return IMMState(x_hat, mu_k)


def run_filter(
filter: ExtendedKalmanFilter,
x0: State,
Expand Down Expand Up @@ -883,3 +1046,79 @@ def run_filter(
x = filter.predict(x, u, dt)

return results_list


def run_imm_filter(
filter: InteractingModelFilter,
x0: State,
P0: np.ndarray,
input_data: List[Input],
meas_data: List[Measurement],
) -> List[StateWithCovariance]:
"""
Executes an InteractingMultipleModel filter
Parameters
----------
filter: An InteractingModelFilter instance:
_description_
ProcessModel: Callable, must return a process model:
_description_
Q_profile: Callable, must return a square np.array compatible with process model:
_description_
x0 : State
_description_
P0 : np.ndarray
_description_
input_data : List[Input]
_description_
meas_data : List[Measurement]
_description_
"""
x = StateWithCovariance(x0, P0)
if x.state.stamp is None:
raise ValueError("x0 must have a valid timestamp.")

# Sort the data by time
input_data.sort(key=lambda x: x.stamp)
meas_data.sort(key=lambda x: x.stamp)

# Remove all that are before the current time
for idx, u in enumerate(input_data):
if u.stamp >= x.state.stamp:
input_data = input_data[idx:]
break

for idx, y in enumerate(meas_data):
if y.stamp >= x.state.stamp:
meas_data = meas_data[idx:]
break

meas_idx = 0
if len(meas_data) > 0:
y = meas_data[meas_idx]

results_list = []
n_models = filter.transition_matrix.shape[0]

x = IMMState(
[StateWithCovariance(x0, P0)] * n_models,
1.0 / n_models * np.array(np.ones(n_models)),
)
for k in tqdm(range(len(input_data) - 1)):
results_list.append(x)
u = input_data[k]
# Fuse any measurements that have occurred.
if len(meas_data) > 0:
while y.stamp < input_data[k + 1].stamp and meas_idx < len(
meas_data
):
x = filter.interaction(x)
x = filter.correct(x, y, u)
meas_idx += 1
if meas_idx < len(meas_data):
y = meas_data[meas_idx]
dt = input_data[k + 1].stamp - x.model_states[0].stamp
x = filter.predict(x, u, dt)

return results_list
Loading

0 comments on commit 0656eca

Please sign in to comment.