Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/api_reference/public/inference/filter_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ The single `Filter()` handler is directed to the appropriate filtering algorithm
| Config class | Time domain | When it fits best |
|----------------------------|---------------------|-------------------|
| `KFConfig` | Discrete | Linear-Gaussian dynamics and linear-Gaussian observations (exact & optimal). |
| `EKFConfig` | Discrete | Nonlinear, differentiable Gaussian dynamics, nonlinear but differentiable Gaussian observations (approximate). *(default)*. |
| `EnKFConfig` | Discrete | Nonlinear or expensive models with Gaussian observations; cuthbert-backed and a good general-purpose default. *(default)* |
| `EKFConfig` | Discrete | Nonlinear, differentiable Gaussian dynamics, nonlinear but differentiable Gaussian observations (approximate). |
Comment thread
DanWaxman marked this conversation as resolved.
Outdated
| `UKFConfig` | Discrete | Nonlinear, differentiable Gaussian dynamics, nonlinear but differentiable Gaussian observations (approximate). Generally more accurate, but slower than `EKFConfig`. |
| `EnKFConfig` | Discrete | High-dimensional or expensive models with lower-dimensional structure and Gaussian observations (approximate). |
| `PFConfig` | Discrete | Applicable for arbitrary state-space models, but quite expensive and noisy estimates (asymptotically exact in the limit of infinite particles, approximate in practice). |
| `HMMConfig` | Discrete (HMM) | Finite discrete latent state space (exact & optimal). |
| `ContinuousTimeKFConfig` | Continuous-discrete | Linear-Gaussian SDE + linear-Gaussian observations (exact and optimal). |
Expand Down Expand Up @@ -48,4 +48,4 @@ The single `Filter()` handler is directed to the appropriate filtering algorithm
::: dynestyx.inference.filter_configs
options:
members:
- HMMConfig
- HMMConfig
74 changes: 43 additions & 31 deletions docs/deep_dives/discrete_time_lti_profile_likelihood.ipynb
Comment thread
mattlevine22 marked this conversation as resolved.

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ with Filter(filter_config=HMMConfig()):
return model(obs_times=obs_times, obs_values=obs_values)
```

- **Discrete-time**: Either a **Simulator** (NUTS samples both parameters and latent states) or a **Filter** (pseudo-marginal MCMC—parameters only). Note: the usage of discrete-time filters is currently under active development (likely incorrect implementations).
- **Discrete-time**: Either a **Simulator** (NUTS samples both parameters and latent states) or a **Filter** (parameters only, with latent states marginalized by a filtering algorithm). `Filter()` defaults to the cuthbert-backed EnKF for Gaussian observation models. Use `PFConfig` when you need non-Gaussian observations or a fully particle-based approximation.
For explicit representation of latent states (NUTS / SVI do all the work of parameter and latent state inference), use the simulator approach (currently working reliably), do:
```python
with DiscreteTimeSimulator():
return model(obs_times=obs_times, obs_values=obs_values)
```
For filter-based marginalization (currently not working reliably), do:
For filter-based marginalization with the default EnKF, do:
```python
with Filter():
return model(obs_times=obs_times, obs_values=obs_values)
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ Other JAX-based libraries for dynamical systems:
- **[dynamax](https://github.com/probml/dynamax)** — Discrete-time state space models with linear/non-linear Kalman filters and Bayesian parameter estimation
- **[cd-dynamax](https://github.com/hd-UQ/cd_dynamax)** — Continuous-discrete state space models with EnKF, EKF, UKF, PF and Bayesian parameter estimation
- **[PFJax](https://pfjax.readthedocs.io/en/latest/)** — Nonlinear and non-Gaussian discrete-time models with particle filters and particle MCMC
- **[Cuthbert](https://state-space-models.github.io/cuthbert/)** — Discrete-time state space models with linear/non-linear Kalman (and Particle Filters) filters, options for associative scans.
- **[Cuthbert](https://state-space-models.github.io/cuthbert/)** — Discrete-time state space models with linear/non-linear Kalman, ensemble Kalman, and particle filters, plus options for associative scans.
- **[diffrax](https://docs.kidger.site/diffrax/)** - Numerical differential equation solvers.

Other probabilistic programming languages with support for dynamical systems:
Expand Down
14 changes: 6 additions & 8 deletions dynestyx/inference/filter_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ class BaseFilterConfig:
class EnKFConfig(BaseFilterConfig):
r"""Ensemble Kalman Filter (EnKF) for discrete-time models.

A good general-purpose filter for nonlinear models. Works with any
The **default filter** for discrete-time models. A good general-purpose
filter for nonlinear models with Gaussian observations. Works with any
differentiable or non-differentiable dynamics and scales well to moderate
state dimensions. Cheaper per-step than the particle filter, but assumes
observations are approximately Gaussian given the ensemble.
Expand All @@ -108,7 +109,7 @@ class EnKFConfig(BaseFilterConfig):
inflation_delta (float | None): Scale ensemble anomalies by
\(\sqrt{1 + \delta}\) before the update to prevent collapse.
`None` disables inflation.
filter_source (FilterSource): Backend. Defaults to `"cd_dynamax"`.
filter_source (FilterSource): Backend. Defaults to `"cuthbert"`.

??? note "Algorithm Reference"
The ensemble Kalman filter comprises ensemble members $x_t^{(i)}, i = 1, \ldots, N_{\text{particles}}$.
Expand Down Expand Up @@ -159,7 +160,7 @@ class EnKFConfig(BaseFilterConfig):
)
perturb_measurements: bool | None = None
inflation_delta: float | None = None
filter_source: CDDynamaxOnlyFilterSource = "cd_dynamax"
filter_source: CuthbertOnlyFilterSource = "cuthbert"


@dataclasses.dataclass
Expand Down Expand Up @@ -282,9 +283,6 @@ class EKFConfig(BaseFilterConfig):

This is exact (but wasteful) for linear-Gaussian models.

This is the **default discrete-time filter** when no `filter_config` is
passed to `Filter`.

Attributes:
filter_emission_order (FilterEmissionOrder): Linearisation order for
the observation function. `"first"` *(default)* is the standard
Expand Down Expand Up @@ -375,7 +373,7 @@ class KFConfig(BaseFilterConfig):
- For more details on the `cuthbert` implementation, see the [cuthbert documentation](https://state-space-models.github.io/cuthbert/cuthbert_api/gaussian/kalman/).
"""

filter_source: CDDynamaxOnlyFilterSource = "cd_dynamax"
filter_source: CuthbertOrCDDynamaxFilterSource = "cd_dynamax"


@dataclasses.dataclass
Expand Down Expand Up @@ -520,7 +518,7 @@ class ContinuousTimeEnKFConfig(EnKFConfig, ContinuousTimeConfig):
[Available Online](https://epubs.siam.org/doi/abs/10.1137/21M1434477).
"""

filter_source: CDDynamaxOnlyFilterSource = "cd_dynamax"
filter_source: CDDynamaxOnlyFilterSource = "cd_dynamax" # type: ignore[assignment]


@dataclasses.dataclass
Expand Down
12 changes: 6 additions & 6 deletions dynestyx/inference/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)
from dynestyx.inference.integrations.utils import (
WeightedParticles,
covariance_from_cholesky,
particles_to_delta_mixtures,
)
from dynestyx.models import DynamicalModel
Expand Down Expand Up @@ -130,7 +131,7 @@ def _cuthbert_states_to_dists(
chol_cov = states.chol_cov[
(slice(None),) * len(plate_shapes) + (slice(1, None), ...)
]
cov = jnp.matmul(chol_cov, jnp.swapaxes(chol_cov, -1, -2))
cov = covariance_from_cholesky(chol_cov)
t_len = _time_len_from_array(mean, plate_shapes)
return [
numpyro.distributions.MultivariateNormal(
Expand Down Expand Up @@ -297,8 +298,7 @@ def _default_filter_config(dynamics: DynamicalModel):
if dynamics.continuous_time:
return ContinuousTimeEnKFConfig()

# default to particle filter in discrete time
return EKFConfig(filter_source="cuthbert")
return EnKFConfig()


@dataclasses.dataclass
Expand Down Expand Up @@ -342,7 +342,7 @@ class Filter(BaseLogFactorAdder):
If `filter_config=None`, defaults are:

- `ContinuousTimeEnKFConfig()` for continuous-time models, and
- `EKFConfig(filter_source="cuthbert")` for discrete-time models.
- `EnKFConfig()` for discrete-time models.

Notes:
- If your latent state is *discrete* (an HMM), you must use `HMMConfig`.
Expand Down Expand Up @@ -643,8 +643,8 @@ def _filter_discrete_time(
) -> list[numpyro.distributions.Distribution]:
"""Discrete-time marginal likelihood via cuthbert or cd-dynamax.

Filter type inferred from config class: KFConfig, EKFConfig, UKFConfig (cd-dynamax)
or EKFConfig (cuthbert), PFConfig (cuthbert).
Filter type inferred from config class: KFConfig, EKFConfig, UKFConfig
(cd-dynamax) or KFConfig, EKFConfig, EnKFConfig, PFConfig (cuthbert).

Args:
name: Name of the factor.
Expand Down
157 changes: 142 additions & 15 deletions dynestyx/inference/integrations/cuthbert/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpyro
import numpyro.distributions as dist
from cuthbert import filter as cuthbert_filter
from cuthbert.enkf import ensemble_kalman_filter
from cuthbert.gaussian import kalman, taylor
from cuthbert.smc import particle_filter
from cuthbertlib.resampling import (
Expand All @@ -17,13 +18,18 @@
from dynestyx.inference.filter_configs import (
BaseFilterConfig,
EKFConfig,
EnKFConfig,
KFConfig,
PFConfig,
_config_to_record_kwargs,
)
from dynestyx.inference.integrations.utils import particles_to_delta_mixtures
from dynestyx.inference.integrations.utils import (
covariance_from_cholesky,
particles_to_delta_mixtures,
)
from dynestyx.models import (
DynamicalModel,
GaussianObservation,
LinearGaussianObservation,
LinearGaussianStateEvolution,
)
Expand Down Expand Up @@ -51,6 +57,13 @@ def _config_to_filter_kwargs(config: BaseFilterConfig) -> dict:
kwargs["resampling_differential_method"] = (
config.resampling_method.differential_method
)
elif isinstance(config, EnKFConfig):
kwargs["n_particles"] = config.n_particles
kwargs["inflation"] = (
config.inflation_delta if config.inflation_delta is not None else 0.0
)
if config.perturb_measurements is not None:
kwargs["perturbed_obs"] = config.perturb_measurements
return kwargs


Expand Down Expand Up @@ -106,14 +119,21 @@ def compute_cuthbert_filter(
"or run inside a NumPyro seeded context (e.g., with numpyro.handlers.seed)."
)
filter_obj = _cuthbert_filter_pf(dynamics, filter_kwargs)
elif isinstance(filter_config, EnKFConfig):
if key is None:
raise ValueError(
"Ensemble Kalman filter requires a PRNG key: set 'crn_seed' in the filter config, "
"or run inside a NumPyro seeded context (e.g., with numpyro.handlers.seed)."
)
filter_obj = _cuthbert_filter_enkf(dynamics, filter_kwargs)
elif isinstance(filter_config, KFConfig):
filter_obj = _cuthbert_filter_kalman(dynamics, filter_kwargs)
elif isinstance(filter_config, EKFConfig):
filter_obj = _cuthbert_filter_taylor_kf(dynamics, filter_kwargs)
else:
raise ValueError(
f"Unsupported cuthbert config: {type(filter_config).__name__}. "
"Expected KFConfig, EKFConfig, PFConfig."
"Expected KFConfig, EKFConfig, EnKFConfig, PFConfig."
)

states = cuthbert_filter(filter_obj, cuthbert_inputs, parallel=False, key=key)
Expand All @@ -133,7 +153,7 @@ def run_discrete_filter(
ctrl_values=None,
**kwargs,
) -> list[dist.Distribution]:
"""Run discrete-time filter via cuthbert (Kalman, Taylor KF, particle filter).
"""Run discrete-time filter via cuthbert (Kalman, Taylor KF, EnKF, PF).

Returns:
list[dist.Distribution]: Filtered state distributions at each obs time.
Expand All @@ -158,17 +178,18 @@ def run_discrete_filter(

if isinstance(filter_config, PFConfig):
_add_sites_pf(name, states, record_kwargs)
particles = states.particles
particles = states.particles[1:]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this [1:] (and subsequents) catching a previous off-by-one error??

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, it seems like this problem should be managed upstream, e.g. in compute_cuthbert_filter or even cuthbert_filter.

Currently, the batched filter computations would not pass through the above changes (but does use compute_cuthbert_filter).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And is there a test that would have failed before but now can pass? Wondering how we missed this...

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And is there a test that would have failed before but now can pass? Wondering how we missed this...

Yup, added a test for this test_filters.py::test_cuthbert_filtered_distribution_shapes_match_observations. The error was mostly innocuous, which is why we didn't see it before. This only gets passed forward as filtered_dists, and it didn't affect the prediction behavior at all because of how scans work.

If so, it seems like this problem should be managed upstream, e.g. in compute_cuthbert_filter or even cuthbert_filter.

Yeah, that seems fair.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

if particles.ndim == 2:
particles = particles[..., None]
return particles_to_delta_mixtures(particles, states.log_weights)
return particles_to_delta_mixtures(particles, states.log_weights[1:])
else:
_add_sites_taylor_kf(name, states, record_kwargs)
chol_t = jnp.transpose(states.chol_cov, (0, 2, 1))
cov = jnp.matmul(states.chol_cov, chol_t)
_add_sites_gaussian_filter(name, states, record_kwargs)
mean = states.mean[1:]
chol_cov = states.chol_cov[1:]
cov = covariance_from_cholesky(chol_cov)
return [
dist.MultivariateNormal(states.mean[i], covariance_matrix=cov[i])
for i in range(states.mean.shape[0])
dist.MultivariateNormal(mean[i], covariance_matrix=cov[i])
for i in range(mean.shape[0])
]


Expand Down Expand Up @@ -231,6 +252,111 @@ def log_potential(x_prev, x, mi: CuthbertInputs):
return pf


def _cuthbert_filter_enkf(dynamics: DynamicalModel, filter_kwargs: dict | None = None):
if filter_kwargs is None:
filter_kwargs = {}

state_dim = dynamics.state_dim
obs_dim = dynamics.observation_dim

def init_sample(key, mi: CuthbertInputs):
return jnp.atleast_1d(jnp.asarray(dynamics.initial_condition.sample(key)))

def get_dynamics(mi: CuthbertInputs):
def dynamics_fn(x, key):
def _noop(key):
return x

def _evolve(key):
d = dynamics.state_evolution(x, mi.u_prev, mi.time_prev, mi.time) # type: ignore
return jnp.atleast_1d(jnp.asarray(d.sample(key))) # type: ignore

return jax.lax.cond(mi.is_first_step, _noop, _evolve, key)

return dynamics_fn

def get_observations(mi: CuthbertInputs):
obs_model = dynamics.observation_model
y = jnp.atleast_1d(jnp.asarray(mi.y))

if isinstance(obs_model, LinearGaussianObservation):
H = jnp.asarray(obs_model.H)
chol_R = jnp.linalg.cholesky(jnp.atleast_2d(jnp.asarray(obs_model.R)))
bias = (
jnp.zeros((obs_dim,), dtype=y.dtype)
if obs_model.bias is None
else jnp.atleast_1d(jnp.asarray(obs_model.bias))
)
D = None if obs_model.D is None else jnp.asarray(obs_model.D)

def observation_fn(x):
loc = H @ x + bias
if D is not None:
loc = loc + D @ jnp.atleast_1d(jnp.asarray(mi.u))
return jnp.atleast_1d(jnp.asarray(loc))

return observation_fn, chol_R, y
elif isinstance(obs_model, GaussianObservation):
chol_R = jnp.linalg.cholesky(jnp.atleast_2d(jnp.asarray(obs_model.R)))

def observation_fn(x):
return jnp.atleast_1d(jnp.asarray(obs_model.h(x, mi.u, mi.time)))

return observation_fn, chol_R, y
else:
probe_x = jnp.zeros((state_dim,), dtype=y.dtype)
probe_dist = obs_model(probe_x, mi.u, mi.time)

if isinstance(probe_dist, dist.MultivariateNormal):
chol_R = jnp.asarray(probe_dist.scale_tril)
elif isinstance(probe_dist, dist.Normal):
scale = jnp.atleast_1d(jnp.asarray(probe_dist.scale))
if scale.size == 1 and obs_dim > 1:
scale = jnp.full((obs_dim,), scale[0])
chol_R = jnp.diag(scale)
elif isinstance(probe_dist, dist.Independent) and isinstance(
probe_dist.base_dist, dist.Normal
):
scale = jnp.atleast_1d(jnp.asarray(probe_dist.base_dist.scale))
if scale.size == 1 and obs_dim > 1:
scale = jnp.full((obs_dim,), scale[0])
chol_R = jnp.diag(scale)
else:
raise TypeError(
"cuthbert EnKF requires Gaussian observation distributions. "
"Expected LinearGaussianObservation, GaussianObservation, or a "
"callable returning Normal, Independent(Normal), or "
f"MultivariateNormal; got {type(probe_dist).__name__}."
)

def observation_fn(x):
edist = obs_model(x, mi.u, mi.time)
if not (
isinstance(edist, (dist.MultivariateNormal, dist.Normal))
or (
isinstance(edist, dist.Independent)
and isinstance(edist.base_dist, dist.Normal)
)
):
raise TypeError(
"cuthbert EnKF observation callable must keep returning "
"Gaussian distributions; got "
f"{type(edist).__name__}."
)
Comment thread
DanWaxman marked this conversation as resolved.
Outdated
return jnp.atleast_1d(jnp.asarray(edist.mean))

return observation_fn, chol_R, y

return ensemble_kalman_filter.build_filter(
init_sample=init_sample, # type: ignore
get_dynamics=get_dynamics, # type: ignore
get_observations=get_observations, # type: ignore
n_particles=int(filter_kwargs.get("n_particles", 30)),
inflation=float(filter_kwargs.get("inflation", 0.0)),
perturbed_obs=bool(filter_kwargs.get("perturbed_obs", True)),
)


def _cuthbert_filter_kalman(
dynamics: DynamicalModel, filter_kwargs: dict | None = None
):
Expand Down Expand Up @@ -349,7 +475,7 @@ def dynamics_log_density(x_prev, x):
)
try:
x_lin = jnp.atleast_1d(jnp.asarray(dist_at_lin.mean)) # type: ignore
except Exception as exc:
except (AttributeError, NotImplementedError) as exc:
raise ValueError(
"dist_at_lin.mean is not available. Linearized Kalman filter requires a mean-able distribution."
) from exc
Expand Down Expand Up @@ -439,8 +565,10 @@ def _add_sites_pf(
numpyro.deterministic(f"{name}_filtered_states_cov_diag", diag_cov)


def _add_sites_taylor_kf(
name: str, states: taylor.LinearizedKalmanFilterState, record_kwargs: dict
def _add_sites_gaussian_filter(
name: str,
states: taylor.LinearizedKalmanFilterState | ensemble_kalman_filter.EnKFState,
record_kwargs: dict,
):
max_elems = record_kwargs["record_max_elems"]
# Strip the init entry (index 0) — cuthbert output is (T+1, ...),
Expand Down Expand Up @@ -472,8 +600,7 @@ def _add_sites_taylor_kf(
numpyro.deterministic(f"{name}_filtered_states_chol_cov", chol_cov)

if add_filtered_states_cov or add_filtered_states_cov_diag:
chol_t = jnp.transpose(chol_cov, (0, 2, 1))
filtered_cov = jnp.matmul(chol_cov, chol_t)
filtered_cov = covariance_from_cholesky(chol_cov)

if add_filtered_states_cov:
numpyro.deterministic(f"{name}_filtered_states_cov", filtered_cov)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At 650 lines and supporting 4 different filters and their own utilities (and some shared), I think this file deserves some refactoring into multiple files.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I started this actually in #192. I think it makes sense to continue the refactoring there.

Expand Down
Loading
Loading