Skip to content

Commit

Permalink
use inbuilts
Browse files Browse the repository at this point in the history
  • Loading branch information
yallup committed Mar 6, 2024
1 parent e414992 commit a34b3d6
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 34 deletions.
73 changes: 47 additions & 26 deletions anesthetic/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,42 +1192,63 @@ def truncate(self, logL=None):
index = np.concatenate([dead_points.index, live_points.index])
return self.loc[index].recompute()

def terminate(self, eps=1e-3, logL=None, n=None):
"""Check if a set of samples has reached a termination criterion.
Uses the termination criterion of
[Handley et al. 2015](https://arxiv.org/abs/1506.00171).
computes if the ratio of evidence in the live to dead points is less
than some precision.
def critical_ratio(
self, nsamples=None, logL=None, beta=1.0, criteria="logZ"
):
"""Compute a critical ratio between the live and dead points.
Parameters
----------
eps : float, optional
The precision of the criteria.
default: 1e-3
nsamples : int, optional
Number of samples to draw when computing the volume estimates
default: None
logL : float or int, optional
Loglikelihood or iteration number to truncate run.
If not provided, truncate at the last set of dead points.
default: None
n : int, optional
Number of samples to draw when computing the volume estimates
default: None
beta : float, optional
Inverse temperature to set.
default: 1.0
criteria : str, optional
Criteria to compute the critical ratio. Can be one of
{'logZ', 'D_KL'}.
default: 'logZ'
"""
available_criteria = {
"logZ": NestedSamples.logZ,
"D_KL": NestedSamples.D_KL,
}
if criteria not in available_criteria.keys():
raise KeyError(
f"Criteria must be one of {list(available_criteria.keys())}"
)
else:
criteria = available_criteria[criteria]
logL = self.contour(logL)
i_live = ((self.logL >= logL) & (self.logL_birth < logL)).to_numpy()
i_dead = ((self.logL < logL)).to_numpy()
if np.any(i_dead):
logZ_dead = self[i_dead].recompute().logZ(n).mean()
logX_dead = self[i_dead].recompute().logX(n).iloc[-1]
i_live = self.live_points(logL).index
i_dead = self.dead_points(logL).index
if len(i_dead) > 0:
logX_dead = self.iloc[i_dead].recompute().logX(nsamples).iloc[-1]
else:
# logZ if no dead points
logZ_dead = -1e30
logX_dead = -1e30
logZ_live = self[i_live].recompute().logZ(n).mean() + logX_dead
return logZ_live - np.logaddexp(logZ_live, logZ_dead) < np.log(eps)
logX_dead = 0.0
criteria_dead = criteria(self.set_beta(beta), nsamples)
criteria_live = (
criteria(self.iloc[i_live].set_beta(beta), nsamples) + logX_dead
)
return criteria_live - np.logaddexp(criteria_live, criteria_dead)

def is_terminated(self, eps=1e-3, **kwargs):
"""Check if a simulated run has terminated. Computes a critical ratio.
Parameters
----------
eps : float, optional
The precision of the criteria.
default: 1e-3
kwargs : dict, optional (see NestedSamples.critical_ratio)
"""
crit = self.critical_ratio(**kwargs).mean()
return crit < np.log(eps)

def posterior_points(self, beta=1):
"""Get equally weighted posterior points at temperature beta."""
Expand Down
44 changes: 36 additions & 8 deletions tests/test_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,19 +1180,47 @@ def test_truncate(cut):
assert_array_equal(pc, truncated_run)


def test_terminate():
@pytest.mark.parametrize("crit", ["logZ", "D_KL"])
def test_critical_ratio(crit):
pc = read_chains("./tests/example_data/pc")
with pytest.raises(KeyError):
pc.critical_ratio(criteria="badarg")

default_crit = pc.critical_ratio(criteria=crit)
logl_crit = pc.critical_ratio(logL=-1.0, criteria=crit)
beta_crit = pc.critical_ratio(beta=0.5, criteria=crit)

assert isinstance(default_crit, float)
assert default_crit != logl_crit
assert default_crit != beta_crit

default_crit_ens = pc.critical_ratio(nsamples=100, criteria=crit)
logl_crit_ens = pc.critical_ratio(nsamples=100, logL=-1.0, criteria=crit)
beta_crit_ens = pc.critical_ratio(nsamples=100, beta=0.5, criteria=crit)

assert isinstance(default_crit_ens, WeightedSeries)
assert np.isclose(
default_crit_ens.mean(), default_crit, atol=2 * default_crit_ens.std()
)
assert (default_crit_ens != logl_crit_ens).all()
assert (default_crit_ens != beta_crit_ens).all()


@pytest.mark.parametrize("crit", ["logZ", "D_KL"])
def test_is_terminated(crit):
np.random.seed(4)
pc = read_chains("./tests/example_data/pc")

assert not pc.terminate(logL=0)
# assert not pc.is_terminated(logL=0,criteria = crit)
# assert not pc.is_terminated(logL=1,criteria = crit)

assert not pc.terminate(logL=200)
assert not pc.terminate(logL=0.0)
assert pc.terminate(logL=None)
assert not pc.is_terminated(logL=200, criteria=crit)
assert not pc.is_terminated(logL=0.0, criteria=crit)
assert pc.is_terminated(logL=None, criteria=crit)

assert pc.terminate(logL=200, eps=1.0)
assert pc.terminate(logL=0.0, eps=1.0)
assert pc.terminate(logL=None, eps=1.0)
assert pc.is_terminated(logL=200, eps=1.0, criteria=crit)
assert pc.is_terminated(logL=0.0, eps=1.0, criteria=crit)
assert pc.is_terminated(logL=None, eps=1.0, criteria=crit)


def test_hist_range_1d():
Expand Down

0 comments on commit a34b3d6

Please sign in to comment.