-
Notifications
You must be signed in to change notification settings - Fork 283
feat: Add Microcanonical Langevin Monte Carlo (MCLMC) kernel #2124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
juanitorduz
wants to merge
11
commits into
pyro-ppl:master
Choose a base branch
from
juanitorduz:mclmc
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 2 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
152af4c
draft of MCLMC
reubenharry a1e7e0b
feat: Add Microcanonical Langevin Monte Carlo (MCLMC) kernel
juanitorduz 59d7f41
coauthor
juanitorduz af9f001
add blcakjax to test
juanitorduz a7280f6
empty
juanitorduz 528741b
Merge branch 'master' into mclmc
juanitorduz f336db0
init no blackjax
juanitorduz 8d708d2
types
juanitorduz 621a95c
docstrings
juanitorduz d27da77
cleanup
juanitorduz e2d754c
tests
juanitorduz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,219 @@ | ||
| # Copyright Contributors to the Pyro project. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from collections import namedtuple | ||
|
|
||
| import jax | ||
|
|
||
| from numpyro.infer.mcmc import MCMCKernel | ||
| from numpyro.infer.util import initialize_model | ||
| from numpyro.util import identity | ||
|
|
||
| try: | ||
| import blackjax | ||
| from blackjax.mcmc.integrators import IntegratorState | ||
| from blackjax.util import pytree_size | ||
|
|
||
| _BLACKJAX_AVAILABLE = True | ||
| except ImportError: | ||
| _BLACKJAX_AVAILABLE = False | ||
| blackjax = None | ||
| IntegratorState = None | ||
| pytree_size = None | ||
|
|
||
| FullState = namedtuple( | ||
| "FullState", ["position", "momentum", "logdensity", "logdensity_grad", "rng_key"] | ||
| ) | ||
|
|
||
|
|
||
| class MCLMC(MCMCKernel): | ||
| """ | ||
| Microcanonical Langevin Monte Carlo (MCLMC) kernel. | ||
|
|
||
| MCLMC is a gradient-based MCMC algorithm that uses Hamiltonian dynamics | ||
| on an extended state space. It requires the `blackjax` package. | ||
|
|
||
| **References:** | ||
|
|
||
| 1. *Microcanonical Hamiltonian Monte Carlo*, | ||
| Jakob Robnik, G. Bruno De Luca, Eva Silverstein, Uroš Seljak | ||
| https://arxiv.org/abs/2212.08549 | ||
|
|
||
| .. note:: The model must have at least 2 latent dimensions for MCLMC to work | ||
| (this is a limitation of the blackjax implementation). | ||
|
|
||
| :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. | ||
| :param float desired_energy_var: Target energy variance for step size and | ||
| trajectory length tuning. Smaller values lead to more conservative | ||
| step sizes. Defaults to 5e-4. | ||
| :param bool diagonal_preconditioning: Whether to use diagonal preconditioning | ||
| for the mass matrix. Defaults to True. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| model=None, | ||
| desired_energy_var=5e-4, | ||
| diagonal_preconditioning=True, | ||
| ): | ||
| if not _BLACKJAX_AVAILABLE: | ||
| raise ImportError( | ||
| "MCLMC requires the 'blackjax' package. " | ||
| "Please install it with: pip install blackjax" | ||
| ) | ||
| if model is None: | ||
| raise ValueError("Model must be specified for MCLMC") | ||
| self._model = model | ||
| self._diagonal_preconditioning = diagonal_preconditioning | ||
| self._desired_energy_var = desired_energy_var | ||
| self._init_fn = None | ||
| self._sample_fn = None | ||
| self._postprocess_fn = None | ||
|
|
||
| @property | ||
| def model(self): | ||
| return self._model | ||
|
|
||
| @property | ||
| def sample_field(self): | ||
| return "position" | ||
|
|
||
| @property | ||
| def default_fields(self): | ||
| return (self.sample_field,) | ||
|
|
||
| def get_diagnostics_str(self, state): | ||
| """ | ||
| Return a diagnostics string for the progress bar. | ||
| """ | ||
| return "step_size={:.2e}, L={:.2e}".format( | ||
| self.adapt_state.step_size, self.adapt_state.L | ||
| ) | ||
|
|
||
| def postprocess_fn(self, args, kwargs): | ||
| """ | ||
| Get a function that transforms unconstrained values at sample sites to values | ||
| constrained to the site's support, in addition to returning deterministic | ||
| sites in the model. | ||
|
|
||
| :param args: Arguments to the model. | ||
| :param kwargs: Keyword arguments to the model. | ||
| """ | ||
| if self._postprocess_fn is None: | ||
| return identity | ||
| return self._postprocess_fn(*args, **kwargs) | ||
|
|
||
| def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): | ||
| """ | ||
| Initialize the MCLMC kernel. | ||
|
|
||
| :param rng_key: Random number generator key | ||
| :param num_warmup: Number of warmup steps | ||
| :param init_params: Initial parameters | ||
| :param model_args: Model arguments | ||
| :param model_kwargs: Model keyword arguments | ||
| :return: Initial state | ||
| """ | ||
|
|
||
| init_model_key, init_state_key, run_key, rng_key_tune = jax.random.split( | ||
| rng_key, 4 | ||
| ) | ||
|
|
||
| init_params, potential_fn_gen, postprocess_fn, _ = initialize_model( | ||
| init_model_key, | ||
| self._model, | ||
| model_args=model_args, | ||
| model_kwargs=model_kwargs, | ||
| dynamic_args=True, | ||
| ) | ||
| self._postprocess_fn = postprocess_fn | ||
|
|
||
| def logdensity_fn(position): | ||
| return -potential_fn_gen(*model_args, **model_kwargs)(position) | ||
|
|
||
| initial_position = init_params.z | ||
| self.logdensity_fn = logdensity_fn | ||
|
|
||
| sampler_state = blackjax.mcmc.mclmc.init( | ||
| position=initial_position, | ||
| logdensity_fn=self.logdensity_fn, | ||
| rng_key=init_state_key, | ||
| ) | ||
|
|
||
| def kernel(inverse_mass_matrix): | ||
| return blackjax.mcmc.mclmc.build_kernel( | ||
| logdensity_fn=logdensity_fn, | ||
| integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, | ||
| inverse_mass_matrix=inverse_mass_matrix, | ||
| ) | ||
|
|
||
| self.dim = pytree_size(initial_position) | ||
|
|
||
| # num_steps is a dummy param here (used for tuning fractions) | ||
| num_tuning_steps = 100 | ||
| ( | ||
| blackjax_state_after_tuning, | ||
| blackjax_mclmc_sampler_params, | ||
| _, | ||
| ) = blackjax.mclmc_find_L_and_step_size( | ||
| mclmc_kernel=kernel, | ||
| num_steps=num_tuning_steps, | ||
| state=sampler_state, | ||
| rng_key=rng_key_tune, | ||
| diagonal_preconditioning=self._diagonal_preconditioning, | ||
| frac_tune3=num_warmup / (3 * num_tuning_steps), | ||
| frac_tune2=num_warmup / (3 * num_tuning_steps), | ||
| frac_tune1=num_warmup / (3 * num_tuning_steps), | ||
| desired_energy_var=self._desired_energy_var, | ||
| ) | ||
|
|
||
| self.adapt_state = blackjax_mclmc_sampler_params | ||
|
|
||
| return FullState( | ||
| blackjax_state_after_tuning.position, | ||
| blackjax_state_after_tuning.momentum, | ||
| blackjax_state_after_tuning.logdensity, | ||
| blackjax_state_after_tuning.logdensity_grad, | ||
| run_key, | ||
| ) | ||
|
|
||
| def sample(self, state, model_args, model_kwargs): | ||
| """ | ||
| Run MCLMC from the given state and return the resulting state. | ||
|
|
||
| :param state: Current state | ||
| :param model_args: Model arguments | ||
| :param model_kwargs: Model keyword arguments | ||
| :return: Next state after running MCLMC | ||
| """ | ||
|
|
||
| mclmc_state = IntegratorState( | ||
| state.position, state.momentum, state.logdensity, state.logdensity_grad | ||
| ) | ||
| rng_key, rng_key_sample = jax.random.split(state.rng_key, 2) | ||
|
|
||
| kernel = blackjax.mcmc.mclmc.build_kernel( | ||
| logdensity_fn=self.logdensity_fn, | ||
| integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, | ||
| inverse_mass_matrix=self.adapt_state.inverse_mass_matrix, | ||
| ) | ||
|
|
||
| new_state, info = kernel( | ||
| rng_key=rng_key_sample, | ||
| state=mclmc_state, | ||
| step_size=self.adapt_state.step_size, | ||
| L=self.adapt_state.L, | ||
| ) | ||
|
|
||
| return FullState( | ||
| new_state.position, | ||
| new_state.momentum, | ||
| new_state.logdensity, | ||
| new_state.logdensity_grad, | ||
| rng_key, | ||
| ) | ||
|
|
||
| def __getstate__(self): | ||
| state = self.__dict__.copy() | ||
| state["_postprocess_fn"] = None | ||
| return state | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,155 @@ | ||
| # Copyright Contributors to the Pyro project. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from numpy.testing import assert_allclose | ||
| import pytest | ||
|
|
||
| from jax import random | ||
| import jax.numpy as jnp | ||
|
|
||
| import numpyro | ||
| import numpyro.distributions as dist | ||
| from numpyro.infer import MCMC | ||
| from numpyro.infer.mclmc import MCLMC | ||
|
|
||
|
|
||
| def test_mclmc_model_required(): | ||
| """Test that ValueError is raised when model is None.""" | ||
| with pytest.raises(ValueError, match="Model must be specified"): | ||
| MCLMC(model=None) | ||
|
|
||
|
|
||
| def test_mclmc_blackjax_not_installed(monkeypatch): | ||
| """Test that ImportError is raised with informative message when blackjax is not installed.""" | ||
| import numpyro.infer.mclmc as mclmc_module | ||
|
|
||
| # Temporarily set _BLACKJAX_AVAILABLE to False | ||
| monkeypatch.setattr(mclmc_module, "_BLACKJAX_AVAILABLE", False) | ||
|
|
||
| def dummy_model(): | ||
| numpyro.sample("x", dist.Normal(0, 1)) | ||
|
|
||
| with pytest.raises(ImportError, match="MCLMC requires the 'blackjax' package"): | ||
| MCLMC(model=dummy_model) | ||
|
|
||
|
|
||
| def test_mclmc_normal(): | ||
| """Test MCLMC with a 2D normal distribution. | ||
|
|
||
| Note: MCLMC requires at least 2 dimensions (blackjax limitation). | ||
| """ | ||
| true_mean = jnp.array([1.0, 2.0]) | ||
| true_std = jnp.array([0.5, 1.0]) | ||
| num_warmup, num_samples = 1000, 2000 | ||
|
|
||
| def model(): | ||
| numpyro.sample("x", dist.Normal(true_mean, true_std).to_event(1)) | ||
|
|
||
| kernel = MCLMC(model=model) | ||
| mcmc = MCMC( | ||
| kernel, | ||
| num_warmup=num_warmup, | ||
| num_samples=num_samples, | ||
| num_chains=1, | ||
| progress_bar=False, | ||
| ) | ||
| mcmc.run(random.PRNGKey(0)) | ||
| samples = mcmc.get_samples() | ||
|
|
||
| assert "x" in samples | ||
| assert samples["x"].shape == (num_samples, 2) | ||
| assert_allclose(jnp.mean(samples["x"], axis=0), true_mean, atol=0.1) | ||
| assert_allclose(jnp.std(samples["x"], axis=0), true_std, atol=0.2) | ||
|
|
||
|
|
||
| def test_mclmc_gaussian_2d(): | ||
| """Test MCLMC with a 2D Gaussian model with observation.""" | ||
| num_warmup, num_samples = 1000, 1000 | ||
|
|
||
| def model(): | ||
| x = numpyro.sample("x", dist.Normal(0.0, 1.0)) | ||
| y = numpyro.sample("y", dist.Normal(0.0, 1.0)) | ||
| numpyro.sample("obs", dist.Normal(x + y, 0.5), obs=jnp.array(0.0)) | ||
|
|
||
| kernel = MCLMC( | ||
| model=model, | ||
| diagonal_preconditioning=True, | ||
| desired_energy_var=5e-4, | ||
| ) | ||
| mcmc = MCMC( | ||
| kernel, | ||
| num_warmup=num_warmup, | ||
| num_samples=num_samples, | ||
| num_chains=1, | ||
| progress_bar=False, | ||
| ) | ||
| mcmc.run(random.PRNGKey(0)) | ||
| samples = mcmc.get_samples() | ||
|
|
||
| assert "x" in samples | ||
| assert "y" in samples | ||
| assert samples["x"].shape == (num_samples,) | ||
| assert samples["y"].shape == (num_samples,) | ||
| # With obs=0, x+y should be close to 0, so means should be near 0 | ||
| assert_allclose(jnp.mean(samples["x"]) + jnp.mean(samples["y"]), 0.0, atol=0.2) | ||
|
|
||
|
|
||
| def test_mclmc_logistic_regression(): | ||
| """Test MCLMC with a logistic regression model. | ||
|
|
||
| Note: MCLMC currently doesn't pass model_args, so we use a closure pattern. | ||
| """ | ||
| N, dim = 1000, 3 | ||
| num_warmup, num_samples = 1000, 2000 | ||
|
|
||
| key1, key2, key3 = random.split(random.PRNGKey(0), 3) | ||
| data = random.normal(key1, (N, dim)) | ||
| true_coefs = jnp.arange(1.0, dim + 1.0) | ||
| logits = jnp.sum(true_coefs * data, axis=-1) | ||
| labels = dist.Bernoulli(logits=logits).sample(key2) | ||
|
|
||
| # Use closure pattern since MCLMC doesn't pass model_args | ||
| def model(): | ||
| coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) | ||
| logits = jnp.sum(coefs * data, axis=-1) | ||
| numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) | ||
|
|
||
| kernel = MCLMC(model=model) | ||
| mcmc = MCMC( | ||
| kernel, | ||
| num_warmup=num_warmup, | ||
| num_samples=num_samples, | ||
| num_chains=1, | ||
| progress_bar=False, | ||
| ) | ||
| mcmc.run(key3) | ||
| samples = mcmc.get_samples() | ||
|
|
||
| assert "coefs" in samples | ||
| assert samples["coefs"].shape == (num_samples, dim) | ||
| assert_allclose(jnp.mean(samples["coefs"], 0), true_coefs, atol=0.5) | ||
|
|
||
|
|
||
| def test_mclmc_sample_shape(): | ||
| """Test that MCLMC produces samples with expected shapes.""" | ||
| num_warmup, num_samples = 500, 500 | ||
|
|
||
| def model(): | ||
| numpyro.sample("a", dist.Normal(0, 1)) | ||
| numpyro.sample("b", dist.Normal(0, 1).expand([3])) | ||
| numpyro.sample("c", dist.Normal(0, 1).expand([2, 4])) | ||
|
|
||
| kernel = MCLMC(model=model) | ||
| mcmc = MCMC( | ||
| kernel, | ||
| num_warmup=num_warmup, | ||
| num_samples=num_samples, | ||
| num_chains=1, | ||
| progress_bar=False, | ||
| ) | ||
| mcmc.run(random.PRNGKey(0)) | ||
| samples = mcmc.get_samples() | ||
|
|
||
| assert samples["a"].shape == (num_samples,) | ||
| assert samples["b"].shape == (num_samples, 3) | ||
| assert samples["c"].shape == (num_samples, 2, 4) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
presumably many of the private methods in this file have corresponding tests in blackjax that you can pull into this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in e2d754c