Skip to content

Commit

Permalink
Updates post imperial talk
Browse files Browse the repository at this point in the history
These changes allowed slides from these talks to be created:
- [PhyStat](https://github.com/williamjameshandley/talks/tree/imperial_2024)
- [cosmoverse](https://github.com/williamjameshandley/talks/tree/cosmoverse_2024)
  • Loading branch information
williamjameshandley committed Sep 15, 2024
1 parent 704c556 commit 0874796
Showing 1 changed file with 92 additions and 0 deletions.
92 changes: 92 additions & 0 deletions plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import matplotlib.pyplot as plt
import numpy as np
from anesthetic import make_2d_axes
from anesthetic.plot import basic_cmap
from matplotlib.colors import LinearSegmentedColormap

from lsbi.stats import mixture_normal, multivariate_normal

np.random.seed(0)

dim = 3
shape = ()
logA = np.random.randn(*shape, dim)
mean = np.random.randn(*shape, dim)
cov = np.random.randn(*shape, dim, dim)
cov = np.einsum("...ij,...kj->...ik", cov, cov)
dist = multivariate_normal(mean, cov)


def plot(dist, ax=None, *args, **kwargs):
if dist.dim > 2:
raise ValueError("dist must be 2D or 1D")
if ax is None:
ax = plt.gca()

N = 10000
x = dist.rvs(N)
logpdf = dist.logpdf(x, broadcast=True)
logpdfmin = np.sort(logpdf, axis=0)[::-1][int(0.997 * N)]

xs = np.atleast_1d(x[..., 0])
shape = xs.shape[1:]
xs = xs.reshape(N, -1)
logpdfs = np.atleast_1d(logpdf).reshape(N, -1)
logpdfmins = np.atleast_1d(logpdfmin).reshape(-1)

ans = []
if dist.dim == 1:
i = np.argsort(xs, axis=0)
for j in range(xs.shape[1]):
x = xs[i[:, j], j]
logpdf = logpdfs[i[:, j], j]
logpdf[logpdf < logpdfmins[j]] = np.nan
ans.append(ax.plot(x, np.exp(logpdf), *args, **kwargs))
elif dist.dim == 2:
contours = [0.95, 0.67, 0]
contours = np.array(contours)

ys = np.atleast_1d(x[..., 1]).reshape(N, -1)
for j in range(xs.shape[1]):
logpdf = logpdfs[:, j]
levels = np.sort(logpdf)[::-1][np.array(contours * N, dtype=int)]
x_ = xs[:, j]
y_ = ys[:, j]
color = kwargs.pop("color", ax._get_lines.get_next_color())
cmap = kwargs.pop("cmap", basic_cmap(color))
ans.append(
ax.tricontourf(
x_, y_, logpdf, levels=levels, cmap=cmap, *args, **kwargs
)
)
return np.array(ans).reshape(shape)


np.random.seed(0)
# k = 30
k = 1
dim = 5
shape = 2, 2
logA = np.random.randn(*shape, k)
mean = np.random.randn(*shape, k, dim) * 10
cov = np.random.randn(*shape, k, dim, dim)
cov = np.einsum("...ij,...kj->...ik", cov, cov)
# dist = mixture_normal(logA, mean, cov)
dist = multivariate_normal(mean, cov)

# dist = dist.marginalise([0,1,2,3])
# fig, ax = plt.subplots()


cols = list(range(dist.dim))
fig, axes = make_2d_axes(cols)
rvs = dist.rvs(1000).reshape(1000, -1, dist.dim)
for x in cols:
for y in cols:
if x == y:
plot(dist.marginalise(list(set(cols) - {x})), axes.loc[x, x].twin)
elif x < y:
plot(dist.marginalise(list(set(cols) - {x, y})), axes.loc[y, x], alpha=0.5)
else:
for k in range(rvs.shape[1]):
axes.loc[y, x].scatter(rvs[:, k, x], rvs[:, k, y], s=1)

0 comments on commit 0874796

Please sign in to comment.