Skip to content

Commit

Permalink
Merge branch 'master' into ndes
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamOrmondroyd authored Dec 12, 2023
2 parents 9a7c3e9 + 7c30e72 commit ecdd875
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
anesthetic: nested sampling post-processing
===========================================
:Authors: Will Handley and Lukas Hergt
:Version: 2.5.1
:Version: 2.6.0
:Homepage: https://github.com/handley-lab/anesthetic
:Documentation: http://anesthetic.readthedocs.io/

Expand Down
2 changes: 1 addition & 1 deletion anesthetic/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.5.1'
__version__ = '2.6.0'
10 changes: 6 additions & 4 deletions anesthetic/kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def fastkde_1d(d, xmin=None, xmax=None):
d_ = mirror_1d(d, xmin, xmax)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
p, x = fastKDE.pdf(d_, axisExpansionFactor=f,
numPointsPerSigma=10*(2-f))
p, x = fastKDE.pdf(d_, axis_expansion_factor=f,
num_points_per_sigma=10*(2-f),
use_xarray=False)
p *= 2-f

if xmin is not None:
Expand Down Expand Up @@ -79,8 +80,9 @@ def fastkde_2d(d_x, d_y, xmin=None, xmax=None, ymin=None, ymax=None):

with warnings.catch_warnings():
warnings.simplefilter("ignore")
p, (x, y) = fastKDE.pdf(d_x_, d_y_, axisExpansionFactor=f,
numPointsPerSigma=10*(2-f[0])*(2-f[1]))
p, (x, y) = fastKDE.pdf(d_x_, d_y_, axis_expansion_factor=f,
num_points_per_sigma=10*(2-f[0])*(2-f[1]),
use_xarray=False)

p *= (2-f[0])
p *= (2-f[1])
Expand Down
87 changes: 79 additions & 8 deletions anesthetic/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,8 +1171,38 @@ def logL_P(self, nsamples=None, beta=None):

logL_P.__doc__ += _logZ_function_shape

def contour(self, logL=None):
"""Convert contour from (index or None) to a float loglikelihood.
Convention is that live points are inclusive of the contour.
Helper function for:
- NestedSamples.live_points,
- NestedSamples.dead_points,
- NestedSamples.truncate.
Parameters
----------
logL : float or int, optional
Loglikelihood or iteration number
If not provided, return the contour containing the last set of
live points.
Returns
-------
logL : float
Loglikelihood of contour
"""
if logL is None:
logL = self.loc[self.logL > self.logL_birth.max()].logL.iloc[0]
elif isinstance(logL, float):
pass
else:
logL = float(self.logL[logL])
return logL

def live_points(self, logL=None):
"""Get the live points within logL.
"""Get the live points within a contour.
Parameters
----------
Expand All @@ -1188,16 +1218,57 @@ def live_points(self, logL=None):
- ith iteration (if input is integer)
- last set of live points if no argument provided
"""
if logL is None:
logL = self.logL_birth.max()
else:
try:
logL = float(self.logL[logL])
except KeyError:
pass
logL = self.contour(logL)
i = ((self.logL >= logL) & (self.logL_birth < logL)).to_numpy()
return Samples(self[i]).set_weights(None)

def dead_points(self, logL=None):
"""Get the dead points at a given contour.
Convention is that dead points are exclusive of the contour.
Parameters
----------
logL : float or int, optional
Loglikelihood or iteration number to return dead points.
If not provided, return the last set of dead points.
Returns
-------
dead_points : Samples
Dead points at either:
- contour logL (if input is float)
- ith iteration (if input is integer)
- last set of dead points if no argument provided
"""
logL = self.contour(logL)
i = ((self.logL < logL)).to_numpy()
return Samples(self[i]).set_weights(None)

def truncate(self, logL=None):
"""Truncate the run at a given contour.
Returns the union of the live_points and dead_points.
Parameters
----------
logL : float or int, optional
Loglikelihood or iteration number to truncate run.
If not provided, truncate at the last set of dead points.
Returns
-------
truncated_run : NestedSamples
Run truncated at either:
- contour logL (if input is float)
- ith iteration (if input is integer)
- last set of dead points if no argument provided
"""
dead_points = self.dead_points(logL)
live_points = self.live_points(logL)
index = np.concatenate([dead_points.index, live_points.index])
return self.loc[index].recompute()

def posterior_points(self, beta=1):
"""Get equally weighted posterior points at temperature beta."""
return self.set_beta(beta).compress('equal')
Expand Down
62 changes: 62 additions & 0 deletions tests/test_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,68 @@ def test_live_points():
assert not live_points.isweighted()


def test_dead_points():
np.random.seed(4)
pc = read_chains("./tests/example_data/pc")

for i, logL in pc.logL.iloc[::49].items():
dead_points = pc.dead_points(logL)
assert len(dead_points) == int(len(pc[:i[0]]))

dead_points_from_int = pc.dead_points(i[0])
assert_array_equal(dead_points_from_int, dead_points)

dead_points_from_index = pc.dead_points(i)
assert_array_equal(dead_points_from_index, dead_points)

assert pc.dead_points(1).index[0] == 0

last_dead_points = pc.dead_points()
logL = pc.logL_birth.max()
assert (last_dead_points.logL <= logL).all()
assert len(last_dead_points) == len(pc) - pc.nlive.mode().to_numpy()[0]
assert not dead_points.isweighted()


def test_contour():
np.random.seed(4)
pc = read_chains("./tests/example_data/pc")

cut_float = 30.0
assert cut_float == pc.contour(cut_float)

cut_int = 0
assert pc.logL.min() == pc.contour(cut_int)

cut_none = None
nlive = pc.nlive.mode().to_numpy()[0]
assert sorted(pc.logL)[-nlive] == pc.contour(cut_none)


@pytest.mark.parametrize("cut", [200, 0.0, None])
def test_truncate(cut):
np.random.seed(4)
pc = read_chains("./tests/example_data/pc")
truncated_run = pc.truncate(cut)
assert not truncated_run.index.duplicated().any()
if cut is None:
assert_array_equal(pc, truncated_run)


def test_hist_range_1d():
"""Test to provide a solution to #89"""
np.random.seed(3)
ns = read_chains('./tests/example_data/pc')
ax = ns.plot_1d('x0', kind='hist_1d')
x1, x2 = ax['x0'].get_xlim()
assert x1 > -1
assert x2 < +1
ax = ns.plot_1d('x0', kind='hist_1d', bins=np.linspace(-1, 1, 11))
x1, x2 = ax['x0'].get_xlim()
assert x1 <= -1
assert x2 >= +1


def test_contour_plot_2d_nan():
"""Contour plots with nans arising from issue #96"""
np.random.seed(3)
Expand Down

0 comments on commit ecdd875

Please sign in to comment.