Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add reader for dnest4 #391

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
anesthetic: nested sampling post-processing
===========================================
:Authors: Will Handley and Lukas Hergt
:Version: 2.8.14
:Version: 2.9.0
:Homepage: https://github.com/handley-lab/anesthetic
:Documentation: http://anesthetic.readthedocs.io/

Expand Down
1 change: 1 addition & 0 deletions anesthetic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def wrapper(backend=None):
Samples = anesthetic.samples.Samples
MCMCSamples = anesthetic.samples.MCMCSamples
NestedSamples = anesthetic.samples.NestedSamples
DiffusiveNestedSamples = anesthetic.samples.DiffusiveNestedSamples
make_2d_axes = anesthetic.plot.make_2d_axes
make_1d_axes = anesthetic.plot.make_1d_axes

Expand Down
2 changes: 1 addition & 1 deletion anesthetic/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.8.14'
__version__ = '2.9.0'
34 changes: 21 additions & 13 deletions anesthetic/gui/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,14 @@ class RunPlotter(object):
param_choice : :class:`anesthetic.gui.widgets.CheckButtons`
Checkbox that selects which parameters to plot.

plot_color : str
color for plots and color maps

"""

def __init__(self, samples, params=None):
def __init__(self, samples, params=None, color='k'):
self.samples = samples
self.color = color

if params:
self.params = np.array(params)
Expand Down Expand Up @@ -199,7 +203,8 @@ def _set_up(self):
gs0 = sGS(1, 2, width_ratios=[19, 1], subplot_spec=gs[0])
gs1 = sGS(1, 3, width_ratios=[4, 1, 1], subplot_spec=gs[1])
gs10 = sGS(2, 1, height_ratios=[1, 4], subplot_spec=gs1[0])
gs11 = sGS(3, 1, height_ratios=[1, 1, 2], subplot_spec=gs1[1])
gs11 = sGS(3, 1, height_ratios=[1, 1, len(self.samples.plot_types())],
subplot_spec=gs1[1])

self.triangle = TrianglePlot(self.fig, gs0[0])
beta = np.logspace(-10, 10, 101)
Expand All @@ -213,14 +218,15 @@ def _set_up(self):
self.reload = Button(self.fig, gs11[1],
self.reload_file, 'Reload File')
self.type = RadioButtons(self.fig, gs11[2],
('live', 'posterior'), self.update)
self.samples.plot_types(), self.update)
self.param_choice = CheckButtons(self.fig, gs1[2],
self.params, self.redraw)

def redraw(self, _):
"""Redraw the triangle plot upon parameter updating."""
self.triangle.draw(self.param_choice(),
self.samples.get_labels_map())
self.samples.get_labels_map(),
self.color)
self.update(None)
self.reset_range(None)
self.fig.tight_layout()
Expand All @@ -238,25 +244,27 @@ def points(self, label):
-------
array-like:
sample 'label'-coordinates.
array-like:
colors to use for plotting

"""
if self.type() == 'posterior':
beta = self.beta()
return self.samples.posterior_points(beta)[label]
else:
i = self.evolution()
logL = self.samples.logL.iloc[i]
return self.samples.live_points(logL)[label]
return self.samples.points_to_plot(
plot_type=self.type(),
label=label,
evolution=self.evolution(),
beta=self.beta(),
base_color=self.color
)

def update(self, _):
"""Update all the plots upon slider changes."""
logX = np.log(self.samples.nlive / (self.samples.nlive+1)).cumsum()
beta = self.beta()
LX = self.samples.logL*beta + logX
LX = self.samples.LX(beta, logX)
LX = np.exp(LX-LX.max())
i = self.evolution()
logL = self.samples.logL.iloc[i]
n = self.samples.nlive.iloc[i]
n = self.samples.n_live(i)

self.triangle.update(self.points)

Expand Down
37 changes: 28 additions & 9 deletions anesthetic/gui/widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __init__(self, fig, gridspec):
self.fig.delaxes(self.ax)
_, self.ax = make_2d_axes([], fig=self.fig, subplot_spec=self.gridspec)

