-
Notifications
You must be signed in to change notification settings - Fork 5
Add Cuthbert EnKF #209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Cuthbert EnKF #209
Changes from 5 commits
7d23ceb
d8796f3
ec6d5dc
285af32
6c6130f
27601d3
c57f745
782c4ed
4787fc8
7e3c3fc
8c74062
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
mattlevine22 marked this conversation as resolved.
|
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| import numpyro | ||
| import numpyro.distributions as dist | ||
| from cuthbert import filter as cuthbert_filter | ||
| from cuthbert.enkf import ensemble_kalman_filter | ||
| from cuthbert.gaussian import kalman, taylor | ||
| from cuthbert.smc import particle_filter | ||
| from cuthbertlib.resampling import ( | ||
|
|
@@ -17,13 +18,18 @@ | |
| from dynestyx.inference.filter_configs import ( | ||
| BaseFilterConfig, | ||
| EKFConfig, | ||
| EnKFConfig, | ||
| KFConfig, | ||
| PFConfig, | ||
| _config_to_record_kwargs, | ||
| ) | ||
| from dynestyx.inference.integrations.utils import particles_to_delta_mixtures | ||
| from dynestyx.inference.integrations.utils import ( | ||
| covariance_from_cholesky, | ||
| particles_to_delta_mixtures, | ||
| ) | ||
| from dynestyx.models import ( | ||
| DynamicalModel, | ||
| GaussianObservation, | ||
| LinearGaussianObservation, | ||
| LinearGaussianStateEvolution, | ||
| ) | ||
|
|
@@ -51,6 +57,13 @@ def _config_to_filter_kwargs(config: BaseFilterConfig) -> dict: | |
| kwargs["resampling_differential_method"] = ( | ||
| config.resampling_method.differential_method | ||
| ) | ||
| elif isinstance(config, EnKFConfig): | ||
| kwargs["n_particles"] = config.n_particles | ||
| kwargs["inflation"] = ( | ||
| config.inflation_delta if config.inflation_delta is not None else 0.0 | ||
| ) | ||
| if config.perturb_measurements is not None: | ||
| kwargs["perturbed_obs"] = config.perturb_measurements | ||
| return kwargs | ||
|
|
||
|
|
||
|
|
@@ -106,14 +119,21 @@ def compute_cuthbert_filter( | |
| "or run inside a NumPyro seeded context (e.g., with numpyro.handlers.seed)." | ||
| ) | ||
| filter_obj = _cuthbert_filter_pf(dynamics, filter_kwargs) | ||
| elif isinstance(filter_config, EnKFConfig): | ||
| if key is None: | ||
| raise ValueError( | ||
| "Ensemble Kalman filter requires a PRNG key: set 'crn_seed' in the filter config, " | ||
| "or run inside a NumPyro seeded context (e.g., with numpyro.handlers.seed)." | ||
| ) | ||
| filter_obj = _cuthbert_filter_enkf(dynamics, filter_kwargs) | ||
| elif isinstance(filter_config, KFConfig): | ||
| filter_obj = _cuthbert_filter_kalman(dynamics, filter_kwargs) | ||
| elif isinstance(filter_config, EKFConfig): | ||
| filter_obj = _cuthbert_filter_taylor_kf(dynamics, filter_kwargs) | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported cuthbert config: {type(filter_config).__name__}. " | ||
| "Expected KFConfig, EKFConfig, PFConfig." | ||
| "Expected KFConfig, EKFConfig, EnKFConfig, PFConfig." | ||
| ) | ||
|
|
||
| states = cuthbert_filter(filter_obj, cuthbert_inputs, parallel=False, key=key) | ||
|
|
@@ -133,7 +153,7 @@ def run_discrete_filter( | |
| ctrl_values=None, | ||
| **kwargs, | ||
| ) -> list[dist.Distribution]: | ||
| """Run discrete-time filter via cuthbert (Kalman, Taylor KF, particle filter). | ||
| """Run discrete-time filter via cuthbert (Kalman, Taylor KF, EnKF, PF). | ||
|
|
||
| Returns: | ||
| list[dist.Distribution]: Filtered state distributions at each obs time. | ||
|
|
@@ -158,17 +178,18 @@ def run_discrete_filter( | |
|
|
||
| if isinstance(filter_config, PFConfig): | ||
| _add_sites_pf(name, states, record_kwargs) | ||
| particles = states.particles | ||
| particles = states.particles[1:] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If so, it seems like this problem should be managed upstream, e.g. in Currently, the batched filter computations would not pass through the above changes (but does use
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And is there a test that would have failed before but now can pass? Wondering how we missed this...
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yup, added a test for this
Yeah, that seems fair.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! |
||
| if particles.ndim == 2: | ||
| particles = particles[..., None] | ||
| return particles_to_delta_mixtures(particles, states.log_weights) | ||
| return particles_to_delta_mixtures(particles, states.log_weights[1:]) | ||
| else: | ||
| _add_sites_taylor_kf(name, states, record_kwargs) | ||
| chol_t = jnp.transpose(states.chol_cov, (0, 2, 1)) | ||
| cov = jnp.matmul(states.chol_cov, chol_t) | ||
| _add_sites_gaussian_filter(name, states, record_kwargs) | ||
| mean = states.mean[1:] | ||
| chol_cov = states.chol_cov[1:] | ||
| cov = covariance_from_cholesky(chol_cov) | ||
| return [ | ||
| dist.MultivariateNormal(states.mean[i], covariance_matrix=cov[i]) | ||
| for i in range(states.mean.shape[0]) | ||
| dist.MultivariateNormal(mean[i], covariance_matrix=cov[i]) | ||
| for i in range(mean.shape[0]) | ||
| ] | ||
|
|
||
|
|
||
|
|
@@ -231,6 +252,111 @@ def log_potential(x_prev, x, mi: CuthbertInputs): | |
| return pf | ||
|
|
||
|
|
||
| def _cuthbert_filter_enkf(dynamics: DynamicalModel, filter_kwargs: dict | None = None): | ||
| if filter_kwargs is None: | ||
| filter_kwargs = {} | ||
|
|
||
| state_dim = dynamics.state_dim | ||
| obs_dim = dynamics.observation_dim | ||
|
|
||
| def init_sample(key, mi: CuthbertInputs): | ||
| return jnp.atleast_1d(jnp.asarray(dynamics.initial_condition.sample(key))) | ||
|
|
||
| def get_dynamics(mi: CuthbertInputs): | ||
| def dynamics_fn(x, key): | ||
| def _noop(key): | ||
| return x | ||
|
|
||
| def _evolve(key): | ||
| d = dynamics.state_evolution(x, mi.u_prev, mi.time_prev, mi.time) # type: ignore | ||
| return jnp.atleast_1d(jnp.asarray(d.sample(key))) # type: ignore | ||
|
|
||
| return jax.lax.cond(mi.is_first_step, _noop, _evolve, key) | ||
|
|
||
| return dynamics_fn | ||
|
|
||
| def get_observations(mi: CuthbertInputs): | ||
| obs_model = dynamics.observation_model | ||
| y = jnp.atleast_1d(jnp.asarray(mi.y)) | ||
|
|
||
| if isinstance(obs_model, LinearGaussianObservation): | ||
| H = jnp.asarray(obs_model.H) | ||
| chol_R = jnp.linalg.cholesky(jnp.atleast_2d(jnp.asarray(obs_model.R))) | ||
| bias = ( | ||
| jnp.zeros((obs_dim,), dtype=y.dtype) | ||
| if obs_model.bias is None | ||
| else jnp.atleast_1d(jnp.asarray(obs_model.bias)) | ||
| ) | ||
| D = None if obs_model.D is None else jnp.asarray(obs_model.D) | ||
|
|
||
| def observation_fn(x): | ||
| loc = H @ x + bias | ||
| if D is not None: | ||
| loc = loc + D @ jnp.atleast_1d(jnp.asarray(mi.u)) | ||
| return jnp.atleast_1d(jnp.asarray(loc)) | ||
|
|
||
| return observation_fn, chol_R, y | ||
| elif isinstance(obs_model, GaussianObservation): | ||
| chol_R = jnp.linalg.cholesky(jnp.atleast_2d(jnp.asarray(obs_model.R))) | ||
|
|
||
| def observation_fn(x): | ||
| return jnp.atleast_1d(jnp.asarray(obs_model.h(x, mi.u, mi.time))) | ||
|
|
||
| return observation_fn, chol_R, y | ||
| else: | ||
| probe_x = jnp.zeros((state_dim,), dtype=y.dtype) | ||
| probe_dist = obs_model(probe_x, mi.u, mi.time) | ||
|
|
||
| if isinstance(probe_dist, dist.MultivariateNormal): | ||
| chol_R = jnp.asarray(probe_dist.scale_tril) | ||
| elif isinstance(probe_dist, dist.Normal): | ||
| scale = jnp.atleast_1d(jnp.asarray(probe_dist.scale)) | ||
| if scale.size == 1 and obs_dim > 1: | ||
| scale = jnp.full((obs_dim,), scale[0]) | ||
| chol_R = jnp.diag(scale) | ||
| elif isinstance(probe_dist, dist.Independent) and isinstance( | ||
| probe_dist.base_dist, dist.Normal | ||
| ): | ||
| scale = jnp.atleast_1d(jnp.asarray(probe_dist.base_dist.scale)) | ||
| if scale.size == 1 and obs_dim > 1: | ||
| scale = jnp.full((obs_dim,), scale[0]) | ||
| chol_R = jnp.diag(scale) | ||
| else: | ||
| raise TypeError( | ||
| "cuthbert EnKF requires Gaussian observation distributions. " | ||
| "Expected LinearGaussianObservation, GaussianObservation, or a " | ||
| "callable returning Normal, Independent(Normal), or " | ||
| f"MultivariateNormal; got {type(probe_dist).__name__}." | ||
| ) | ||
|
|
||
| def observation_fn(x): | ||
| edist = obs_model(x, mi.u, mi.time) | ||
| if not ( | ||
| isinstance(edist, (dist.MultivariateNormal, dist.Normal)) | ||
| or ( | ||
| isinstance(edist, dist.Independent) | ||
| and isinstance(edist.base_dist, dist.Normal) | ||
| ) | ||
| ): | ||
| raise TypeError( | ||
| "cuthbert EnKF observation callable must keep returning " | ||
| "Gaussian distributions; got " | ||
| f"{type(edist).__name__}." | ||
| ) | ||
|
DanWaxman marked this conversation as resolved.
Outdated
|
||
| return jnp.atleast_1d(jnp.asarray(edist.mean)) | ||
|
|
||
| return observation_fn, chol_R, y | ||
|
|
||
| return ensemble_kalman_filter.build_filter( | ||
| init_sample=init_sample, # type: ignore | ||
| get_dynamics=get_dynamics, # type: ignore | ||
| get_observations=get_observations, # type: ignore | ||
| n_particles=int(filter_kwargs.get("n_particles", 30)), | ||
| inflation=float(filter_kwargs.get("inflation", 0.0)), | ||
| perturbed_obs=bool(filter_kwargs.get("perturbed_obs", True)), | ||
| ) | ||
|
|
||
|
|
||
| def _cuthbert_filter_kalman( | ||
| dynamics: DynamicalModel, filter_kwargs: dict | None = None | ||
| ): | ||
|
|
@@ -349,7 +475,7 @@ def dynamics_log_density(x_prev, x): | |
| ) | ||
| try: | ||
| x_lin = jnp.atleast_1d(jnp.asarray(dist_at_lin.mean)) # type: ignore | ||
| except Exception as exc: | ||
| except (AttributeError, NotImplementedError) as exc: | ||
| raise ValueError( | ||
| "dist_at_lin.mean is not available. Linearized Kalman filter requires a mean-able distribution." | ||
| ) from exc | ||
|
|
@@ -439,8 +565,10 @@ def _add_sites_pf( | |
| numpyro.deterministic(f"{name}_filtered_states_cov_diag", diag_cov) | ||
|
|
||
|
|
||
| def _add_sites_taylor_kf( | ||
| name: str, states: taylor.LinearizedKalmanFilterState, record_kwargs: dict | ||
| def _add_sites_gaussian_filter( | ||
| name: str, | ||
| states: taylor.LinearizedKalmanFilterState | ensemble_kalman_filter.EnKFState, | ||
| record_kwargs: dict, | ||
| ): | ||
| max_elems = record_kwargs["record_max_elems"] | ||
| # Strip the init entry (index 0) — cuthbert output is (T+1, ...), | ||
|
|
@@ -472,8 +600,7 @@ def _add_sites_taylor_kf( | |
| numpyro.deterministic(f"{name}_filtered_states_chol_cov", chol_cov) | ||
|
|
||
| if add_filtered_states_cov or add_filtered_states_cov_diag: | ||
| chol_t = jnp.transpose(chol_cov, (0, 2, 1)) | ||
| filtered_cov = jnp.matmul(chol_cov, chol_t) | ||
| filtered_cov = covariance_from_cholesky(chol_cov) | ||
|
|
||
| if add_filtered_states_cov: | ||
| numpyro.deterministic(f"{name}_filtered_states_cov", filtered_cov) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At 650 lines and supporting 4 different filters and their own utilities (and some shared), I think this file deserves some refactoring into multiple files.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. I started this actually in #192. I think it makes sense to continue the refactoring there. |
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.