Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
91a4e63
CD-Dynamax Smoothing
DanWaxman Apr 9, 2026
1cad067
Dynamax smoothers
DanWaxman Apr 9, 2026
780abab
Cuthbert smoothers
DanWaxman Apr 9, 2026
e1d0114
Smoother handler
DanWaxman Apr 9, 2026
7ccc016
Smoother Tests
DanWaxman Apr 9, 2026
d7a3784
Update documentation
DanWaxman Apr 9, 2026
4aae916
Update Smoother Config Formatting/Usage
DanWaxman Apr 9, 2026
7d5fab5
Small linting
DanWaxman Apr 9, 2026
346c7e2
Resolve linting issues
DanWaxman Apr 9, 2026
650cbf4
Merge branch 'main' into dw-smoothing
DanWaxman Apr 10, 2026
2042044
Merge remote-tracking branch 'origin/main' into dw-smoothing
DanWaxman May 8, 2026
7ae6d4b
Lint
DanWaxman May 8, 2026
c193c09
Add corrected posterior rollouts for smoothers
DanWaxman May 8, 2026
33ebd6e
Discrete-time smoothing notebook
DanWaxman May 8, 2026
15fb359
Update continuous-time smoothing notebook
DanWaxman May 8, 2026
b5e82b4
Update documentation, including smoother configs
DanWaxman May 8, 2026
4a13cab
Make FilterConfig an ABC
DanWaxman May 10, 2026
cdcafaa
Rename to BaseSmootherConfig; Move Smoother-Specific Properties There
DanWaxman May 10, 2026
22dfcec
Update dynestyx/inference/smoothers.py
DanWaxman May 10, 2026
1806039
Make simulator err if filtered_times and smoothed rollout are provided
DanWaxman May 10, 2026
8b16910
Move to shared dist and plate utilities
DanWaxman May 11, 2026
42f9b3a
Merge origin/main into dw-smoothing
DanWaxman May 11, 2026
832cf10
Cast and simplify type assertions
DanWaxman May 11, 2026
8aaf792
Cast and simplify discrete time type assertions
DanWaxman May 11, 2026
1071af3
Update exposition
DanWaxman May 12, 2026
2207712
Remove thin wrappers
DanWaxman May 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/api_reference/developer/inference/smoother_configs.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Smoother Configurations

Developer-facing API reference for smoother config classes and dispatch types.

Concrete smoother configs intentionally mirror the filter config hierarchy while
requiring users to opt into smoothing-specific classes. Backend support is
validated by `Smoother` before dispatch:

- discrete `KFSmootherConfig` and `EKFSmootherConfig`: `cuthbert` or `cd_dynamax`
- discrete `UKFSmootherConfig`: `cd_dynamax`
- discrete `PFSmootherConfig`: `cuthbert`
- continuous `ContinuousTimeKFSmootherConfig` and
`ContinuousTimeEKFSmootherConfig`: `cd_dynamax`

Smoother-specific fields live on the concrete classes rather than a nested
options object, which keeps handler dispatch and API docs aligned.

::: dynestyx.inference.smoother_configs
options:
filters: []
7 changes: 7 additions & 0 deletions docs/api_reference/developer/inference/smoothers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Smoothers

Developer-facing API reference for smoother handler internals and dispatch.

::: dynestyx.inference.smoothers
options:
filters: []
33 changes: 33 additions & 0 deletions docs/api_reference/public/inference/smoother_configs.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Smoother Configurations

`Smoother` is configured using explicit `*SmootherConfig` classes.

Use smoother configs instead of filter configs when entering a `Smoother`
handler. The classes inherit the familiar filtering options, plus smoother
recording fields such as `record_smoothed_states_mean`,
`record_smoothed_states_cov_diag`, `record_smoothed_particles`, and
`record_smoothed_log_weights`.

## Common Choices

```python
from dynestyx.inference.smoother_configs import (
ContinuousTimeKFSmootherConfig,
KFSmootherConfig,
PFSmootherConfig,
)

kf = KFSmootherConfig(filter_source="cd_dynamax")
pf = PFSmootherConfig(filter_source="cuthbert", n_particles=1_000)
ct_kf = ContinuousTimeKFSmootherConfig()
```

`PFSmootherConfig` exposes particle-smoother options:
`pf_backward_sampling_method`, `pf_mcmc_n_steps`, and
`pf_n_smoother_particles`. `ContinuousTimeKFSmootherConfig` exposes
`cdlgssm_smoother_type` for the CD-Dynamax continuous-discrete linear
Gaussian smoother variant.