def draw(self, params, labels={}):
def draw(self, params, labels={}, color='k'):
"""Draw a new triangular grid for list of parameters.

Parameters
Expand All @@ -252,9 +252,9 @@ def draw(self, params, labels={}):
for x, ax in row.items():
if ax is not None:
if x == y:
ax.twin.plot([None], [None], 'k-')
ax.twin.plot([None], [None], '-', color=color)
else:
ax.plot([None], [None], 'k.')
ax.plot([None], [None], '.', color=color)

def update(self, f):
"""Update the points in the triangle plot using f function.
Expand All @@ -269,13 +269,32 @@ def update(self, f):
for y, row in self.ax.iterrows():
for x, ax in row.items():
if ax is not None:
if x == y:
datx, daty = histogram(f(x), bins='auto')
ax.twin.lines[0].set_xdata(datx)
ax.twin.lines[0].set_ydata(daty)
x_values, colors = f(x)
if x == y and len(ax.twin.lines) == len(colors):
# if there are as many lines as colors,
# the lines can be updated efficiently
for i, color in enumerate(colors):
datx, daty = histogram(x_values[i], bins='auto')
ax.twin.lines[i].set_xdata(datx)
ax.twin.lines[i].set_ydata(daty)
ax.twin.lines[i].set_color(color)
elif x == y and len(ax.twin.lines) != len(colors):
# if there are NOT as many lines as colors,
# the lines need to be replaced
ax.twin.clear()
for i, color in enumerate(colors):
ax.twin.hist(x_values[i], color=color)
elif len(ax.lines) == len(colors):
y_values, _ = f(y)
for i, color in enumerate(colors):
ax.lines[i].set_xdata(x_values[i])
ax.lines[i].set_ydata(y_values[i])
ax.lines[i].set_color(color)
else:
ax.lines[0].set_xdata(f(x))
ax.lines[0].set_ydata(f(y))
ax.clear()
y_values, _ = f(y)
for i, color in enumerate(colors):
ax.plot(x_values[i], y_values[i], 'o', color=color)

def reset_range(self):
"""Reset the range of each grid."""
Expand Down
4 changes: 3 additions & 1 deletion anesthetic/read/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from anesthetic.read.multinest import read_multinest
from anesthetic.read.ultranest import read_ultranest
from anesthetic.read.nestedfit import read_nestedfit
from anesthetic.read.dnest4 import read_dnest4
from anesthetic.read.csv import read_csv


Expand All @@ -19,6 +20,7 @@ def read_chains(root, *args, **kwargs):
* `Nested_fit <https://github.com/martinit18/Nested_Fit>`_,
* `CosmoMC <https://github.com/cmbant/CosmoMC>`_,
* `Cobaya <https://github.com/CobayaSampler/cobaya>`_,
* `DNest4 <https://github.com/eggplantbren/DNest4/>`_,
* anything `GetDist <https://github.com/cmbant/getdist>`_ compatible,
* files produced using ``DataFrame.to_csv()`` from anesthetic.

Expand Down Expand Up @@ -54,7 +56,7 @@ def read_chains(root, *args, **kwargs):
errors = []
readers = [
read_polychord, read_multinest, read_cobaya, read_ultranest,
read_nestedfit, read_getdist, read_csv
read_dnest4, read_nestedfit, read_getdist, read_csv
]
for read in readers:
try:
Expand Down
80 changes: 80 additions & 0 deletions anesthetic/read/dnest4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Read NestedSamples from dnest4 output files."""
import os
import numpy as np
from anesthetic.samples import DiffusiveNestedSamples


def _determine_columns(n_params, header, delim=' '):
"""
Determine column names from DNest4 output.

If none are given by DNest4, parameters are named x_i.
' ' is the default delimiter in DNest4.
"""
dnest4_column_descriptions = header[1:].lstrip().split(delim)
if len(dnest4_column_descriptions) != n_params:
# header can not have contained column names
columns = [f'x_{i}' for i in range(n_params)]
else:
columns = [d.strip() for d in dnest4_column_descriptions]
return columns


def read_dnest4(root,
levels_file='levels.txt',
sample_file='sample.txt',
sample_info_file='sample_info.txt',
*args,
**kwargs):
"""
Read dnest4 output files.

Parameters
----------
root : str
root specify the directory only, no specific roots,
The files read files are levels_file, sample_file and sample_info.
levels_file: str
output name from DNest4
sample_file: str
output name from DNest4
sample_info_file: str
output name from DNest4
"""
levels = np.loadtxt(os.path.join(root, levels_file),
dtype=float,
delimiter=' ',
comments='#')
samples = np.genfromtxt(os.path.join(root, sample_file),
dtype=float,
delimiter=' ',
comments='#')
sample_info = np.loadtxt(os.path.join(root, sample_info_file),
dtype=float,
delimiter=' ',
comments='#')

with open(os.path.join(root, sample_file), 'r') as f:
header = f.readline()

n_params = samples.shape[1]

sample_level = sample_info[:, 0].astype(int)
logL = sample_info[:, 1]
logL_birth = levels[sample_level, 1]

kwargs['label'] = kwargs.get('label', os.path.basename(root))
columns_ = _determine_columns(n_params, header)
columns = kwargs.pop('columns', columns_)
labels_ = {c: '$' + c + '$' for c in columns}
labels = kwargs.pop('labels', labels_)

return DiffusiveNestedSamples(sample_info=sample_info,
levels=levels,
samples=samples,
logL=logL,
logL_birth=logL_birth,
columns=columns,
labels=labels,
*args,
**kwargs)
Loading
Loading