-
Notifications
You must be signed in to change notification settings - Fork 275
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
Mixture of Uniform has a different behavior than other mixtures.
Steps to Reproduce
Minimal Working Example with numpyro 0.19.0, jax 0.8.2
from numpyro import sample, distributions as dist, render_model
from jax import random as jr, numpy as jnp
def model():
mixture = dist.Mixture(
dist.Categorical(probs=jnp.array([1/3, 2/3])),
dist.Uniform(low=jnp.array([-1.0, 0.0]), high=jnp.array([0.0, 1.0]))
# dist.Normal(loc=jnp.array([-1., 1.]), scale=jnp.ones(2))
# [dist.Uniform(low=-1., high=0.), dist.Uniform(low=0., high=1.)]
)
theta = sample("theta", mixture)
print(theta.shape)
return theta
render_model(model)This example outputs
>>> (2,)
>>> (2,)when using a mixture of batched Uniform, while it is outputting the expected
>>> ()
>>> ()when using a mixture of list of Uniform or another distribution such as Normal.
Expected Behavior
The construction of a mixture of 2 Uniform distribution should not lead to a 2-variate variable theta.
Thousand thanks.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working