Skip to content

Predictive.exclude_deterministic does not filter deterministic sites #2086

@brendancooley

Description

@brendancooley

Bug Description

The Predictive helper class attribute exclude_deterministic does not successfully exclude deterministic sites from resulting posterior predictive dictionary.

Steps to Reproduce

uv run test_exclude_deterministic.py

where exclude_deterministic.py contains

#!/usr/bin/env -S uv run
# /// script
# requires-python = ">=3.9"
# dependencies = [
#     "jax",
#     "jaxlib",
#     "numpyro",
# ]
# ///
"""
Test to verify if exclude_deterministic works in numpyro.infer.util.Predictive
"""

import jax
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive


def model(x=None, y=None):
    """Simple model with both stochastic and deterministic sites"""
    # Stochastic sites
    a = numpyro.sample("a", dist.Normal(0, 1))
    b = numpyro.sample("b", dist.Normal(0, 1))

    # Deterministic site
    linear_combination = numpyro.deterministic("linear_combination", a + 2 * b)

    # Another deterministic site
    _ = numpyro.deterministic("squared", a**2)

    # Likelihood
    if y is not None:
        sigma = numpyro.sample("sigma", dist.Exponential(1.0))
        numpyro.sample("obs", dist.Normal(linear_combination, sigma), obs=y)


# Generate some synthetic data
key = jax.random.PRNGKey(0)
n_samples = 100
x_data = jnp.linspace(0, 1, n_samples)
true_a = 1.5
true_b = 2.0
y_data = true_a + 2 * true_b + jax.random.normal(key, (n_samples,)) * 0.5

# Run MCMC
mcmc = MCMC(NUTS(model), num_warmup=500, num_samples=1000)
mcmc.run(key, y=y_data)
samples = mcmc.get_samples()

# Test Predictive with exclude_deterministic=True
print("\nTest 1: Predictive with exclude_deterministic=True")
print("-" * 70)
predictive_exclude = Predictive(model, samples, exclude_deterministic=True)
predictions_exclude = predictive_exclude(key)
print("Keys with exclude_deterministic=True:", predictions_exclude.keys())
print("Expected: Should NOT include 'linear_combination' or 'squared'")

# Test Predictive with exclude_deterministic=False (default)
print("\nTest 2: Predictive with exclude_deterministic=False")
print("-" * 70)
predictive_include = Predictive(model, samples, exclude_deterministic=False)
predictions_include = predictive_include(key)
print("Keys with exclude_deterministic=False:", predictions_include.keys())
print("Expected: SHOULD include 'linear_combination' and 'squared'")

Expected Behavior

The "squared" and "linear_combination" sites should only be present in the returned value of predictive_include().

Notes

Obviously the fix here will be correct but breaking. Unless I'm missing something. Happy to work on a fix, but perhaps there's some context I'm missing.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions