Skip to content

Commit

Permalink
Add logjac to logdensity_fn (#751)
Browse files Browse the repository at this point in the history
* Add logjac to logdensity_fn

* Refactor logprior_fn in SMCLinearRegressionTestCase
  • Loading branch information
junpenglao authored Oct 30, 2024
1 parent b107f9f commit 65ae00e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
3 changes: 2 additions & 1 deletion docs/examples/howto_sample_multiple_chains.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ observed = np.random.normal(loc, scale, size=1_000)
def logdensity_fn(loc, log_scale, observed=observed):
"""Univariate Normal"""
scale = jnp.exp(log_scale)
logjac = log_scale
logpdf = stats.norm.logpdf(observed, loc, scale)
return jnp.sum(logpdf)
return logjac + jnp.sum(logpdf)
def logdensity(x):
Expand Down
3 changes: 2 additions & 1 deletion docs/examples/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ observed = np.random.normal(loc, scale, size=1_000)
def logdensity_fn(loc, log_scale, observed=observed):
"""Univariate Normal"""
scale = jnp.exp(log_scale)
logjac = log_scale
logpdf = stats.norm.logpdf(observed, loc, scale)
return jnp.sum(logpdf)
return logjac + jnp.sum(logpdf)
logdensity = lambda x: logdensity_fn(**x)
Expand Down
11 changes: 5 additions & 6 deletions tests/smc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def logdensity_fn(self, log_scale, coefs, preds, x):
logpdf = self.logdensity_by_observation(log_scale, coefs, preds, x)
return jnp.sum(logpdf)

def logprior_fn(self, log_scale, coefs):
return log_scale + stats.norm.logpdf(log_scale) + stats.norm.logpdf(coefs)

def observations(self):
num_particles = 100

Expand All @@ -27,9 +30,7 @@ def observations(self):
def particles_prior_loglikelihood(self):
observations, num_particles = self.observations()

logprior_fn = lambda x: stats.norm.logpdf(x["log_scale"]) + stats.norm.logpdf(
x["coefs"]
)
logprior_fn = lambda x: self.logprior_fn(**x)
loglikelihood_fn = lambda x: self.logdensity_fn(**x, **observations)

log_scale_init = np.random.randn(num_particles)
Expand All @@ -45,9 +46,7 @@ def partial_posterior_test_case(self):
y_data = 3 * x_data + np.random.normal(size=x_data.shape)
observations = {"x": x_data, "preds": y_data}

logprior_fn = lambda x: stats.norm.logpdf(x["log_scale"]) + stats.norm.logpdf(
x["coefs"]
)
logprior_fn = lambda x: self.logprior_fn(**x)

log_scale_init = np.random.randn(num_particles)
coeffs_init = np.random.randn(num_particles)
Expand Down

0 comments on commit 65ae00e

Please sign in to comment.