Skip to content

Commit

Permalink
Prediction methods for stats functions (#31)
Browse files Browse the repository at this point in the history
* Added a predict method to stats functions

* Tests added

* All tests passing

* bump version to 0.11.0

* Added other return docstrings
  • Loading branch information
williamjameshandley authored Nov 26, 2023
1 parent 9432101 commit e9da66d
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 6 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ lsbi: Linear Simulation Based Inference
=======================================
:lsbi: Linear Simulation Based Inference
:Author: Will Handley & David Yallup
:Version: 0.10.0
:Version: 0.11.0
:Homepage: https://github.com/handley-lab/lsbi
:Documentation: http://lsbi.readthedocs.io/

Expand Down
7 changes: 5 additions & 2 deletions bin/run_tests
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#!/bin/bash

echo "Running PEP code style tests"
flake8 lsbi tests
echo "Running black check"
black --check .

echo "Running isort check"
isort --check-only --profile black .

echo "Running docstring checks"
pydocstyle --convention=numpy lsbi
Expand Down
2 changes: 1 addition & 1 deletion lsbi/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.10.0"
__version__ = "0.11.0"
106 changes: 106 additions & 0 deletions lsbi/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ def marginalise(self, indices):
----------
indices : array_like
Indices to marginalise.
Returns
-------
marginalised distribution: multivariate_normal
"""
i = self._bar(indices)
mean = self.mean[i]
Expand All @@ -31,6 +35,10 @@ def condition(self, indices, values):
Indices to condition over.
values : array_like
Values to condition on.
Returns
-------
conditional distribution: multivariate_normal
"""
i = self._bar(indices)
k = indices
Expand Down Expand Up @@ -66,6 +74,10 @@ def bijector(self, x, inverse=False):
inverse : bool, optional, default=False
If True: compute the inverse transformation from physical to
hypercube space.
Returns
-------
transformed x or theta: array_like, shape (..., d)
"""
L = np.linalg.cholesky(self.cov)
if inverse:
Expand All @@ -76,6 +88,29 @@ def bijector(self, x, inverse=False):
y = scipy.stats.norm.ppf(x)
return self.mean + np.einsum("ij,...j->...i", L, y)

def predict(self, A, b=None):
"""Predict the mean and covariance of a linear transformation.
if: x ~ N(mu, Sigma)
then: Ax + b ~ N(A mu + b, A Sigma A^T)
Parameters
----------
A : array_like, shape (q, n)
Linear transformation matrix.
b : array_like, shape (q,), optional
Linear transformation vector.
Returns
-------
predicted distribution: multivariate_normal
"""
if b is None:
b = np.zeros(A.shape[0])
mean = A @ self.mean + b
cov = A @ self.cov @ A.T
return multivariate_normal(mean, cov, allow_singular=True)


class multimultivariate_normal(object):
"""Multivariate normal distribution with multiple means and covariances.
Expand Down Expand Up @@ -133,6 +168,10 @@ def marginalise(self, indices):
----------
indices : array_like
Indices to marginalise.
Returns
-------
marginalised distribution: multimultivariate_normal
"""
i = self._bar(indices)
means = self.means[:, i]
Expand All @@ -148,6 +187,10 @@ def condition(self, indices, values):
Indices to condition over.
values : array_like
Values to condition on.
Returns
-------
conditional distribution: multimultivariate_normal
"""
i = self._bar(indices)
k = indices
Expand Down Expand Up @@ -189,6 +232,10 @@ def bijector(self, x, inverse=False):
inverse : bool, optional, default=False
If True: compute the inverse transformation from physical to
hypercube space.
Returns
-------
transformed x or theta: array_like, shape (..., d)
"""
Ls = np.linalg.cholesky(self.covs)
if inverse:
Expand All @@ -199,6 +246,29 @@ def bijector(self, x, inverse=False):
y = scipy.stats.norm.ppf(x)
return self.means + np.einsum("ijk,...ik->...ij", Ls, y)

def predict(self, A, b=None):
"""Predict the mean and covariance of a linear transformation.
if: x ~ N(mu, Sigma)
then: Ax + b ~ N(A mu + b, A Sigma A^T)
Parameters
----------
A : array_like, shape (k, q, n)
Linear transformation matrix.
b : array_like, shape (k, q), optional
Linear transformation vector.
Returns
-------
predicted distribution: mixture_multivariate_normal
"""
if b is None:
b = np.zeros(A.shape[:-1])
means = np.einsum("kqn,kn->kq", A, self.means) + b
covs = np.einsum("kpn,knm,kqm->kpq", A, self.covs, A)
return multimultivariate_normal(means, covs)


class mixture_multivariate_normal(object):
"""Mixture of multivariate normal distributions.
Expand Down Expand Up @@ -254,6 +324,10 @@ def marginalise(self, indices):
----------
indices : array_like
Indices to marginalise.
Returns
-------
marginalised distribution: mixture_multivariate_normal
"""
i = self._bar(indices)
means = self.means[:, i]
Expand All @@ -270,6 +344,10 @@ def condition(self, indices, values):
Indices to condition over.
values : array_like
Values to condition on.
Returns
-------
conditional distribution: mixture_multivariate_normal
"""
i = self._bar(indices)
k = indices
Expand Down Expand Up @@ -315,6 +393,10 @@ def bijector(self, x, inverse=False):
inverse : bool, optional, default=False
If True: compute the inverse transformation from physical to
hypercube space.
Returns
-------
transformed x or theta: array_like, shape (..., d)
"""
theta = np.empty_like(x)
if inverse:
Expand Down Expand Up @@ -374,3 +456,27 @@ def _process_quantiles(self, x, dim):
x = x[np.newaxis, :]

return x

def predict(self, A, b=None):
"""Predict the mean and covariance of a linear transformation.
if: x ~ mixN(mu, Sigma, logA)
then: Ax + b ~ mixN(A mu + b, A Sigma A^T, logA)
Parameters
----------
A : array_like, shape (k, q, n)
Linear transformation matrix.
b : array_like, shape (k, q,), optional
Linear transformation vector.
Returns
-------
predicted distribution: mixture_multivariate_normal
"""
if b is None:
b = np.zeros(A.shape[:-1])
means = np.einsum("kqn,kn->kq", A, self.means) + b
covs = np.einsum("kqn,knm,kpm->kqp", A, self.covs, A)
logA = self.logA
return mixture_multivariate_normal(means, covs, logA)
61 changes: 59 additions & 2 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,22 @@ def test_bijector(self, k, d):
assert dist.bijector(theta, inverse=True).shape == x.shape

@pytest.mark.parametrize("p", np.arange(1, 5))
def test_marginalise_condition(self, d, k, p):
def test_marginalise_condition(self, k, d, p):
if d <= p:
pytest.skip("d <= p")
i = np.random.choice(d, p, replace=False)
j = np.array([x for x in range(d) if x not in i])
dist = self.random(k, d)
mixture_2 = dist.marginalise(i)
assert isinstance(mixture_2, self.cls)
assert mixture_2.means.shape == (k, d - p)
assert mixture_2.covs.shape == (k, d - p, d - p)
assert_allclose(dist.means[:, j], mixture_2.means)
assert_allclose(dist.covs[:, j][:, :, j], mixture_2.covs)

v = np.random.randn(k, p)
mixture_3 = dist.condition(i, v)
assert isinstance(mixture_3, self.cls)
assert mixture_3.means.shape == (k, d - p)
assert mixture_3.covs.shape == (k, d - p, d - p)

Expand All @@ -118,6 +120,24 @@ def test_marginalise_condition(self, d, k, p):
assert mixture_3.means.shape == (k, d - p)
assert mixture_3.covs.shape == (k, d - p, d - p)

@pytest.mark.parametrize("q", [1, 2, 5, 10])
def test_predict(self, q, k, d):
dist = self.random(k, d)
A = np.random.randn(k, q, d)
y = dist.predict(A)
assert isinstance(y, self.cls)
assert y.means.shape == (k, q)
assert y.covs.shape == (k, q, q)

b = np.random.randn(q)
y = dist.predict(A, b)
assert isinstance(y, self.cls)
assert y.means.shape == (
k,
q,
)
assert y.covs.shape == (k, q, q)


@pytest.mark.parametrize("d", [1, 2, 5, 10])
class TestMultivariateNormal(object):
Expand Down Expand Up @@ -190,16 +210,33 @@ def test_marginalise_condition_multivariate_normal(self, d, p):
j = np.array([x for x in range(d) if x not in i])
dist_1 = self.random(d)
dist_2 = dist_1.marginalise(i)
assert isinstance(dist_2, self.cls)
assert dist_2.mean.shape == (d - p,)
assert dist_2.cov.shape == (d - p, d - p)
assert_allclose(dist_1.mean[j], dist_2.mean)
assert_allclose(dist_1.cov[j][:, j], dist_2.cov)

v = np.random.randn(p)
dist_3 = dist_1.condition(i, v)
assert isinstance(dist_3, self.cls)
assert dist_3.mean.shape == (d - p,)
assert dist_3.cov.shape == (d - p, d - p)

@pytest.mark.parametrize("q", [1, 2, 5, 10])
def test_predict(self, q, d):
dist = self.random(d)
A = np.random.randn(q, d)
y = dist.predict(A)
assert isinstance(y, self.cls)
assert y.mean.shape == (q,)
assert y.cov.shape == (q, q)

b = np.random.randn(q)
y = dist.predict(A, b)
assert isinstance(y, self.cls)
assert y.mean.shape == (q,)
assert y.cov.shape == (q, q)


@pytest.mark.parametrize("d", [1, 2, 5, 10])
@pytest.mark.parametrize("k", [1, 2, 5, 10])
Expand Down Expand Up @@ -300,19 +337,39 @@ def test_bijector(self, k, d):
assert dist.bijector(theta, inverse=True).shape == xs.shape

@pytest.mark.parametrize("p", np.arange(1, 5))
def test_marginalise_condition(self, d, k, p):
def test_marginalise_condition(self, k, d, p):
if d <= p:
pytest.skip("d <= p")
i = np.random.choice(d, p, replace=False)
j = np.array([x for x in range(d) if x not in i])
dist = self.random(k, d)
mixture_2 = dist.marginalise(i)
assert isinstance(mixture_2, self.cls)
assert mixture_2.means.shape == (k, d - p)
assert mixture_2.covs.shape == (k, d - p, d - p)
assert_allclose(dist.means[:, j], mixture_2.means)
assert_allclose(dist.covs[:, j][:, :, j], mixture_2.covs)

v = np.random.randn(k, p)
mixture_3 = dist.condition(i, v)
assert isinstance(mixture_3, self.cls)
assert mixture_3.means.shape == (k, d - p)
assert mixture_3.covs.shape == (k, d - p, d - p)

@pytest.mark.parametrize("q", [1, 2, 5, 10])
def test_predict(self, q, k, d):
dist = self.random(k, d)
A = np.random.randn(k, q, d)
y = dist.predict(A)
assert isinstance(y, self.cls)
assert y.means.shape == (k, q)
assert y.covs.shape == (k, q, q)

b = np.random.randn(q)
y = dist.predict(A, b)
assert isinstance(y, self.cls)
assert y.means.shape == (
k,
q,
)
assert y.covs.shape == (k, q, q)

0 comments on commit e9da66d

Please sign in to comment.