diff --git a/numpyro/_typing.py b/numpyro/_typing.py index d53d60c17..4639fda43 100644 --- a/numpyro/_typing.py +++ b/numpyro/_typing.py @@ -35,4 +35,8 @@ """A generic type for a pytree, i.e. a nested structure of lists, tuples, dicts, and arrays.""" +LogDensityFn: TypeAlias = Callable[[PyTree], NumLike] +"""Callable log-density signature used by gradient-based kernels.""" + + NumLikeT = TypeVar("NumLikeT", bound=NumLike) diff --git a/numpyro/infer/mclmc.py b/numpyro/infer/mclmc.py new file mode 100644 index 000000000..9cd0ad340 --- /dev/null +++ b/numpyro/infer/mclmc.py @@ -0,0 +1,859 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Callable +from typing import Any, NamedTuple, cast + +import jax +from jax.flatten_util import ravel_pytree +import jax.numpy as jnp +from jax.typing import ArrayLike + +from numpyro._typing import LogDensityFn, NumLike, PyTree +from numpyro.diagnostics import effective_sample_size +from numpyro.infer.mcmc import MCMCKernel +from numpyro.infer.util import initialize_model +from numpyro.util import identity + + +class MCLMCState(NamedTuple): + position: PyTree + momentum: PyTree + logdensity: NumLike + logdensity_grad: PyTree + + +class MCLMCInfo(NamedTuple): + logdensity: NumLike + kinetic_change: NumLike + energy_change: NumLike + + +class MCLMCAdaptationState(NamedTuple): + L: NumLike + step_size: NumLike + inverse_mass_matrix: ArrayLike + + +class FullState(NamedTuple): + position: PyTree + momentum: PyTree + logdensity: NumLike + logdensity_grad: PyTree + rng_key: jax.dtypes.prng_key + + +class _AdaptationAverages(NamedTuple): + time: NumLike + x_average: NumLike + step_size_max: NumLike + + +class _StreamingAverage(NamedTuple): + total: NumLike + average: ArrayLike + + +class _AdaptationIterationState(NamedTuple): + state: MCLMCState + params: MCLMCAdaptationState + adaptive_state: _AdaptationAverages + streaming_avg: _StreamingAverage + + +KernelFn = Callable[ + [jax.dtypes.prng_key, MCLMCState, NumLike, NumLike], tuple[MCLMCState, MCLMCInfo] +] +KernelFactoryFn = Callable[[ArrayLike], KernelFn] + + +# First momentum-stage coefficient in the 5-stage McLachlan splitting scheme. +_MCLACHLAN_B1: float = 0.1931833275037836 +# Palindromic integrator coefficients for one isokinetic McLachlan update. +_MCLACHLAN_COEFS: tuple[float, ...] = ( + _MCLACHLAN_B1, + 0.5, + 1 - 2 * _MCLACHLAN_B1, + 0.5, + _MCLACHLAN_B1, +) +# When NaNs are detected during adaptation, shrink step size by this factor. +_DELTA_NAN_STEP_SIZE_FACTOR: float = 0.8 + + +def _pytree_size(pytree: PyTree) -> int: + return sum(jnp.size(leaf) for leaf in jax.tree.leaves(pytree)) + + +def _generate_unit_vector(rng_key: jax.dtypes.prng_key, position: PyTree) -> PyTree: + flat_position, unravel_fn = ravel_pytree(position) + sample = jax.random.normal( + rng_key, shape=flat_position.shape, dtype=flat_position.dtype + ) + return unravel_fn(sample / jnp.linalg.norm(sample)) + + +def _incremental_value_update( + expectation: ArrayLike, + incremental_val: _StreamingAverage, + weight: NumLike = 1.0, + zero_prevention: NumLike = 0.0, +) -> _StreamingAverage: + total, average = incremental_val + total_weight = total + weight + zero_prevention + + def _update_average(exp: ArrayLike, av: ArrayLike) -> ArrayLike: + numerator = total * av + weight * exp + return jnp.where(numerator == 0.0, 0.0, numerator / total_weight) + + average = jax.tree.map( + _update_average, + expectation, + average, + ) + return _StreamingAverage(total + weight, average) + + +def _init_mclmc( + position: PyTree, logdensity_fn: LogDensityFn, rng_key: jax.dtypes.prng_key +) -> MCLMCState: + if _pytree_size(position) < 2: + raise ValueError( + "The target distribution must have more than 1 dimension for MCLMC." + ) + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return MCLMCState( + position=position, + momentum=_generate_unit_vector(rng_key, position), + logdensity=logdensity, + logdensity_grad=logdensity_grad, + ) + + +def _position_update( + position: PyTree, + kinetic_grad: PyTree, + step_size: NumLike, + coef: NumLike, + logdensity_fn: LogDensityFn, +) -> tuple[PyTree, NumLike, PyTree]: + new_position = jax.tree.map( + lambda x, grad: x + step_size * coef * grad, + position, + kinetic_grad, + ) + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(new_position) + return new_position, logdensity, logdensity_grad + + +def _normalized_flatten(x: ArrayLike, tol: float = 1e-13) -> tuple[ArrayLike, NumLike]: + norm = jnp.linalg.norm(x) + return jnp.where(norm > tol, x / norm, x), norm + + +def _esh_dynamics_momentum_update_one_step( + momentum: PyTree, + logdensity_grad: PyTree, + step_size: NumLike, + coef: NumLike, + inverse_mass_matrix: ArrayLike, + previous_kinetic_energy_change: NumLike | None = None, +) -> tuple[PyTree, PyTree, NumLike]: + sqrt_inverse_mass_matrix = jnp.sqrt(inverse_mass_matrix) + flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) + flatten_grads = flatten_grads * sqrt_inverse_mass_matrix + flatten_momentum, _ = ravel_pytree(momentum) + dims = flatten_momentum.shape[0] + + normalized_gradient, gradient_norm = _normalized_flatten(flatten_grads) + momentum_proj = jnp.dot(flatten_momentum, normalized_gradient) + delta = step_size * coef * gradient_norm / (dims - 1) + zeta = jnp.exp(-delta) + new_momentum_raw = ( + normalized_gradient * (1 - zeta) * (1 + zeta + momentum_proj * (1 - zeta)) + + 2 * zeta * flatten_momentum + ) + new_momentum_normalized, _ = _normalized_flatten(new_momentum_raw) + next_momentum = unravel_fn(new_momentum_normalized) + kinetic_grad = unravel_fn(new_momentum_normalized * sqrt_inverse_mass_matrix) + kinetic_energy_change = ( + delta + - jnp.log(2.0) + + jnp.log(1 + momentum_proj + (1 - momentum_proj) * zeta**2) + ) * (dims - 1) + if previous_kinetic_energy_change is not None: + kinetic_energy_change = kinetic_energy_change + previous_kinetic_energy_change + return next_momentum, kinetic_grad, kinetic_energy_change + + +def _isokinetic_mclachlan_step( + state: MCLMCState, + step_size: NumLike, + logdensity_fn: LogDensityFn, + inverse_mass_matrix: ArrayLike, +) -> tuple[MCLMCState, NumLike]: + position, momentum, _, logdensity_grad = state + kinetic_energy_change = None + + for i, coef in enumerate(_MCLACHLAN_COEFS[:-1]): + if i % 2 == 0: + momentum, kinetic_grad, kinetic_energy_change = ( + _esh_dynamics_momentum_update_one_step( + momentum=momentum, + logdensity_grad=logdensity_grad, + step_size=step_size, + coef=coef, + inverse_mass_matrix=inverse_mass_matrix, + previous_kinetic_energy_change=kinetic_energy_change, + ) + ) + else: + position, logdensity, logdensity_grad = _position_update( + position=position, + kinetic_grad=kinetic_grad, + step_size=step_size, + coef=coef, + logdensity_fn=logdensity_fn, + ) + + momentum, _, kinetic_energy_change = _esh_dynamics_momentum_update_one_step( + momentum=momentum, + logdensity_grad=logdensity_grad, + step_size=step_size, + coef=_MCLACHLAN_COEFS[-1], + inverse_mass_matrix=inverse_mass_matrix, + previous_kinetic_energy_change=kinetic_energy_change, + ) + return MCLMCState( + position, momentum, logdensity, logdensity_grad + ), kinetic_energy_change + + +def _partially_refresh_momentum( + momentum: PyTree, rng_key: jax.dtypes.prng_key, step_size: NumLike, L: NumLike +) -> PyTree: + flat_momentum, unravel_fn = ravel_pytree(momentum) + dim = flat_momentum.shape[0] + nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) + z = nu * jax.random.normal( + rng_key, shape=flat_momentum.shape, dtype=flat_momentum.dtype + ) + new_momentum = unravel_fn((flat_momentum + z) / jnp.linalg.norm(flat_momentum + z)) + return jax.lax.cond( + jnp.isinf(L), lambda _: momentum, lambda _: new_momentum, operand=None + ) + + +def _maruyama_step( + init_state: MCLMCState, + step_size: NumLike, + L: NumLike, + rng_key: jax.dtypes.prng_key, + logdensity_fn: LogDensityFn, + inverse_mass_matrix: ArrayLike, +) -> tuple[MCLMCState, NumLike]: + key1, key2 = jax.random.split(rng_key) + state = init_state._replace( + momentum=_partially_refresh_momentum( + momentum=init_state.momentum, + rng_key=key1, + L=L, + step_size=step_size * 0.5, + ) + ) + state, kinetic_change = _isokinetic_mclachlan_step( + state=state, + step_size=step_size, + logdensity_fn=logdensity_fn, + inverse_mass_matrix=inverse_mass_matrix, + ) + state = state._replace( + momentum=_partially_refresh_momentum( + momentum=state.momentum, + rng_key=key2, + L=L, + step_size=step_size * 0.5, + ) + ) + return state, kinetic_change + + +def _state_is_finite(state: MCLMCState) -> NumLike: + flat_position, _ = ravel_pytree(state.position) + flat_momentum, _ = ravel_pytree(state.momentum) + return jnp.logical_and( + jnp.all(jnp.isfinite(flat_position)), jnp.all(jnp.isfinite(flat_momentum)) + ) + + +def _fallback_state_with_fresh_momentum( + previous_state: MCLMCState, key: jax.dtypes.prng_key +) -> MCLMCState: + return previous_state._replace( + momentum=_generate_unit_vector(key, previous_state.position) + ) + + +def _handle_nans( + previous_state: MCLMCState, + next_state: MCLMCState, + info: MCLMCInfo, + key: jax.dtypes.prng_key, +) -> tuple[MCLMCState, MCLMCInfo]: + nonans = _state_is_finite(next_state) + state, info = jax.lax.cond( + nonans, + lambda: (next_state, info), + lambda: ( + _fallback_state_with_fresh_momentum(previous_state, key), + MCLMCInfo( + logdensity=previous_state.logdensity, + energy_change=jnp.zeros_like(info.energy_change), + kinetic_change=jnp.zeros_like(info.kinetic_change), + ), + ), + ) + return state, info + + +def _build_kernel( + logdensity_fn: LogDensityFn, inverse_mass_matrix: ArrayLike +) -> KernelFn: + def kernel( + rng_key: jax.dtypes.prng_key, state: MCLMCState, L: NumLike, step_size: NumLike + ) -> tuple[MCLMCState, MCLMCInfo]: + kernel_key, nan_key = jax.random.split(rng_key) + next_state, kinetic_change = _maruyama_step( + init_state=state, + step_size=step_size, + L=L, + rng_key=kernel_key, + logdensity_fn=logdensity_fn, + inverse_mass_matrix=inverse_mass_matrix, + ) + energy_change = kinetic_change - next_state.logdensity + state.logdensity + next_state, info = _handle_nans( + previous_state=state, + next_state=next_state, + info=MCLMCInfo( + logdensity=next_state.logdensity, + energy_change=energy_change, + kinetic_change=kinetic_change, + ), + key=nan_key, + ) + return next_state, info + + return kernel + + +def _adaptation_handle_nans( + previous_state: MCLMCState, + next_state: MCLMCState, + step_size: NumLike, + step_size_max: NumLike, + kinetic_change: NumLike, + key: jax.dtypes.prng_key, +) -> tuple[NumLike, MCLMCState, NumLike, NumLike]: + nonans = _state_is_finite(next_state) + state, step_size, kinetic_change = jax.tree.map( + lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), + (next_state, step_size_max, kinetic_change), + (previous_state, step_size * _DELTA_NAN_STEP_SIZE_FACTOR, 0.0), + ) + state = jax.lax.cond( + jnp.isnan(next_state.logdensity), + lambda: _fallback_state_with_fresh_momentum(state, key), + lambda: state, + ) + return nonans, state, step_size, kinetic_change + + +def _make_l_step_size_adaptation( + kernel_fn: KernelFactoryFn, + dim: int, + frac_tune1: NumLike, + frac_tune2: NumLike, + diagonal_preconditioning: bool, + desired_energy_var: NumLike = 1e-3, + trust_in_estimate: NumLike = 1.5, + num_effective_samples: int = 150, +) -> Callable[ + [MCLMCState, MCLMCAdaptationState, int, jax.dtypes.prng_key], + tuple[MCLMCState, MCLMCAdaptationState], +]: + decay_rate = (num_effective_samples - 1.0) / (num_effective_samples + 1.0) + + def predictor( + previous_state: MCLMCState, + params: MCLMCAdaptationState, + adaptive_state: _AdaptationAverages, + rng_key: jax.dtypes.prng_key, + ) -> tuple[MCLMCState, MCLMCAdaptationState, _AdaptationAverages, NumLike]: + time, x_average, step_size_max = adaptive_state + rng_key, nan_key = jax.random.split(rng_key) + next_state, info = kernel_fn(params.inverse_mass_matrix)( + rng_key=rng_key, + state=previous_state, + L=params.L, + step_size=params.step_size, + ) + success, state, step_size_max, energy_change = _adaptation_handle_nans( + previous_state=previous_state, + next_state=next_state, + step_size=params.step_size, + step_size_max=step_size_max, + kinetic_change=info.energy_change, + key=nan_key, + ) + xi = jnp.square(energy_change) / (dim * desired_energy_var) + 1e-8 + weight = jnp.exp(-0.5 * jnp.square(jnp.log(xi) / (6.0 * trust_in_estimate))) + x_average = decay_rate * x_average + weight * ( + xi / jnp.power(params.step_size, 6.0) + ) + time = decay_rate * time + weight + step_size = jnp.power(x_average / time, -1.0 / 6.0) + step_size = jnp.where(step_size < step_size_max, step_size, step_size_max) + params_new = params._replace(step_size=step_size) + return ( + state, + params_new, + _AdaptationAverages(time, x_average, step_size_max), + success, + ) + + def step( + iteration_state: _AdaptationIterationState, + weight_and_key: tuple[NumLike, jax.dtypes.prng_key], + ) -> tuple[_AdaptationIterationState, None]: + mask, rng_key = weight_and_key + state, params, adaptive_state, streaming_avg = iteration_state + state, params, adaptive_state, success = predictor( + state, params, adaptive_state, rng_key + ) + x = ravel_pytree(state.position)[0] + streaming_avg = _incremental_value_update( + expectation=jnp.array([x, jnp.square(x)]), + incremental_val=streaming_avg, + weight=mask * success * params.step_size, + ) + return _AdaptationIterationState( + state, params, adaptive_state, streaming_avg + ), None + + def run_steps( + xs: tuple[ArrayLike, jax.dtypes.prng_key], + state: MCLMCState, + params: MCLMCAdaptationState, + ) -> _AdaptationIterationState: + return jax.lax.scan( + step, + init=_AdaptationIterationState( + state, + params, + _AdaptationAverages(0.0, 0.0, jnp.inf), + _StreamingAverage(0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), + ), + xs=xs, + )[0] + + def adaptation( + state: MCLMCState, + params: MCLMCAdaptationState, + num_steps: int, + rng_key: jax.dtypes.prng_key, + ) -> tuple[MCLMCState, MCLMCAdaptationState]: + num_steps1 = round(num_steps * frac_tune1) + num_steps2 = round(num_steps * frac_tune2) + keys = jax.random.split(rng_key, num_steps1 + num_steps2 + 1) + tune_keys, final_key = keys[:-1], keys[-1] + mask = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) + + iteration_state = run_steps((mask, tune_keys), state, params) + state, params = iteration_state.state, iteration_state.params + average = iteration_state.streaming_avg.average + L = params.L + inverse_mass_matrix = params.inverse_mass_matrix + if num_steps2 > 1: + x_average, x_squared_average = average[0], average[1] + variances = x_squared_average - jnp.square(x_average) + L = jnp.sqrt(jnp.sum(variances)) + if diagonal_preconditioning: + inverse_mass_matrix = variances + params = params._replace(inverse_mass_matrix=inverse_mass_matrix) + L = jnp.sqrt(dim) + steps = round(num_steps2 / 3) + keys = jax.random.split(final_key, steps) + iteration_state = run_steps((jnp.ones(steps), keys), state, params) + state, params = iteration_state.state, iteration_state.params + return state, MCLMCAdaptationState(L, params.step_size, inverse_mass_matrix) + + return adaptation + + +def _make_adaptation_l(kernel, frac, lfactor): + def adaptation_l( + state: MCLMCState, + params: MCLMCAdaptationState, + num_steps: int, + rng_key: jax.dtypes.prng_key, + ) -> tuple[MCLMCState, MCLMCAdaptationState]: + num_steps3 = round(num_steps * frac) + keys = jax.random.split(rng_key, num_steps3) + + def step( + curr_state: MCLMCState, key: jax.dtypes.prng_key + ) -> tuple[MCLMCState, PyTree]: + next_state, _ = kernel( + rng_key=key, + state=curr_state, + L=params.L, + step_size=params.step_size, + ) + return next_state, next_state.position + + state, samples = jax.lax.scan(step, init=state, xs=keys) + flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) + ess = effective_sample_size(flat_samples[None, ...]) + ess = jnp.nan_to_num( + ess, + nan=1.0, + posinf=float(num_steps3), + neginf=1.0, + ) + ess = jnp.clip(ess, min=1.0) + return state, params._replace( + L=lfactor * params.step_size * jnp.mean(num_steps3 / ess) + ) + + return adaptation_l + + +def _mclmc_find_l_and_step_size( + mclmc_kernel: KernelFactoryFn, + num_steps: int, + state: MCLMCState, + rng_key: jax.dtypes.prng_key, + frac_tune1: NumLike = 0.1, + frac_tune2: NumLike = 0.1, + frac_tune3: NumLike = 0.1, + desired_energy_var: NumLike = 5e-4, + trust_in_estimate: NumLike = 1.5, + num_effective_samples: int = 150, + diagonal_preconditioning: bool = True, + params: MCLMCAdaptationState | None = None, + lfactor: NumLike = 0.4, +) -> tuple[MCLMCState, MCLMCAdaptationState, int]: + dim = _pytree_size(state.position) + if params is None: + params = MCLMCAdaptationState( + L=jnp.sqrt(dim), + step_size=jnp.sqrt(dim) * 0.25, + inverse_mass_matrix=jnp.ones((dim,)), + ) + + part1_key, part2_key = jax.random.split(rng_key, 2) + num_steps1 = round(num_steps * frac_tune1) + num_steps2 = round(num_steps * frac_tune2) + num_steps2 += diagonal_preconditioning * (num_steps2 // 3) + num_steps3 = round(num_steps * frac_tune3) + + state, params = _make_l_step_size_adaptation( + kernel_fn=mclmc_kernel, + dim=dim, + frac_tune1=frac_tune1, + frac_tune2=frac_tune2, + desired_energy_var=desired_energy_var, + trust_in_estimate=trust_in_estimate, + num_effective_samples=num_effective_samples, + diagonal_preconditioning=diagonal_preconditioning, + )(state, params, num_steps, part1_key) + + total_num_tuning_integrator_steps = num_steps1 + num_steps2 + if num_steps3 >= 2: + state, params = _make_adaptation_l( + kernel=mclmc_kernel(params.inverse_mass_matrix), + frac=frac_tune3, + lfactor=lfactor, + )(state, params, num_steps, part2_key) + total_num_tuning_integrator_steps += num_steps3 + return state, params, total_num_tuning_integrator_steps + + +class MCLMC(MCMCKernel): + """ + Microcanonical Langevin Monte Carlo (MCLMC) kernel. + + This kernel implements an isokinetic integrator with stochastic momentum + refreshment. During warmup, it automatically tunes step size, momentum + decoherence length ``L``, and optionally a diagonal preconditioner. + The resulting state can be used with :class:`~numpyro.infer.mcmc.MCMC`. + + Example + ------- + + A minimal 2D model: + + .. code-block:: python + + import jax + import jax.numpy as jnp + import numpyro + import numpyro.distributions as dist + from numpyro.infer import MCMC + from numpyro.infer.mclmc import MCLMC + + def model(): + numpyro.sample("x", dist.Normal(jnp.array([0.0, 0.0]), 1.0).to_event(1)) + + kernel = MCLMC(model=model) + mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, progress_bar=False) + mcmc.run(jax.random.key(0)) + samples = mcmc.get_samples() + + Model with observed data and tuned energy variance: + + .. code-block:: python + + def model(X, y=None): + w = numpyro.sample("w", dist.Normal(jnp.zeros(X.shape[-1]), 1.0)) + logits = X @ w + numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=y) + + kernel = MCLMC( + model=model, + desired_energy_var=5e-4, + diagonal_preconditioning=True, + ) + mcmc = MCMC(kernel, num_warmup=1500, num_samples=1000, progress_bar=False) + mcmc.run(jax.random.key(1), X, y) + + **References:** + + 1. *Microcanonical Hamiltonian Monte Carlo*, + Jakob Robnik, G. Bruno De Luca, Eva Silverstein, Uroš Seljak + https://arxiv.org/abs/2212.08549 + + .. note:: The model must have at least 2 unconstrained latent dimensions. + This limitation comes from the isokinetic MCLMC dynamics. + + :param model: Python callable containing NumPyro primitives. + :param float desired_energy_var: Target energy variance used in warmup to tune + step size. Smaller values generally lead to more conservative integration + steps. Defaults to ``5e-4``. + :param bool diagonal_preconditioning: Whether warmup should estimate a diagonal + inverse mass matrix. If ``False``, adaptation uses isotropic scaling. + Defaults to ``True``. + """ + + def __init__( + self: "MCLMC", + model: Callable[..., Any] | None = None, + desired_energy_var: NumLike = 5e-4, + diagonal_preconditioning: bool = True, + ) -> None: + """ + Construct an MCLMC kernel. + + :param model: NumPyro model callable that defines latent variables and + observations. + :param desired_energy_var: Target energy variance used during warmup to + tune step size. Smaller values typically produce more conservative + integrator updates. + :param diagonal_preconditioning: Whether to estimate a diagonal inverse + mass matrix during warmup. If ``False``, adaptation uses isotropic + scaling. + :raises ValueError: If ``model`` is not provided. + """ + if model is None: + raise ValueError("Model must be specified for MCLMC") + self._model = model + self._diagonal_preconditioning = diagonal_preconditioning + self._desired_energy_var = desired_energy_var + self._postprocess_fn: Callable[..., Callable[[PyTree], PyTree]] | None = None + self.logdensity_fn: LogDensityFn | None = None + self.adapt_state: MCLMCAdaptationState | None = None + self._kernel: KernelFn | None = None + + @property + def model(self: "MCLMC") -> Callable[..., Any]: + """Return the model callable associated with this kernel.""" + return self._model + + @property + def sample_field(self: "MCLMC") -> str: + """ + Name of the state attribute treated as the MCMC sample. + + This is used by :class:`~numpyro.infer.mcmc.MCMC` for collection and + postprocessing. + """ + return "position" + + @property + def default_fields(self: "MCLMC") -> tuple[str, ...]: + """ + State attributes collected by default during sampling. + + :return: Tuple of field names to collect from each state. + """ + return (self.sample_field,) + + def get_diagnostics_str(self: "MCLMC", state: FullState) -> str: + """ + Return progress-bar diagnostics for current adaptation parameters. + + :param state: Current full sampler state (unused; present for kernel API + compatibility). + :return: A formatted diagnostics string during/after initialization, or + an empty string if adaptation is unavailable. + """ + if self.adapt_state is None: + return "" + return "step_size={:.2e}, L={:.2e}".format( + self.adapt_state.step_size, self.adapt_state.L + ) + + def postprocess_fn( + self: "MCLMC", args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> Callable[[PyTree], PyTree]: + """ + Build a transform from unconstrained latent space to constrained space. + + :param args: Positional model arguments used to initialize transforms. + :param kwargs: Keyword model arguments used to initialize transforms. + :return: Callable that maps unconstrained latent samples to constrained + values and includes deterministic sites. + """ + if self._postprocess_fn is None: + return cast(Callable[[PyTree], PyTree], identity) + return self._postprocess_fn(*args, **kwargs) + + def init( + self: "MCLMC", + rng_key: jax.dtypes.prng_key, + num_warmup: int, + init_params: Any, + model_args: tuple[Any, ...], + model_kwargs: dict[str, Any], + ) -> FullState: + """ + Initialize sampler state and run warmup adaptation. + + This method initializes model state, builds the log-density function, + adapts ``step_size``, ``L``, and (optionally) diagonal preconditioning, + then returns a ready-to-sample state. + + :param rng_key: JAX PRNG key. + :param num_warmup: Number of warmup steps requested by the outer MCMC + driver; used to set adaptation phase fractions. + :param init_params: Optional initial parameters (kept for kernel API + compatibility; model initialization is delegated to + :func:`~numpyro.infer.util.initialize_model`). + :param model_args: Positional arguments passed to the model. + :param model_kwargs: Keyword arguments passed to the model. + :return: Fully initialized :class:`FullState`. + """ + init_model_key, init_state_key, run_key, tune_key = jax.random.split(rng_key, 4) + init_params, potential_fn_gen, postprocess_fn, _ = initialize_model( + init_model_key, + self._model, + model_args=model_args, + model_kwargs=model_kwargs, + dynamic_args=True, + ) + self._postprocess_fn = postprocess_fn + + def logdensity_fn(position: PyTree) -> NumLike: + return -potential_fn_gen(*model_args, **model_kwargs)(position) + + self.logdensity_fn = logdensity_fn + sampler_state = _init_mclmc( + position=init_params.z, + logdensity_fn=self.logdensity_fn, + rng_key=init_state_key, + ) + + def kernel_fn( + inverse_mass_matrix: ArrayLike, + ) -> KernelFn: + return _build_kernel( + logdensity_fn=self.logdensity_fn, + inverse_mass_matrix=inverse_mass_matrix, + ) + + num_tuning_steps = 100 + tuned_state, self.adapt_state, _ = _mclmc_find_l_and_step_size( + mclmc_kernel=kernel_fn, + num_steps=num_tuning_steps, + state=sampler_state, + rng_key=tune_key, + diagonal_preconditioning=self._diagonal_preconditioning, + frac_tune1=num_warmup / (3 * num_tuning_steps), + frac_tune2=num_warmup / (3 * num_tuning_steps), + frac_tune3=num_warmup / (3 * num_tuning_steps), + desired_energy_var=self._desired_energy_var, + ) + self._kernel = _build_kernel( + logdensity_fn=self.logdensity_fn, + inverse_mass_matrix=self.adapt_state.inverse_mass_matrix, + ) + return FullState( + tuned_state.position, + tuned_state.momentum, + tuned_state.logdensity, + tuned_state.logdensity_grad, + run_key, + ) + + def sample( + self: "MCLMC", + state: FullState, + model_args: tuple[Any, ...], + model_kwargs: dict[str, Any], + ) -> FullState: + """ + Advance the Markov chain by one MCLMC transition. + + :param state: Current full sampler state. + :param model_args: Unused after initialization (kept for API + compatibility with :class:`~numpyro.infer.mcmc.MCMCKernel`). + :param model_kwargs: Unused after initialization (kept for API + compatibility with :class:`~numpyro.infer.mcmc.MCMCKernel`). + :return: Next :class:`FullState` after one transition. + :raises RuntimeError: If called before :meth:`init`. + """ + del model_args, model_kwargs + mclmc_state = MCLMCState( + state.position, state.momentum, state.logdensity, state.logdensity_grad + ) + rng_key, sample_key = jax.random.split(state.rng_key, 2) + if self._kernel is None or self.adapt_state is None: + msg = "MCLMC must be initialized before calling sample." + raise RuntimeError(msg) + new_state, _ = self._kernel( + rng_key=sample_key, + state=mclmc_state, + step_size=self.adapt_state.step_size, + L=self.adapt_state.L, + ) + return FullState( + new_state.position, + new_state.momentum, + new_state.logdensity, + new_state.logdensity_grad, + rng_key, + ) + + def __getstate__(self: "MCLMC") -> dict[str, Any]: + """ + Return a pickle-safe object state. + + The cached postprocess closure is intentionally cleared because closures + from ``initialize_model`` are not reliably serializable. + + :return: Serializable state dictionary for this kernel instance. + """ + state = self.__dict__.copy() + state["_postprocess_fn"] = None + return state diff --git a/pyproject.toml b/pyproject.toml index 46486b1ac..b3cb374f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -224,6 +224,7 @@ module = [ "numpyro.diagnostics.*", "numpyro.handlers.*", "numpyro.infer.elbo.*", + "numpyro.infer.mclmc.py", "numpyro.optim.*", "numpyro.primitives.*", "numpyro.patch.*", diff --git a/test/infer/test_mclmc.py b/test/infer/test_mclmc.py new file mode 100644 index 000000000..68e231a2d --- /dev/null +++ b/test/infer/test_mclmc.py @@ -0,0 +1,659 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + + +from numpy.testing import assert_allclose +import pytest + +import jax +from jax import random +import jax.numpy as jnp + +import numpyro +import numpyro.distributions as dist +from numpyro.infer import MCMC +import numpyro.infer.mclmc as mclmc_module +from numpyro.infer.mclmc import MCLMC + + +def _two_dim_model(): + numpyro.sample("x", dist.Normal(jnp.array([0.0, 0.0]), 1.0).to_event(1)) + + +def _model_with_args(loc, scale=1.0): + numpyro.sample("x", dist.Normal(loc, scale).to_event(1)) + + +def _gaussian_logdensity(x): + return -0.5 * jnp.sum(jnp.square(x)) + + +def _make_test_state(key=None): + if key is None: + key = random.PRNGKey(0) + return mclmc_module._init_mclmc( + position=jnp.array([0.3, -0.7]), + logdensity_fn=_gaussian_logdensity, + rng_key=key, + ) + + +def test_pytree_size_counts_all_leaves(): + pytree = {"a": jnp.zeros((2, 3)), "b": [jnp.ones((4,)), jnp.ones(())]} + assert mclmc_module._pytree_size(pytree) == 11 + + +def test_generate_unit_vector_has_unit_norm(): + position = jnp.array([1.0, 2.0, 3.0]) + vec = mclmc_module._generate_unit_vector(random.PRNGKey(0), position) + flat_vec, _ = jax.flatten_util.ravel_pytree(vec) + assert_allclose(jnp.linalg.norm(flat_vec), 1.0, atol=1e-6) + assert flat_vec.shape == position.shape + + +def test_incremental_value_update_weighted_average(): + avg = mclmc_module._StreamingAverage( + total=jnp.array(0.0), average=jnp.array([0.0, 0.0]) + ) + avg = mclmc_module._incremental_value_update( + expectation=jnp.array([2.0, 4.0]), incremental_val=avg, weight=2.0 + ) + avg = mclmc_module._incremental_value_update( + expectation=jnp.array([4.0, 8.0]), incremental_val=avg, weight=2.0 + ) + assert_allclose(avg.average, jnp.array([3.0, 6.0]), atol=1e-6) + assert_allclose(avg.total, 4.0, atol=1e-6) + + +def test_incremental_value_update_zero_numerator_safe(): + avg = mclmc_module._StreamingAverage(total=jnp.array(0.0), average=jnp.array([0.0])) + updated = mclmc_module._incremental_value_update( + expectation=jnp.array([0.0]), incremental_val=avg, weight=0.0 + ) + assert_allclose(updated.average, jnp.array([0.0])) + + +def test_init_mclmc_rejects_low_dimension(): + with pytest.raises( + ValueError, match="target distribution must have more than 1 dimension" + ): + mclmc_module._init_mclmc( + position=jnp.array([0.0]), + logdensity_fn=lambda x: -0.5 * jnp.sum(x**2), + rng_key=random.PRNGKey(0), + ) + + +def test_init_mclmc_returns_valid_state(): + state = _make_test_state(random.PRNGKey(1)) + flat_momentum, _ = jax.flatten_util.ravel_pytree(state.momentum) + assert jnp.isfinite(state.logdensity) + assert jnp.all(jnp.isfinite(state.logdensity_grad)) + assert_allclose(jnp.linalg.norm(flat_momentum), 1.0, atol=1e-6) + + +def test_position_update_matches_expected_gaussian_update(): + position = jnp.array([1.0, 2.0]) + kinetic_grad = jnp.array([0.5, -1.0]) + new_position, logdensity, grad = mclmc_module._position_update( + position=position, + kinetic_grad=kinetic_grad, + step_size=0.1, + coef=0.5, + logdensity_fn=_gaussian_logdensity, + ) + expected_position = jnp.array([1.025, 1.95]) + assert_allclose(new_position, expected_position, atol=1e-7) + assert_allclose(logdensity, _gaussian_logdensity(expected_position), atol=1e-7) + assert_allclose(grad, -expected_position, atol=1e-7) + + +def test_normalized_flatten_for_nonzero_and_zero_vectors(): + normalized, norm = mclmc_module._normalized_flatten(jnp.array([3.0, 4.0])) + assert_allclose(normalized, jnp.array([0.6, 0.8]), atol=1e-7) + assert_allclose(norm, 5.0, atol=1e-7) + + normalized_zero, norm_zero = mclmc_module._normalized_flatten(jnp.zeros((3,))) + assert_allclose(normalized_zero, jnp.zeros((3,)), atol=1e-7) + assert_allclose(norm_zero, 0.0, atol=1e-7) + + +def test_esh_dynamics_momentum_update_matches_naive_formula(): + step_size = 1e-3 + key0, key1 = random.split(random.PRNGKey(62)) + gradient = random.uniform(key0, shape=(3,)) + momentum = random.uniform(key1, shape=(3,)) + momentum = momentum / jnp.linalg.norm(momentum) + + gradient_norm = jnp.linalg.norm(gradient) + gradient_normalized = gradient / gradient_norm + delta = step_size * gradient_norm / (momentum.shape[0] - 1) + naive_next = ( + momentum + + gradient_normalized + * ( + jnp.sinh(delta) + + jnp.dot(gradient_normalized, momentum * (jnp.cosh(delta) - 1)) + ) + ) / (jnp.cosh(delta) + jnp.dot(gradient_normalized, momentum * jnp.sinh(delta))) + naive_next = naive_next / jnp.linalg.norm(naive_next) + + next_momentum, _, _ = mclmc_module._esh_dynamics_momentum_update_one_step( + momentum=momentum, + logdensity_grad=gradient, + step_size=step_size, + coef=1.0, + inverse_mass_matrix=jnp.ones((3,)), + ) + assert_allclose(next_momentum, naive_next, atol=1e-6) + + +def test_isokinetic_mclachlan_step_returns_finite_state_and_unit_momentum(): + state = _make_test_state(random.PRNGKey(0)) + next_state, kinetic_change = mclmc_module._isokinetic_mclachlan_step( + state=state, + step_size=1e-3, + logdensity_fn=_gaussian_logdensity, + inverse_mass_matrix=jnp.ones((2,)), + ) + flat_momentum, _ = jax.flatten_util.ravel_pytree(next_state.momentum) + assert jnp.isfinite(kinetic_change) + assert jnp.isfinite(next_state.logdensity) + assert jnp.all(jnp.isfinite(next_state.logdensity_grad)) + assert_allclose(jnp.linalg.norm(flat_momentum), 1.0, atol=1e-5) + + +def test_partially_refresh_momentum_respects_infinite_l(): + momentum = jnp.array([1.0, 0.0]) + refreshed_inf = mclmc_module._partially_refresh_momentum( + momentum=momentum, + rng_key=random.PRNGKey(0), + step_size=0.1, + L=jnp.inf, + ) + assert_allclose(refreshed_inf, momentum) + + refreshed = mclmc_module._partially_refresh_momentum( + momentum=momentum, + rng_key=random.PRNGKey(0), + step_size=0.1, + L=1.0, + ) + assert_allclose(jnp.linalg.norm(refreshed), 1.0, atol=1e-6) + + +def test_maruyama_step_returns_finite_values(): + state = _make_test_state(random.PRNGKey(0)) + next_state, kinetic_change = mclmc_module._maruyama_step( + init_state=state, + step_size=1e-2, + L=1.0, + rng_key=random.PRNGKey(1), + logdensity_fn=_gaussian_logdensity, + inverse_mass_matrix=jnp.ones((2,)), + ) + assert jnp.isfinite(kinetic_change) + assert jnp.isfinite(next_state.logdensity) + assert jnp.all(jnp.isfinite(next_state.logdensity_grad)) + assert mclmc_module._state_is_finite(next_state) + + +def test_state_is_finite_detects_nan_and_inf(): + state = _make_test_state(random.PRNGKey(0)) + assert mclmc_module._state_is_finite(state) + + nan_state = state._replace(position=jnp.array([jnp.nan, 0.0])) + inf_state = state._replace(momentum=jnp.array([jnp.inf, 0.0])) + assert not mclmc_module._state_is_finite(nan_state) + assert not mclmc_module._state_is_finite(inf_state) + + +def test_fallback_state_with_fresh_momentum_preserves_position_and_logdensity(): + state = _make_test_state(random.PRNGKey(0)) + new_state = mclmc_module._fallback_state_with_fresh_momentum( + previous_state=state, key=random.PRNGKey(2) + ) + assert_allclose(new_state.position, state.position) + assert_allclose(new_state.logdensity, state.logdensity) + assert_allclose(jnp.linalg.norm(new_state.momentum), 1.0, atol=1e-6) + + +def test_handle_nans_keeps_valid_state_and_falls_back_for_invalid_state(): + previous = _make_test_state(random.PRNGKey(0)) + valid_next = previous._replace(position=previous.position + 0.1) + info = mclmc_module.MCLMCInfo( + logdensity=valid_next.logdensity, + kinetic_change=jnp.array(0.3), + energy_change=jnp.array(0.2), + ) + state_ok, info_ok = mclmc_module._handle_nans( + previous_state=previous, next_state=valid_next, info=info, key=random.PRNGKey(1) + ) + assert_allclose(state_ok.position, valid_next.position) + assert_allclose(info_ok.energy_change, info.energy_change) + + invalid_next = valid_next._replace(position=jnp.array([jnp.nan, 0.0])) + state_bad, info_bad = mclmc_module._handle_nans( + previous_state=previous, + next_state=invalid_next, + info=info, + key=random.PRNGKey(3), + ) + assert_allclose(state_bad.position, previous.position) + assert_allclose(info_bad.logdensity, previous.logdensity) + assert_allclose(info_bad.energy_change, 0.0) + assert_allclose(info_bad.kinetic_change, 0.0) + + +def test_build_kernel_single_step_outputs_finite_state_and_info(): + kernel = mclmc_module._build_kernel( + logdensity_fn=_gaussian_logdensity, inverse_mass_matrix=jnp.ones((2,)) + ) + state = _make_test_state(random.PRNGKey(0)) + next_state, info = kernel( + rng_key=random.PRNGKey(1), state=state, L=1.0, step_size=1e-2 + ) + assert mclmc_module._state_is_finite(next_state) + assert jnp.isfinite(info.logdensity) + assert jnp.isfinite(info.energy_change) + assert jnp.isfinite(info.kinetic_change) + + +def test_adaptation_handle_nans_behavior(): + previous = _make_test_state(random.PRNGKey(0)) + next_state = previous._replace(position=previous.position + 0.1) + success, state, new_step_size_max, new_kinetic = ( + mclmc_module._adaptation_handle_nans( + previous_state=previous, + next_state=next_state, + step_size=jnp.array(0.2), + step_size_max=jnp.array(0.5), + kinetic_change=jnp.array(0.1), + key=random.PRNGKey(2), + ) + ) + assert success + assert_allclose(state.position, next_state.position) + assert_allclose(new_step_size_max, 0.5) + assert_allclose(new_kinetic, 0.1) + + invalid = next_state._replace( + position=jnp.array([jnp.nan, 0.0]), logdensity=jnp.nan + ) + success, state, new_step_size_max, new_kinetic = ( + mclmc_module._adaptation_handle_nans( + previous_state=previous, + next_state=invalid, + step_size=jnp.array(0.2), + step_size_max=jnp.array(0.5), + kinetic_change=jnp.array(0.1), + key=random.PRNGKey(3), + ) + ) + assert not success + assert_allclose(new_step_size_max, 0.2 * mclmc_module._DELTA_NAN_STEP_SIZE_FACTOR) + assert_allclose(new_kinetic, 0.0) + assert_allclose(state.position, previous.position) + + +def test_make_l_step_size_adaptation_returns_finite_positive_params(): + dim = 2 + initial_state = _make_test_state(random.PRNGKey(0)) + params = mclmc_module.MCLMCAdaptationState( + L=jnp.sqrt(dim), + step_size=0.2, + inverse_mass_matrix=jnp.ones((dim,)), + ) + adaptation = mclmc_module._make_l_step_size_adaptation( + kernel_fn=lambda imm: mclmc_module._build_kernel(_gaussian_logdensity, imm), + dim=dim, + frac_tune1=0.2, + frac_tune2=0.2, + diagonal_preconditioning=True, + ) + state, new_params = adaptation( + initial_state, + params, + num_steps=30, + rng_key=random.PRNGKey(1), + ) + assert mclmc_module._state_is_finite(state) + assert jnp.isfinite(new_params.L) and (new_params.L > 0) + assert jnp.isfinite(new_params.step_size) and (new_params.step_size > 0) + assert new_params.inverse_mass_matrix.shape == (dim,) + + +def test_make_adaptation_l_nominal_case_updates_l(): + state = _make_test_state(random.PRNGKey(0)) + params = mclmc_module.MCLMCAdaptationState( + L=jnp.array(1.0), + step_size=jnp.array(0.1), + inverse_mass_matrix=jnp.ones((2,)), + ) + kernel = mclmc_module._build_kernel(_gaussian_logdensity, jnp.ones((2,))) + adaptation_l = mclmc_module._make_adaptation_l(kernel=kernel, frac=0.5, lfactor=0.4) + _, new_params = adaptation_l( + state=state, + params=params, + num_steps=12, + rng_key=random.PRNGKey(2), + ) + assert jnp.isfinite(new_params.L) + assert new_params.L > 0 + + +def test_mclmc_find_l_and_step_size_returns_expected_phase_accounting(): + state = _make_test_state(random.PRNGKey(0)) + state, params, total_steps = mclmc_module._mclmc_find_l_and_step_size( + mclmc_kernel=lambda imm: mclmc_module._build_kernel(_gaussian_logdensity, imm), + num_steps=20, + state=state, + rng_key=random.PRNGKey(1), + frac_tune1=0.2, + frac_tune2=0.2, + frac_tune3=0.2, + diagonal_preconditioning=True, + ) + expected_num_steps1 = round(20 * 0.2) + expected_num_steps2 = round(20 * 0.2) + (round(20 * 0.2) // 3) + expected_num_steps3 = round(20 * 0.2) + assert ( + total_steps == expected_num_steps1 + expected_num_steps2 + expected_num_steps3 + ) + assert mclmc_module._state_is_finite(state) + assert jnp.isfinite(params.L) and (params.L > 0) + assert jnp.isfinite(params.step_size) and (params.step_size > 0) + assert params.inverse_mass_matrix.shape == (2,) + + +def test_mclmc_model_required(): + """Test that ValueError is raised when model is None.""" + with pytest.raises(ValueError, match="Model must be specified"): + MCLMC(model=None) + + +def test_mclmc_normal(): + """Test MCLMC with a 2D normal distribution.""" + true_mean = jnp.array([1.0, 2.0]) + true_std = jnp.array([0.5, 1.0]) + num_warmup, num_samples = 1000, 2000 + + def model(): + numpyro.sample("x", dist.Normal(true_mean, true_std).to_event(1)) + + kernel = MCLMC(model=model) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=1, + progress_bar=False, + ) + mcmc.run(random.PRNGKey(0)) + samples = mcmc.get_samples() + + assert "x" in samples + assert samples["x"].shape == (num_samples, 2) + assert_allclose(jnp.mean(samples["x"], axis=0), true_mean, atol=0.1) + assert_allclose(jnp.std(samples["x"], axis=0), true_std, atol=0.2) + + +def test_mclmc_gaussian_2d(): + """Test MCLMC with a 2D Gaussian model with observation.""" + num_warmup, num_samples = 1000, 1000 + + def model(): + x = numpyro.sample("x", dist.Normal(0.0, 1.0)) + y = numpyro.sample("y", dist.Normal(0.0, 1.0)) + numpyro.sample("obs", dist.Normal(x + y, 0.5), obs=jnp.array(0.0)) + + kernel = MCLMC( + model=model, + diagonal_preconditioning=True, + desired_energy_var=5e-4, + ) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=1, + progress_bar=False, + ) + mcmc.run(random.PRNGKey(0)) + samples = mcmc.get_samples() + + assert "x" in samples + assert "y" in samples + assert samples["x"].shape == (num_samples,) + assert samples["y"].shape == (num_samples,) + # With obs=0, x+y should be close to 0, so means should be near 0 + assert_allclose(jnp.mean(samples["x"]) + jnp.mean(samples["y"]), 0.0, atol=0.2) + + +def test_mclmc_logistic_regression(): + """Test MCLMC with a logistic regression model.""" + N, dim = 1000, 3 + num_warmup, num_samples = 1000, 2000 + + key1, key2, key3 = random.split(random.PRNGKey(0), 3) + data = random.normal(key1, (N, dim)) + true_coefs = jnp.arange(1.0, dim + 1.0) + logits = jnp.sum(true_coefs * data, axis=-1) + labels = dist.Bernoulli(logits=logits).sample(key2) + + # Closure pattern is used here for compactness. + def model(): + coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) + logits = jnp.sum(coefs * data, axis=-1) + numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) + + kernel = MCLMC(model=model) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=1, + progress_bar=False, + ) + mcmc.run(key3) + samples = mcmc.get_samples() + + assert "coefs" in samples + assert samples["coefs"].shape == (num_samples, dim) + assert_allclose(jnp.mean(samples["coefs"], 0), true_coefs, atol=0.5) + + +def test_mclmc_sample_shape(): + """Test that MCLMC produces samples with expected shapes.""" + num_warmup, num_samples = 500, 500 + + def model(): + numpyro.sample("a", dist.Normal(0, 1)) + numpyro.sample("b", dist.Normal(0, 1).expand([3])) + numpyro.sample("c", dist.Normal(0, 1).expand([2, 4])) + + kernel = MCLMC(model=model) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=1, + progress_bar=False, + ) + mcmc.run(random.PRNGKey(0)) + samples = mcmc.get_samples() + + assert samples["a"].shape == (num_samples,) + assert samples["b"].shape == (num_samples, 3) + assert samples["c"].shape == (num_samples, 2, 4) + + +def test_mclmc_model_args_and_kwargs(): + """Test that model_args/model_kwargs are respected during inference.""" + true_mean = jnp.array([1.5, -0.5]) + true_scale = 0.8 + num_warmup, num_samples = 500, 1000 + + kernel = MCLMC(model=_model_with_args) + mcmc = MCMC( + kernel, + num_warmup=num_warmup, + num_samples=num_samples, + num_chains=1, + progress_bar=False, + ) + mcmc.run(random.PRNGKey(1), true_mean, scale=true_scale) + samples = mcmc.get_samples()["x"] + + assert samples.shape == (num_samples, 2) + assert_allclose(jnp.mean(samples, axis=0), true_mean, atol=0.2) + assert_allclose(jnp.std(samples, axis=0), true_scale, atol=0.2) + + +def test_mclmc_rejects_one_dimensional_latent_space(): + """Test that MCLMC rejects models with fewer than 2 latent dimensions.""" + + def one_dim_model(): + numpyro.sample("x", dist.Normal(0.0, 1.0)) + + kernel = MCLMC(model=one_dim_model) + mcmc = MCMC( + kernel, + num_warmup=10, + num_samples=10, + num_chains=1, + progress_bar=False, + ) + with pytest.raises( + ValueError, + match="target distribution must have more than 1 dimension", + ): + mcmc.run(random.PRNGKey(0)) + + +def test_mclmc_small_warmup_runs(): + """Test small warmup edge case where adaptation phases are tiny.""" + kernel = MCLMC(model=_two_dim_model) + mcmc = MCMC( + kernel, + num_warmup=3, + num_samples=20, + num_chains=1, + progress_bar=False, + ) + mcmc.run(random.PRNGKey(2)) + samples = mcmc.get_samples()["x"] + assert samples.shape == (20, 2) + + +def test_mclmc_public_properties_and_diagnostics(): + kernel = MCLMC(model=_two_dim_model) + assert kernel.model is _two_dim_model + assert kernel.sample_field == "position" + assert kernel.default_fields == ("position",) + assert kernel.get_diagnostics_str(None) == "" + kernel.adapt_state = mclmc_module.MCLMCAdaptationState( + L=jnp.array(1.2), step_size=jnp.array(0.05), inverse_mass_matrix=jnp.ones((2,)) + ) + assert "step_size=" in kernel.get_diagnostics_str(None) + assert "L=" in kernel.get_diagnostics_str(None) + + +def test_mclmc_postprocess_fn_identity_when_uninitialized(): + kernel = MCLMC(model=_two_dim_model) + fn = kernel.postprocess_fn((), {}) + x = {"z": jnp.array([1.0, 2.0])} + out = fn(x) + assert out is x + + +def test_mclmc_sample_raises_if_not_initialized(): + kernel = MCLMC(model=_two_dim_model) + state = mclmc_module.FullState( + position=jnp.array([0.0, 0.0]), + momentum=jnp.array([1.0, 0.0]), + logdensity=jnp.array(0.0), + logdensity_grad=jnp.array([0.0, 0.0]), + rng_key=random.PRNGKey(0), + ) + with pytest.raises(RuntimeError, match="must be initialized"): + kernel.sample(state, (), {}) + + +def test_mclmc_init_and_sample_direct_api(): + kernel = MCLMC(model=_two_dim_model) + state = kernel.init( + rng_key=random.PRNGKey(0), + num_warmup=30, + init_params=None, + model_args=(), + model_kwargs={}, + ) + assert isinstance(state, mclmc_module.FullState) + next_state = kernel.sample(state, (), {}) + assert isinstance(next_state, mclmc_module.FullState) + assert next_state.position["x"].shape == (2,) + + +def test_mclmc_postprocess_fn_after_init_returns_callable(): + kernel = MCLMC(model=_two_dim_model) + kernel.init( + rng_key=random.PRNGKey(1), + num_warmup=10, + init_params=None, + model_args=(), + model_kwargs={}, + ) + fn = kernel.postprocess_fn((), {}) + assert callable(fn) + + +def test_mclmc_getstate_clears_postprocess_fn(): + kernel = MCLMC(model=_two_dim_model) + kernel._postprocess_fn = lambda *args, **kwargs: lambda x: x + state = kernel.__getstate__() + assert state["_postprocess_fn"] is None + assert state["_model"] is _two_dim_model + + +def test_mclmc_adaptation_l_handles_bad_ess(monkeypatch): + """Test ESS guard keeps L finite for degenerate ESS estimates.""" + state = mclmc_module.MCLMCState( + position=jnp.array([0.0, 0.0]), + momentum=jnp.array([1.0, 0.0]), + logdensity=jnp.array(0.0), + logdensity_grad=jnp.array([0.0, 0.0]), + ) + params = mclmc_module.MCLMCAdaptationState( + L=jnp.array(1.0), + step_size=jnp.array(0.1), + inverse_mass_matrix=jnp.ones((2,)), + ) + + def dummy_kernel(rng_key, state, L, step_size): + del rng_key, L, step_size + return state, mclmc_module.MCLMCInfo( + logdensity=state.logdensity, + kinetic_change=jnp.array(0.0), + energy_change=jnp.array(0.0), + ) + + monkeypatch.setattr( + mclmc_module, + "effective_sample_size", + lambda _: jnp.array([0.0, jnp.nan, jnp.inf]), + ) + + adaptation_l = mclmc_module._make_adaptation_l( + kernel=dummy_kernel, + frac=0.5, + lfactor=0.4, + ) + _, new_params = adaptation_l( + state=state, + params=params, + num_steps=10, + rng_key=random.PRNGKey(0), + ) + assert jnp.isfinite(new_params.L) + assert new_params.L > 0