diff --git a/docs/examples/howto_sample_multiple_chains.md b/docs/examples/howto_sample_multiple_chains.md index a5b6566f8..c2947e29f 100644 --- a/docs/examples/howto_sample_multiple_chains.md +++ b/docs/examples/howto_sample_multiple_chains.md @@ -57,8 +57,9 @@ observed = np.random.normal(loc, scale, size=1_000) def logdensity_fn(loc, log_scale, observed=observed): """Univariate Normal""" scale = jnp.exp(log_scale) + logjac = log_scale logpdf = stats.norm.logpdf(observed, loc, scale) - return jnp.sum(logpdf) + return logjac + jnp.sum(logpdf) def logdensity(x): diff --git a/docs/examples/quickstart.md b/docs/examples/quickstart.md index 870e5df9a..a290bfdad 100644 --- a/docs/examples/quickstart.md +++ b/docs/examples/quickstart.md @@ -48,8 +48,9 @@ observed = np.random.normal(loc, scale, size=1_000) def logdensity_fn(loc, log_scale, observed=observed): """Univariate Normal""" scale = jnp.exp(log_scale) + logjac = log_scale logpdf = stats.norm.logpdf(observed, loc, scale) - return jnp.sum(logpdf) + return logjac + jnp.sum(logpdf) logdensity = lambda x: logdensity_fn(**x) diff --git a/tests/smc/__init__.py b/tests/smc/__init__.py index 006d7ba38..8f5a39d6f 100644 --- a/tests/smc/__init__.py +++ b/tests/smc/__init__.py @@ -16,6 +16,9 @@ def logdensity_fn(self, log_scale, coefs, preds, x): logpdf = self.logdensity_by_observation(log_scale, coefs, preds, x) return jnp.sum(logpdf) + def logprior_fn(self, log_scale, coefs): + return log_scale + stats.norm.logpdf(log_scale) + stats.norm.logpdf(coefs) + def observations(self): num_particles = 100 @@ -27,9 +30,7 @@ def observations(self): def particles_prior_loglikelihood(self): observations, num_particles = self.observations() - logprior_fn = lambda x: stats.norm.logpdf(x["log_scale"]) + stats.norm.logpdf( - x["coefs"] - ) + logprior_fn = lambda x: self.logprior_fn(**x) loglikelihood_fn = lambda x: self.logdensity_fn(**x, **observations) log_scale_init = np.random.randn(num_particles) @@ -45,9 +46,7 @@ def partial_posterior_test_case(self): y_data = 3 * x_data + np.random.normal(size=x_data.shape) observations = {"x": x_data, "preds": y_data} - logprior_fn = lambda x: stats.norm.logpdf(x["log_scale"]) + stats.norm.logpdf( - x["coefs"] - ) + logprior_fn = lambda x: self.logprior_fn(**x) log_scale_init = np.random.randn(num_particles) coeffs_init = np.random.randn(num_particles)