-
Notifications
You must be signed in to change notification settings - Fork 271
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
The Predictive helper class attribute exclude_deterministic does not successfully exclude deterministic sites from resulting posterior predictive dictionary.
Steps to Reproduce
uv run test_exclude_deterministic.pywhere exclude_deterministic.py contains
#!/usr/bin/env -S uv run
# /// script
# requires-python = ">=3.9"
# dependencies = [
# "jax",
# "jaxlib",
# "numpyro",
# ]
# ///
"""
Test to verify if exclude_deterministic works in numpyro.infer.util.Predictive
"""
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
def model(x=None, y=None):
"""Simple model with both stochastic and deterministic sites"""
# Stochastic sites
a = numpyro.sample("a", dist.Normal(0, 1))
b = numpyro.sample("b", dist.Normal(0, 1))
# Deterministic site
linear_combination = numpyro.deterministic("linear_combination", a + 2 * b)
# Another deterministic site
_ = numpyro.deterministic("squared", a**2)
# Likelihood
if y is not None:
sigma = numpyro.sample("sigma", dist.Exponential(1.0))
numpyro.sample("obs", dist.Normal(linear_combination, sigma), obs=y)
# Generate some synthetic data
key = jax.random.PRNGKey(0)
n_samples = 100
x_data = jnp.linspace(0, 1, n_samples)
true_a = 1.5
true_b = 2.0
y_data = true_a + 2 * true_b + jax.random.normal(key, (n_samples,)) * 0.5
# Run MCMC
mcmc = MCMC(NUTS(model), num_warmup=500, num_samples=1000)
mcmc.run(key, y=y_data)
samples = mcmc.get_samples()
# Test Predictive with exclude_deterministic=True
print("\nTest 1: Predictive with exclude_deterministic=True")
print("-" * 70)
predictive_exclude = Predictive(model, samples, exclude_deterministic=True)
predictions_exclude = predictive_exclude(key)
print("Keys with exclude_deterministic=True:", predictions_exclude.keys())
print("Expected: Should NOT include 'linear_combination' or 'squared'")
# Test Predictive with exclude_deterministic=False (default)
print("\nTest 2: Predictive with exclude_deterministic=False")
print("-" * 70)
predictive_include = Predictive(model, samples, exclude_deterministic=False)
predictions_include = predictive_include(key)
print("Keys with exclude_deterministic=False:", predictions_include.keys())
print("Expected: SHOULD include 'linear_combination' and 'squared'")Expected Behavior
The "squared" and "linear_combination" sites should only be present in the returned value of predictive_include().
Notes
Obviously the fix here will be correct but breaking. Unless I'm missing something. Happy to work on a fix, but perhaps there's some context I'm missing.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working