Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zaxtax committed Jan 19, 2025
1 parent e33e517 commit e895a5c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ def sample_posterior_predictive(
else:
vars_ = model.observed_RVs + observed_dependent_deterministics(model)
if observed_data is not None:
vars_ += [model[x] for x in observed_data if x in model]
vars_ += [model[x] for x in observed_data if x in model and x not in vars_]

vars_to_sample = list(get_default_varnames(vars_, include_transformed=False))

Expand Down
55 changes: 45 additions & 10 deletions tests/sampling/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,16 +481,6 @@ def test_normal_scalar(self):
chains=nchains,
)

# test that trace is used in ppc
with pm.Model() as model_ppc:
mu = pm.Normal("mu", 0.0, 1.0)
a = pm.Normal("a", mu=mu, sigma=1)

ppc = pm.sample_posterior_predictive(
trace=trace, model=model_ppc, return_inferencedata=False
)
assert "a" in ppc

with model:
# test list input
ppc0 = pm.sample_posterior_predictive(
Expand Down Expand Up @@ -550,6 +540,51 @@ def test_normal_scalar_idata(self):
ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False)
assert ppc["a"].shape == (nchains, ndraws)

def test_external_trace(self):
nchains = 2
ndraws = 500
with pm.Model() as model:
mu = pm.Normal("mu", 0.0, 1.0)
a = pm.Normal("a", mu=mu, sigma=1, observed=0.0)
trace = pm.sample(
draws=ndraws,
chains=nchains,
)

# test that trace is used in ppc
with pm.Model() as model_ppc:
mu = pm.Normal("mu", 0.0, 1.0)
a = pm.Normal("a", mu=mu, sigma=1)

ppc = pm.sample_posterior_predictive(
trace=trace, model=model_ppc, return_inferencedata=False
)
assert list(ppc.keys()) == ["a"]

@pytest.mark.xfail(reason="Auto-imputation of variables not supported in this setting")
def test_external_trace_det(self):
nchains = 2
ndraws = 500
with pm.Model() as model:
mu = pm.Normal("mu", 0.0, 1.0)
a = pm.Normal("a", mu=mu, sigma=1, observed=0.0)
b = pm.Deterministic("b", a + 1)
trace = pm.sample(
draws=ndraws,
chains=nchains,
)

# test that trace is used in ppc
with pm.Model() as model_ppc:
mu = pm.Normal("mu", 0.0, 1.0)
a = pm.Normal("a", mu=mu, sigma=1)
b = pm.Deterministic("b", a + 1)

ppc = pm.sample_posterior_predictive(
trace=trace, model=model_ppc, return_inferencedata=False
)
assert list(ppc.keys()) == ["a", "b"]

def test_normal_vector(self):
with pm.Model() as model:
mu = pm.Normal("mu", 0.0, 1.0)
Expand Down

0 comments on commit e895a5c

Please sign in to comment.