::: dynestyx.inference.smoother_configs
options:
filters: []
54 changes: 54 additions & 0 deletions docs/api_reference/public/inference/smoothers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Smoothers

`Smoother` computes posterior smoothing distributions \(p(x_t \mid y_{1:T})\) and adds the corresponding marginal log-likelihood factor for parameter inference, mirroring `Filter` semantics.

## Usage

```python
from dynestyx import DiscreteTimeSimulator, Smoother
from dynestyx.inference.smoother_configs import KFSmootherConfig

with DiscreteTimeSimulator(n_simulations=4):
with Smoother(
smoother_config=KFSmootherConfig(
filter_source="cd_dynamax",
record_smoothed_states_mean=True,
)
):
samples = model(
obs_times=obs_times,
obs_values=obs_values,
predict_times=future_times,
)
```

`obs_times` and `obs_values` are required together. `Smoother` consumes them,
adds a marginal log-likelihood factor, and can record deterministic sites such
as `f_smoothed_states_mean`, `f_smoothed_states_cov`, and
`f_smoothed_states_cov_diag`.

## Prediction Semantics

For this release, smoother-backed prediction is intentionally future-only:
every `predict_time` must satisfy `predict_time >= max(obs_times)`. The
downstream simulator rolls out from the final smoothed state distribution.

Prediction times inside the smoothing window currently raise a clear error
instead of silently using incorrect indexing or backend-specific missing-data
behavior.

## Support Matrix

| Model class | Config | Backend |
| --- | --- | --- |
| Discrete linear-Gaussian | `KFSmootherConfig` | `cuthbert`, `cd_dynamax` |
| Discrete nonlinear Gaussian | `EKFSmootherConfig` | `cuthbert`, `cd_dynamax` |
| Discrete nonlinear Gaussian | `UKFSmootherConfig` | `cd_dynamax` |
| Discrete non-Gaussian/nonlinear | `PFSmootherConfig` | `cuthbert` |
| Continuous-discrete linear-Gaussian | `ContinuousTimeKFSmootherConfig` | `cd_dynamax` |
| Continuous-discrete nonlinear Gaussian | `ContinuousTimeEKFSmootherConfig` | `cd_dynamax` |

