Skip to content

Commit

Permalink
Updated to include chain reading
Browse files Browse the repository at this point in the history
  • Loading branch information
williamjameshandley committed Mar 4, 2024
1 parent 0ec4a2c commit 0fec842
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 31 deletions.
2 changes: 1 addition & 1 deletion anesthetic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ def wrapper(backend=None):

read_hdf = anesthetic.read.hdf.read_hdf
read_chains = anesthetic.read.chain.read_chains
read_csv = anesthetic.samples.read_csv
read_csv = anesthetic.read.csv.read_csv
10 changes: 6 additions & 4 deletions anesthetic/read/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from anesthetic.read.multinest import read_multinest
from anesthetic.read.ultranest import read_ultranest
from anesthetic.read.nestedfit import read_nestedfit
from anesthetic.read.csv import read_csv


def read_chains(root, *args, **kwargs):
Expand All @@ -18,8 +19,8 @@ def read_chains(root, *args, **kwargs):
* `Nested_fit <https://github.com/martinit18/Nested_Fit>`_,
* `CosmoMC <https://github.com/cmbant/CosmoMC>`_,
* `Cobaya <https://github.com/CobayaSampler/cobaya>`_,
* or anything `GetDist <https://github.com/cmbant/getdist>`_
compatible.
* anything `GetDist <https://github.com/cmbant/getdist>`_ compatible,
* files produced using ``DataFrame.to_csv()`` from anesthetic.
Note that in order to optimally read chains from Cobaya you need to have
`GetDist <https://getdist.readthedocs.io/en/latest/>`__ installed.
Expand All @@ -40,6 +41,7 @@ def read_chains(root, *args, **kwargs):
"""
root = str(root)
# TODO: remove this in version >= 2.1
if 'burn_in' in kwargs:
raise KeyError(
"This is anesthetic 1.0 syntax. The `burn_in` keyword is no "
Expand All @@ -51,8 +53,8 @@ def read_chains(root, *args, **kwargs):
)
errors = []
readers = [
read_polychord, read_multinest, read_cobaya,
read_ultranest, read_nestedfit, read_getdist
read_polychord, read_multinest, read_cobaya, read_ultranest,
read_nestedfit, read_getdist, read_csv
]
for read in readers:
try:
Expand Down
15 changes: 15 additions & 0 deletions anesthetic/read/csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Read and write CSV files for anesthetic."""
from anesthetic.weighted_labelled_pandas import read_csv as wl_read_csv
from anesthetic.samples import MCMCSamples, NestedSamples
from pathlib import Path


def read_csv(filename, *args, **kwargs):
"""Read a CSV file into a :class:`Samples` object."""
filename = Path(filename)
kwargs['label'] = kwargs.get('label', filename.stem)
wldf = wl_read_csv(filename.with_suffix('.csv'))
if 'nlive' in wldf.columns:
return NestedSamples(wldf, *args, **kwargs)
else:
return MCMCSamples(wldf, *args, **kwargs)
12 changes: 1 addition & 11 deletions anesthetic/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,12 @@
from anesthetic.utils import (compute_nlive, compute_insertion_indexes,
is_int, logsumexp)
from anesthetic.gui.plot import RunPlotter
from anesthetic.weighted_labelled_pandas import (WeightedLabelledDataFrame,
read_csv as wl_read_csv)
from anesthetic.weighted_labelled_pandas import WeightedLabelledDataFrame
from anesthetic.plot import (make_1d_axes, make_2d_axes,
AxesSeries, AxesDataFrame)
from anesthetic.utils import adjust_docstrings


def read_csv(filename, *args, **kwargs):
"""Read a CSV file into a :class:`Samples` object."""
wldf = wl_read_csv(filename, *args, **kwargs)
if 'nlive' in wldf.columns:
return NestedSamples(wldf)
else:
return MCMCSamples(wldf)


class Samples(WeightedLabelledDataFrame):
"""Storage and plotting tools for general samples.
Expand Down
23 changes: 23 additions & 0 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from anesthetic.read.ultranest import read_ultranest
from anesthetic.read.nestedfit import read_nestedfit
from anesthetic.read.hdf import HDFStore, read_hdf
from anesthetic.read.csv import read_csv
from utils import pytables_mark_xfail, h5py_mark_xfail, getdist_mark_skip


Expand Down Expand Up @@ -307,3 +308,25 @@ def test_hdf5(tmp_path, root):
def test_path(root):
base_dir = Path("./tests/example_data")
read_chains(base_dir / root)


@pytest.mark.parametrize('root', ['pc', 'gd'])
def test_read_csv(root):
samples = read_chains(f'./tests/example_data/{root}')
samples.to_csv(f'{root}.csv')

samples_ = read_csv(f'{root}.csv')
samples_.root = samples.root
assert_frame_equal(samples, samples_)

samples_ = read_csv(f'{root}')
samples_.root = samples.root
assert_frame_equal(samples, samples_)

samples_ = read_chains(f'{root}.csv')
samples_.root = samples.root
assert_frame_equal(samples, samples_)

samples_ = read_chains(f'{root}')
samples_.root = samples.root
assert_frame_equal(samples, samples_)
16 changes: 1 addition & 15 deletions tests/test_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
Samples, MCMCSamples, NestedSamples, make_1d_axes, make_2d_axes,
read_chains
)
from anesthetic.samples import (merge_nested_samples, merge_samples_weighted,
read_csv)
from anesthetic.samples import merge_nested_samples, merge_samples_weighted
from anesthetic.weighted_labelled_pandas import (WeightedLabelledSeries,
WeightedLabelledDataFrame)
from numpy.testing import (assert_array_equal, assert_array_almost_equal,
Expand Down Expand Up @@ -1981,16 +1980,3 @@ def test_axes_limits_2d(kind, kwargs):
assert 3 < xmax < 3.9
assert -3.9 < ymin < -3
assert 3 < ymax < 3.9


def test_read_csv():
np.random.seed(3)
pc = read_chains('./tests/example_data/pc')
pc.to_csv('pc.csv')
pc_ = read_csv('pc.csv')
assert_frame_equal(pc, pc_)

mcmc = read_chains('./tests/example_data/gd')
mcmc.to_csv('gd.csv')
mcmc_ = read_csv('gd.csv')
assert_frame_equal(mcmc, mcmc_)

0 comments on commit 0fec842

Please sign in to comment.