From 88b7ec68527681bf8de7c88f883ae5932f210722 Mon Sep 17 00:00:00 2001 From: lukashergt Date: Fri, 27 Sep 2024 02:20:19 -0700 Subject: [PATCH] optionally allow for passing a pre-computed stats instance to tension computation to save computing time for high-nsamples runs --- anesthetic/tension.py | 45 +++++++++++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/anesthetic/tension.py b/anesthetic/tension.py index ef52779e..fb9c8590 100644 --- a/anesthetic/tension.py +++ b/anesthetic/tension.py @@ -35,27 +35,31 @@ def stats(A, B, AB, nsamples=None, beta=None): # noqa: D301 Parameters ---------- - A : :class:`anesthetic.samples.NestedSamples` - NestedSamples object from a sampling run using only dataset A. - Alternatively, you can pass the precomputed stats object returned from + A : :class:`anesthetic.samples.Samples` + :class:`anesthetic.samples.NestedSamples` object from a sampling run + using only dataset A. + Alternatively, you can pass a precomputed stats object returned from :meth:`anesthetic.samples.NestedSamples.stats`. - B : :class:`anesthetic.samples.NestedSamples` - NestedSamples object from a sampling run using only dataset B. + B : :class:`anesthetic.samples.Samples` + :class:`anesthetic.samples.NestedSamples` object from a sampling run + using only dataset B. Alternatively, you can pass the precomputed stats object returned from :meth:`anesthetic.samples.NestedSamples.stats`. - AB : :class:`anesthetic.samples.NestedSamples` - NestedSamples object from a sampling run using both datasets A and B - jointly. + AB : :class:`anesthetic.samples.Samples` + :class:`anesthetic.samples.NestedSamples` object from a sampling run + using both datasets A and B jointly. + Alternatively, you can pass the precomputed stats object returned from + :meth:`anesthetic.samples.NestedSamples.stats`. nsamples : int, optional - - If nsamples is not supplied, calculate mean value + - If nsamples is not supplied, calculate mean value. - If nsamples is integer, draw nsamples from the distribution of - values inferred by nested sampling + values inferred by nested sampling. beta : float, array-like, default=1 - Inverse temperature(s) beta=1/kT. + Inverse temperature(s) `beta=1/kT`. Returns ------- @@ -63,9 +67,22 @@ def stats(A, B, AB, nsamples=None, beta=None): # noqa: D301 DataFrame containing the following tension statistics in columns: ['logR', 'logI', 'logS', 'd_G', 'p'] """ - statsA = A.stats(nsamples=nsamples, beta=beta) - statsB = B.stats(nsamples=nsamples, beta=beta) - statsAB = AB.stats(nsamples=nsamples, beta=beta) + columns = ['logZ', 'D_KL', 'logL_P', 'd_G'] + if set(columns).issubset(A.drop_labels().columns): + statsA = A + else: + statsA = A.stats(nsamples=nsamples, beta=beta) + if set(columns).issubset(B.drop_labels().columns): + statsB = B + else: + statsB = B.stats(nsamples=nsamples, beta=beta) + if set(columns).issubset(AB.drop_labels().columns): + statsAB = AB + else: + statsAB = AB.stats(nsamples=nsamples, beta=beta) + if statsA.shape != statsAB.shape or statsB.shape != statsAB.shape: + raise ValueError("Shapes of stats_A, stats_B, and stats_AB do not " + "match. Make sure to pass consistent `nsamples`.") samples = Samples(index=statsA.index)