Skip to content

Commit

Permalink
Fix for logL_birth (handley-lab#324)
Browse files Browse the repository at this point in the history
* Fix for logL_birth

* bump version to 2.1.5

* bump version to 2.1.6

* now avoiding dropna where possible and reducing code repetition

* version bump to 2.2.2

* Dropping all infs with warnings

* Corrected bump_version script

* Added a scatter kind

* fix mistakenly changed DOI in README

* check for inf directly instead of indirectly using isfinite, since nans are actually ok here

* add `pytest.warns` to tests to explicitly check for the existance of the newly added warnings and to keep our pytest output clean

* Update samples.py for pep8

* version bump to 2.5.2

* bump version to 2.7.4

* Removed obselete test for handley-lab#96

* Update README.rst version to 2.8.5

* Update _version.py to 2.8.5

* bump version to 2.8.6

* bump version to 2.8.7

* bump version to 2.8.8

* bump version to 2.8.9

---------

Co-authored-by: Lukas Hergt <[email protected]>
Co-authored-by: Lukas Hergt <[email protected]>
  • Loading branch information
3 people authored Apr 9, 2024
1 parent 16083c4 commit 79c7bb6
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 28 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
anesthetic: nested sampling post-processing
===========================================
:Authors: Will Handley and Lukas Hergt
:Version: 2.8.8
:Version: 2.8.9
:Homepage: https://github.com/handley-lab/anesthetic
:Documentation: http://anesthetic.readthedocs.io/

Expand Down
2 changes: 1 addition & 1 deletion anesthetic/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.8.8'
__version__ = '2.8.9'
32 changes: 24 additions & 8 deletions anesthetic/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,11 @@ def plot_1d(self, axes=None, *args, **kwargs):
for x, ax in axes.items():
if x in self and kwargs['kind'] is not None:
xlabel = self.get_label(x)
self[x].plot(ax=ax, xlabel=xlabel, logx=x in logx,
*args, **kwargs)
if np.isinf(self[x]).any():
warnings.warn(f"column {x} has inf values.")
selfx = self[x].replace([-np.inf, np.inf], np.nan)
selfx.plot(ax=ax, xlabel=xlabel, logx=x in logx,
*args, **kwargs)
ax.set_xlabel(xlabel)
else:
ax.plot([], [])
Expand Down Expand Up @@ -239,6 +242,9 @@ def plot_2d(self, axes=None, *args, **kwargs):
- 'hist_1d': 1d histograms down the diagonal
- 'hist_2d': 2d histograms in lower triangle
- 'hist': 1d & 2d histograms in lower & diagonal
- 'scatter_2d': 2d scatter in lower triangle
- 'scatter': 1d histograms down diagonal
& 2d scatter in lower triangle
Feel free to add your own to this list!
Default:
Expand Down Expand Up @@ -337,16 +343,24 @@ def plot_2d(self, axes=None, *args, **kwargs):
if x in self and y in self and lkwargs['kind'] is not None:
xlabel = self.get_label(x)
ylabel = self.get_label(y)
if np.isinf(self[x]).any():
warnings.warn(f"column {x} has inf values.")
if x == y:
self[x].plot(ax=ax.twin, xlabel=xlabel,
logx=x in logx,
*args, **lkwargs)
selfx = self[x].replace([-np.inf, np.inf], np.nan)
selfx.plot(ax=ax.twin, xlabel=xlabel,
logx=x in logx,
*args, **lkwargs)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
else:
self.plot(x, y, ax=ax, xlabel=xlabel,
logx=x in logx, logy=y in logy,
ylabel=ylabel, *args, **lkwargs)
if np.isinf(self[x]).any():
warnings.warn(f"column {y} has inf values.")
selfxy = self[[x, y]]
selfxy = self.replace([-np.inf, np.inf], np.nan)
selfxy = selfxy.dropna(axis=0)
selfxy.plot(x, y, ax=ax, xlabel=xlabel,
logx=x in logx, logy=y in logy,
ylabel=ylabel, *args, **lkwargs)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
else:
Expand All @@ -370,6 +384,8 @@ def plot_2d(self, axes=None, *args, **kwargs):
'hist': {'diagonal': 'hist_1d', 'lower': 'hist_2d'},
'hist_1d': {'diagonal': 'hist_1d'},
'hist_2d': {'lower': 'hist_2d'},
'scatter': {'diagonal': 'hist_1d', 'lower': 'scatter_2d'},
'scatter_2d': {'lower': 'scatter_2d'},
}

def importance_sample(self, logL_new, action='add', inplace=False):
Expand Down
5 changes: 3 additions & 2 deletions bin/bump_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

current_version = run("cat", vfile)
current_version = current_version.split("=")[-1].strip().strip("'")
escaped_version = current_version.replace(".", "\.")
current_version = version.parse(current_version)

if len(sys.argv) > 1:
Expand All @@ -33,9 +34,9 @@

for f in [vfile, README]:
if sys.platform == "darwin": # macOS sed requires empty string for backup
run("sed", "-i", "", f"s/{current_version}/{new_version}/g", f)
run("sed", "-i", "", f"s/{escaped_version}/{new_version}/g", f)
else:
run("sed", "-i", f"s/{current_version}/{new_version}/g", f)
run("sed", "-i", f"s/{escaped_version}/{new_version}/g", f)

run("git", "add", vfile, README)
run("git", "commit", "-m", f"bump version to {new_version}")
25 changes: 9 additions & 16 deletions tests/test_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,10 @@ def test_plot_2d_no_axes():
assert axes.iloc[-1, 1].get_xlabel() == 'x1'
assert axes.iloc[-1, 2].get_xlabel() == 'x2'

with pytest.warns(UserWarning):
axes = ns[['x0', 'logL_birth']].plot_2d()
axes = ns.drop_labels()[['x0', 'logL_birth']].plot_2d()


def test_plot_1d_no_axes():
np.random.seed(3)
Expand All @@ -431,6 +435,11 @@ def test_plot_1d_no_axes():
assert axes.iloc[1].get_xlabel() == 'x1'
assert axes.iloc[2].get_xlabel() == 'x2'

with pytest.warns(UserWarning):
axes = ns.plot_1d()
axes = ns[['x0', 'logL_birth']].plot_1d()
axes = ns.drop_labels()[['x0', 'logL_birth']].plot_1d()


@pytest.mark.parametrize('kind', ['kde', 'hist', skipif_no_fastkde('fastkde')])
def test_plot_logscale_1d(kind):
Expand Down Expand Up @@ -1213,22 +1222,6 @@ def test_hist_range_1d():
assert x2 >= +1


def test_contour_plot_2d_nan():
"""Contour plots with nans arising from issue #96"""
np.random.seed(3)
ns = read_chains('./tests/example_data/pc')

ns.loc[:9, ('x0', '$x_0$')] = np.nan
with pytest.raises((np.linalg.LinAlgError, RuntimeError, ValueError)):
ns.plot_2d(['x0', 'x1'])

# Check this error is removed in the case of zero weights
weights = ns.get_weights()
weights[:10] = 0
ns.set_weights(weights, inplace=True)
ns.plot_2d(['x0', 'x1'])


def test_compute_insertion():
np.random.seed(3)
ns = read_chains('./tests/example_data/pc')
Expand Down

0 comments on commit 79c7bb6

Please sign in to comment.