::: dynestyx.inference.smoothers
options:
members:
- Smoother
8 changes: 6 additions & 2 deletions docs/tutorials/gentle_intro/00_index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,13 @@
"\n",
"7. **[Part 6b: Continuous-time dynamical systems (ODEs)](../06b_odes/)** - `ContinuousTimeStateEvolution` (drift and no diffusions); `ODESimulator`. Probabilistic numerics coming soon!\n",
"\n",
"7. **[Part 7: Hidden Markov Models (HMMs)](../07_hmm/)** — Working with categorical state spaces models (filtering and Bayesian parameter estimation).\n",
"8. **[Part 7: Hidden Markov Models (HMMs)](../07_hmm/)** — Working with categorical state spaces models (filtering and Bayesian parameter estimation).\n",
"\n",
"8. **[Part 8: Hierarchical / mixed-effect inference](../08_hierarchical_inference/)** — `plate` for trajectory-level parameters; hierarchical dynamical models and inference."
"9. **[Part 8: Hierarchical / mixed-effect inference](../08_hierarchical_inference/)** — `plate` for trajectory-level parameters; hierarchical dynamical models and inference.\n",
"\n",
"10. **[Part 9: Discrete-time smoothing](../09_discrete_smoothing/)** — `Smoother` for discrete-time models; smoothed state summaries and future-only posterior rollout.\n",
"\n",
"11. **[Part 10: Continuous-time smoothing](../10_continuous_smoothing/)** — Continuous-discrete smoothing; smoothed state summaries and future-only posterior rollout."
]
},
{
Expand Down
659 changes: 659 additions & 0 deletions docs/tutorials/gentle_intro/09_discrete_smoothing.ipynb

Large diffs are not rendered by default.

801 changes: 801 additions & 0 deletions docs/tutorials/gentle_intro/10_continuous_smoothing.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions dynestyx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dynestyx.discretizers import Discretizer, euler_maruyama
from dynestyx.handlers import plate, sample
from dynestyx.inference.filters import Filter
from dynestyx.inference.smoothers import Smoother
from dynestyx.models import (
ContinuousTimeStateEvolution,
DiracIdentityObservation,
Expand Down Expand Up @@ -41,6 +42,7 @@
"Discretizer",
"ObservationModel",
"Filter",
"Smoother",
"flatten_draws",
"plate",
"sample",
Expand Down
2 changes: 2 additions & 0 deletions dynestyx/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@
"integrations",
"mcmc",
"mcmc_configs",
"smoother_configs",
"smoothers",
]
134 changes: 134 additions & 0 deletions dynestyx/inference/distribution_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from typing import Literal

import jax
import jax.numpy as jnp
import numpyro.distributions as dist

from dynestyx.inference.integrations.utils import (
WeightedParticles,
covariance_from_cholesky,
)
from dynestyx.inference.plate_utils import _slice_time_axis, _time_len_from_array

MissingPolicy = Literal["raise", "empty"]


def _handle_missing_gaussian_sequence(
*,
missing: MissingPolicy,
missing_message: str | None,
) -> list[dist.Distribution]:
if missing == "empty":
return []
if missing == "raise":
raise ValueError(
missing_message or "Gaussian means/covariances were unavailable."
)
raise ValueError(f"Unknown missing Gaussian sequence policy: {missing!r}.")


def _gaussian_sequence_to_dists(
means: jax.Array | None,
covariances: jax.Array | None,
*,
plate_shapes: tuple[int, ...] = (),
missing: MissingPolicy = "raise",
missing_message: str | None = None,
) -> list[dist.Distribution]:
"""Convert time-indexed Gaussian parameters to per-time distributions."""
if means is None or covariances is None:
return _handle_missing_gaussian_sequence(
missing=missing,
missing_message=missing_message,
)

t_len = _time_len_from_array(means, plate_shapes)
return [
dist.MultivariateNormal(
_slice_time_axis(means, t, plate_shapes),
covariance_matrix=_slice_time_axis(covariances, t, plate_shapes),
)
for t in range(t_len)
]


def _particle_sequence_to_dists(
particles: jax.Array,
log_weights: jax.Array,
*,
plate_shapes: tuple[int, ...] = (),
) -> list[dist.Distribution]:
"""Convert time-indexed particle arrays to per-time weighted particles."""
if particles.ndim == len(plate_shapes) + 2:
particles = particles[..., None]

normalized_log_weights = jax.nn.log_softmax(log_weights, axis=-1)
t_len = _time_len_from_array(normalized_log_weights, plate_shapes)
return [
WeightedParticles(
particles=_slice_time_axis(particles, t, plate_shapes),
log_weights=_slice_time_axis(normalized_log_weights, t, plate_shapes),
)
for t in range(t_len)
]


def _posterior_sequence_to_dists(
posterior,
*,
means_attr: str,
covariances_attr: str,
particle_mode: bool,
plate_shapes: tuple[int, ...] = (),
missing: MissingPolicy = "raise",
missing_message: str | None = None,
) -> list[dist.Distribution]:
"""Convert a backend posterior object to per-time distributions."""
if particle_mode:
return _particle_sequence_to_dists(
posterior.particles,
posterior.log_weights,
plate_shapes=plate_shapes,
)

return _gaussian_sequence_to_dists(
getattr(posterior, means_attr),
getattr(posterior, covariances_attr),
plate_shapes=plate_shapes,
missing=missing,
missing_message=missing_message,
)


def _cholesky_state_sequence_to_dists(
states,
*,
particle_mode: bool,
plate_shapes: tuple[int, ...] = (),
) -> list[dist.Distribution]:
"""Convert cuthbert state objects to per-time distributions."""
if particle_mode:
return _particle_sequence_to_dists(
states.particles,
states.log_weights,
plate_shapes=plate_shapes,
)

return _gaussian_sequence_to_dists(
states.mean,
covariance_from_cholesky(states.chol_cov),
plate_shapes=plate_shapes,
)


def _categorical_log_probs_to_dists(
log_probs: jax.Array,
*,
plate_shapes: tuple[int, ...] = (),
) -> list[dist.Distribution]:
"""Convert time-indexed categorical log-probs to per-time distributions."""
t_len = _time_len_from_array(log_probs, plate_shapes)
return [
dist.Categorical(probs=jnp.exp(_slice_time_axis(log_probs, t, plate_shapes)))
for t in range(t_len)
]
3 changes: 2 additions & 1 deletion dynestyx/inference/filter_configs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Filter configuration dataclasses. Shared by dispatchers and integration backends."""

import abc
import dataclasses
import math
from typing import Literal
Expand All @@ -22,7 +23,7 @@


@dataclasses.dataclass
class BaseFilterConfig:
class BaseFilterConfig(abc.ABC):
r"""Shared configuration options inherited by all filter configs.

You do not instantiate this class directly; use one of the concrete
Expand Down
Loading
Loading