Skip to content

Commit

Permalink
fix warning
Browse files Browse the repository at this point in the history
  • Loading branch information
qacwnfq committed Aug 26, 2024
1 parent 4867214 commit 19354ca
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 29 deletions.
14 changes: 11 additions & 3 deletions anesthetic/read/dnest4.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def _determine_columns_and_labels(n_params, header, delim=' '):
def read_dnest4(root,
levels_file='levels.txt',
sample_file='sample.txt',
sample_info_file='sample_info.txt'):
sample_info_file='sample_info.txt',
*args,
**kwargs):
"""
Read dnest4 output files.
Expand Down Expand Up @@ -57,16 +59,22 @@ def read_dnest4(root,
header = f.readline()

n_params = samples.shape[1]
columns, labels = _determine_columns_and_labels(n_params, header)

sample_level = sample_info[:, 0].astype(int)
logL = sample_info[:, 1]
logL_birth = levels[sample_level, 1]

kwargs['label'] = kwargs.get('label', os.path.basename(root))
columns, labels = _determine_columns_and_labels(n_params, header)
columns = kwargs.pop('columns', columns)
labels = kwargs.pop('labels', labels)

return DiffusiveNestedSamples(sample_info=sample_info,
levels=levels,
samples=samples,
logL=logL,
logL_birth=logL_birth,
columns=columns,
labels=labels)
labels=labels,
*args,
**kwargs)
37 changes: 11 additions & 26 deletions anesthetic/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,19 +1434,21 @@ def __init__(self, sample_info, levels, samples, *args, **kwargs):
self.logzero = kwargs.get('logzero', -1e300)
sample_info_columns = [kwargs.get('columns') +
['level', 'logL', 'tiebreaker', 'ID']]
self.samples_with_info = pandas.DataFrame(
self.attrs["samples_with_info"] = pandas.DataFrame(
np.concatenate([samples, sample_info], axis=1),
columns=sample_info_columns)
self.samples_with_info[['level', 'ID']] = \
self.samples_with_info[['level', 'ID']].astype(int)
level_columns = ['log_X',
'log_likelihood',
'tiebreaker',
'accepts',
'tries',
'exceeds',
'visits']
self.levels = pandas.DataFrame(levels, columns=level_columns)
self.attrs["levels"] = pandas.DataFrame(levels, columns=level_columns)

@property
def _constructor(self):
return NestedSamples

def samples_at_level(self, level_index, label):
"""
Expand All @@ -1463,8 +1465,9 @@ def samples_at_level(self, level_index, label):
-------
numpy.ndarray:
"""
selection = (self.samples_with_info.level == level_index).squeeze()
return self.samples_with_info[selection][label].to_numpy()
samples = self.attrs["samples_with_info"]
selection = (samples.level == level_index).squeeze()
return samples[selection][label].to_numpy()

Check warning on line 1470 in anesthetic/samples.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/samples.py#L1468-L1470

Added lines #L1468 - L1470 were not covered by tests

def n_live(self, *args):
"""
Expand All @@ -1485,24 +1488,6 @@ def n_live(self, *args):
"""
return self.num_particles

Check warning on line 1489 in anesthetic/samples.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/samples.py#L1489

Added line #L1489 was not covered by tests

def LX(self, beta, logX):
"""
Get LX, e.g., for Higson plot.
Parameters
----------
beta: float
temperature
logX: np.ndarray
prior volumes
Returns
-------
LX: np.ndarray
"""
LX = super().LX(beta, logX)
return LX

def plot_types(self):
"""
Get types of plots supported by this class.
Expand Down Expand Up @@ -1536,9 +1521,9 @@ def points_to_plot(self, plot_type, label, evolution, beta, base_color):
List[tuple[float]: colors to use
"""
if plot_type == 'visited points':
base_color = 'C0'
logX = self.logX().to_numpy()[evolution]
levels_to_plot = self.levels[self.levels['log_X'] >= logX]
levels = self.attrs["levels"]
levels_to_plot = levels[levels['log_X'] >= logX]
max_level_index = levels_to_plot.tail(1).index.to_list()[0]
colors = [basic_cmap(base_color)(float(j) / (max_level_index + 1))

Check warning on line 1528 in anesthetic/samples.py

View check run for this annotation

Codecov / codecov/patch

anesthetic/samples.py#L1523-L1528

Added lines #L1523 - L1528 were not covered by tests
for j in range(1, max_level_index + 2)]
Expand Down

0 comments on commit 19354ca

Please sign in to comment.