Skip to content

Commit

Permalink
simplify tension tests and add test for direct input of nested sampli…
Browse files Browse the repository at this point in the history
…ng stats to tension stats
  • Loading branch information
lukashergt committed Sep 27, 2024
1 parent 88b7ec6 commit 51c0975
Showing 1 changed file with 40 additions and 35 deletions.
75 changes: 40 additions & 35 deletions tests/test_tension.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from pytest import approx
from pytest import approx, raises
from anesthetic.examples.perfect_ns import correlated_gaussian
import numpy as np
from numpy.linalg import inv, slogdet
from pandas.testing import assert_series_equal
from anesthetic.tension import stats


Expand Down Expand Up @@ -37,30 +38,32 @@ def test_tension_stats_compatible_gaussian():

nsamples = 10
beta = 1
samples_stats = stats(samplesA, samplesB, samplesAB, nsamples, beta)
s = stats(samplesA, samplesB, samplesAB, nsamples, beta)

logR_std = samples_stats.logR.std()
logR_mean = samples_stats.logR.mean()
logR_exact = logV - dmu_cov_dmu_AB/2 - slogdet(2*np.pi*(covA+covB))[1]/2
assert logR_mean == approx(logR_exact, abs=3*logR_std)
assert s.logR.mean() == approx(logR_exact, abs=3*s.logR.std())

logS_std = samples_stats.logS.std()
logS_mean = samples_stats.logS.mean()
logS_exact = d / 2 - dmu_cov_dmu_AB / 2
assert logS_mean == approx(logS_exact, abs=3*logS_std)
assert s.logS.mean() == approx(logS_exact, abs=3*s.logS.std())

logI_std = samples_stats.logI.std()
logI_mean = samples_stats.logI.mean()
logI_exact = logV - d / 2 - slogdet(2*np.pi*(covA+covB))[1] / 2
assert logI_mean == approx(logI_exact, abs=3*logI_std)
assert s.logI.mean() == approx(logI_exact, abs=3*s.logI.std())

assert logS_mean == approx(logR_mean - logI_mean, abs=3*logS_std)
assert s.logS.mean() == approx(s.logR.mean() - s.logI.mean(),
abs=3*s.logS.std())

assert samples_stats.get_labels().tolist() == ([r'$\ln\mathcal{R}$',
r'$\ln\mathcal{I}$',
r'$\ln\mathcal{S}$',
r'$d_\mathrm{G}$',
r'$p$'])
assert s.get_labels().tolist() == ([r'$\ln\mathcal{R}$',
r'$\ln\mathcal{I}$',
r'$\ln\mathcal{S}$',
r'$d_\mathrm{G}$',
r'$p$'])

with raises(ValueError):
stats(samplesA.stats(nsamples=5), samplesB, samplesAB, nsamples)
s2 = stats(samplesA.stats(nsamples=nsamples),
samplesB.stats(nsamples=nsamples),
samplesAB.stats(nsamples=nsamples))
assert_series_equal(s2.mean(), s.mean(), atol=s2.std().max())


def test_tension_stats_incompatible_gaussian():
Expand Down Expand Up @@ -95,27 +98,29 @@ def test_tension_stats_incompatible_gaussian():

nsamples = 10
beta = 1
samples_stats = stats(samplesA, samplesB, samplesAB, nsamples, beta)
s = stats(samplesA, samplesB, samplesAB, nsamples, beta)

logR_std = samples_stats.logR.std()
logR_mean = samples_stats.logR.mean()
logR_exact = logV - dmu_cov_dmu_AB/2 - slogdet(2*np.pi*(covA+covB))[1]/2
assert logR_mean == approx(logR_exact, abs=3*logR_std)
assert s.logR.mean() == approx(logR_exact, abs=3*s.logR.std())

logS_std = samples_stats.logS.std()
logS_mean = samples_stats.logS.mean()
logS_exact = d / 2 - dmu_cov_dmu_AB / 2
assert logS_mean == approx(logS_exact, abs=3*logS_std)
assert s.logS.mean() == approx(logS_exact, abs=3*s.logS.std())

logI_std = samples_stats.logI.std()
logI_mean = samples_stats.logI.mean()
logI_exact = logV - d / 2 - slogdet(2*np.pi*(covA+covB))[1] / 2
assert logI_mean == approx(logI_exact, abs=3*logI_std)

assert logS_mean == approx(logR_mean - logI_mean, abs=3*logS_std)

assert samples_stats.get_labels().tolist() == ([r'$\ln\mathcal{R}$',
r'$\ln\mathcal{I}$',
r'$\ln\mathcal{S}$',
r'$d_\mathrm{G}$',
r'$p$'])
assert s.logI.mean() == approx(logI_exact, abs=3*s.logI.std())

assert s.logS.mean() == approx(s.logR.mean() - s.logI.mean(),
abs=3*s.logS.std())

assert s.get_labels().tolist() == ([r'$\ln\mathcal{R}$',
r'$\ln\mathcal{I}$',
r'$\ln\mathcal{S}$',
r'$d_\mathrm{G}$',
r'$p$'])

with raises(ValueError):
stats(samplesA.stats(nsamples=5), samplesB, samplesAB, nsamples)
s2 = stats(samplesA.stats(nsamples=nsamples),
samplesB.stats(nsamples=nsamples),
samplesAB.stats(nsamples=nsamples))
assert_series_equal(s2.mean(), s.mean(), atol=s2.std().max())

0 comments on commit 51c0975

Please sign in to comment.