From 0fec842003dd331a148c6273e8cf01bb78c2ff3b Mon Sep 17 00:00:00 2001 From: Will Handley Date: Mon, 4 Mar 2024 17:15:32 +0000 Subject: [PATCH] Updated to include chain reading --- anesthetic/__init__.py | 2 +- anesthetic/read/chain.py | 10 ++++++---- anesthetic/read/csv.py | 15 +++++++++++++++ anesthetic/samples.py | 12 +----------- tests/test_reader.py | 23 +++++++++++++++++++++++ tests/test_samples.py | 16 +--------------- 6 files changed, 47 insertions(+), 31 deletions(-) create mode 100644 anesthetic/read/csv.py diff --git a/anesthetic/__init__.py b/anesthetic/__init__.py index c2ea0d63..9f4747cb 100644 --- a/anesthetic/__init__.py +++ b/anesthetic/__init__.py @@ -49,4 +49,4 @@ def wrapper(backend=None): read_hdf = anesthetic.read.hdf.read_hdf read_chains = anesthetic.read.chain.read_chains -read_csv = anesthetic.samples.read_csv +read_csv = anesthetic.read.csv.read_csv diff --git a/anesthetic/read/chain.py b/anesthetic/read/chain.py index 3b28a743..5dc55b5f 100644 --- a/anesthetic/read/chain.py +++ b/anesthetic/read/chain.py @@ -5,6 +5,7 @@ from anesthetic.read.multinest import read_multinest from anesthetic.read.ultranest import read_ultranest from anesthetic.read.nestedfit import read_nestedfit +from anesthetic.read.csv import read_csv def read_chains(root, *args, **kwargs): @@ -18,8 +19,8 @@ def read_chains(root, *args, **kwargs): * `Nested_fit `_, * `CosmoMC `_, * `Cobaya `_, - * or anything `GetDist `_ - compatible. + * anything `GetDist `_ compatible, + * files produced using ``DataFrame.to_csv()`` from anesthetic. Note that in order to optimally read chains from Cobaya you need to have `GetDist `__ installed. @@ -40,6 +41,7 @@ def read_chains(root, *args, **kwargs): """ root = str(root) + # TODO: remove this in version >= 2.1 if 'burn_in' in kwargs: raise KeyError( "This is anesthetic 1.0 syntax. The `burn_in` keyword is no " @@ -51,8 +53,8 @@ def read_chains(root, *args, **kwargs): ) errors = [] readers = [ - read_polychord, read_multinest, read_cobaya, - read_ultranest, read_nestedfit, read_getdist + read_polychord, read_multinest, read_cobaya, read_ultranest, + read_nestedfit, read_getdist, read_csv ] for read in readers: try: diff --git a/anesthetic/read/csv.py b/anesthetic/read/csv.py new file mode 100644 index 00000000..3dd2713e --- /dev/null +++ b/anesthetic/read/csv.py @@ -0,0 +1,15 @@ +"""Read and write CSV files for anesthetic.""" +from anesthetic.weighted_labelled_pandas import read_csv as wl_read_csv +from anesthetic.samples import MCMCSamples, NestedSamples +from pathlib import Path + + +def read_csv(filename, *args, **kwargs): + """Read a CSV file into a :class:`Samples` object.""" + filename = Path(filename) + kwargs['label'] = kwargs.get('label', filename.stem) + wldf = wl_read_csv(filename.with_suffix('.csv')) + if 'nlive' in wldf.columns: + return NestedSamples(wldf, *args, **kwargs) + else: + return MCMCSamples(wldf, *args, **kwargs) diff --git a/anesthetic/samples.py b/anesthetic/samples.py index 0e018efc..182a75ff 100644 --- a/anesthetic/samples.py +++ b/anesthetic/samples.py @@ -14,22 +14,12 @@ from anesthetic.utils import (compute_nlive, compute_insertion_indexes, is_int, logsumexp) from anesthetic.gui.plot import RunPlotter -from anesthetic.weighted_labelled_pandas import (WeightedLabelledDataFrame, - read_csv as wl_read_csv) +from anesthetic.weighted_labelled_pandas import WeightedLabelledDataFrame from anesthetic.plot import (make_1d_axes, make_2d_axes, AxesSeries, AxesDataFrame) from anesthetic.utils import adjust_docstrings -def read_csv(filename, *args, **kwargs): - """Read a CSV file into a :class:`Samples` object.""" - wldf = wl_read_csv(filename, *args, **kwargs) - if 'nlive' in wldf.columns: - return NestedSamples(wldf) - else: - return MCMCSamples(wldf) - - class Samples(WeightedLabelledDataFrame): """Storage and plotting tools for general samples. diff --git a/tests/test_reader.py b/tests/test_reader.py index 134b2cac..e711dad1 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -15,6 +15,7 @@ from anesthetic.read.ultranest import read_ultranest from anesthetic.read.nestedfit import read_nestedfit from anesthetic.read.hdf import HDFStore, read_hdf +from anesthetic.read.csv import read_csv from utils import pytables_mark_xfail, h5py_mark_xfail, getdist_mark_skip @@ -307,3 +308,25 @@ def test_hdf5(tmp_path, root): def test_path(root): base_dir = Path("./tests/example_data") read_chains(base_dir / root) + + +@pytest.mark.parametrize('root', ['pc', 'gd']) +def test_read_csv(root): + samples = read_chains(f'./tests/example_data/{root}') + samples.to_csv(f'{root}.csv') + + samples_ = read_csv(f'{root}.csv') + samples_.root = samples.root + assert_frame_equal(samples, samples_) + + samples_ = read_csv(f'{root}') + samples_.root = samples.root + assert_frame_equal(samples, samples_) + + samples_ = read_chains(f'{root}.csv') + samples_.root = samples.root + assert_frame_equal(samples, samples_) + + samples_ = read_chains(f'{root}') + samples_.root = samples.root + assert_frame_equal(samples, samples_) diff --git a/tests/test_samples.py b/tests/test_samples.py index 1b4b1280..a40b85e5 100644 --- a/tests/test_samples.py +++ b/tests/test_samples.py @@ -13,8 +13,7 @@ Samples, MCMCSamples, NestedSamples, make_1d_axes, make_2d_axes, read_chains ) -from anesthetic.samples import (merge_nested_samples, merge_samples_weighted, - read_csv) +from anesthetic.samples import merge_nested_samples, merge_samples_weighted from anesthetic.weighted_labelled_pandas import (WeightedLabelledSeries, WeightedLabelledDataFrame) from numpy.testing import (assert_array_equal, assert_array_almost_equal, @@ -1981,16 +1980,3 @@ def test_axes_limits_2d(kind, kwargs): assert 3 < xmax < 3.9 assert -3.9 < ymin < -3 assert 3 < ymax < 3.9 - - -def test_read_csv(): - np.random.seed(3) - pc = read_chains('./tests/example_data/pc') - pc.to_csv('pc.csv') - pc_ = read_csv('pc.csv') - assert_frame_equal(pc, pc_) - - mcmc = read_chains('./tests/example_data/gd') - mcmc.to_csv('gd.csv') - mcmc_ = read_csv('gd.csv') - assert_frame_equal(mcmc, mcmc_)