diff --git a/plot.py b/plot.py new file mode 100644 index 0000000..be1efc0 --- /dev/null +++ b/plot.py @@ -0,0 +1,92 @@ +import matplotlib.pyplot as plt +import numpy as np +from anesthetic import make_2d_axes +from anesthetic.plot import basic_cmap +from matplotlib.colors import LinearSegmentedColormap + +from lsbi.stats import mixture_normal, multivariate_normal + +np.random.seed(0) + +dim = 3 +shape = () +logA = np.random.randn(*shape, dim) +mean = np.random.randn(*shape, dim) +cov = np.random.randn(*shape, dim, dim) +cov = np.einsum("...ij,...kj->...ik", cov, cov) +dist = multivariate_normal(mean, cov) + + +def plot(dist, ax=None, *args, **kwargs): + if dist.dim > 2: + raise ValueError("dist must be 2D or 1D") + if ax is None: + ax = plt.gca() + + N = 10000 + x = dist.rvs(N) + logpdf = dist.logpdf(x, broadcast=True) + logpdfmin = np.sort(logpdf, axis=0)[::-1][int(0.997 * N)] + + xs = np.atleast_1d(x[..., 0]) + shape = xs.shape[1:] + xs = xs.reshape(N, -1) + logpdfs = np.atleast_1d(logpdf).reshape(N, -1) + logpdfmins = np.atleast_1d(logpdfmin).reshape(-1) + + ans = [] + if dist.dim == 1: + i = np.argsort(xs, axis=0) + for j in range(xs.shape[1]): + x = xs[i[:, j], j] + logpdf = logpdfs[i[:, j], j] + logpdf[logpdf < logpdfmins[j]] = np.nan + ans.append(ax.plot(x, np.exp(logpdf), *args, **kwargs)) + elif dist.dim == 2: + contours = [0.95, 0.67, 0] + contours = np.array(contours) + + ys = np.atleast_1d(x[..., 1]).reshape(N, -1) + for j in range(xs.shape[1]): + logpdf = logpdfs[:, j] + levels = np.sort(logpdf)[::-1][np.array(contours * N, dtype=int)] + x_ = xs[:, j] + y_ = ys[:, j] + color = kwargs.pop("color", ax._get_lines.get_next_color()) + cmap = kwargs.pop("cmap", basic_cmap(color)) + ans.append( + ax.tricontourf( + x_, y_, logpdf, levels=levels, cmap=cmap, *args, **kwargs + ) + ) + return np.array(ans).reshape(shape) + + +np.random.seed(0) +# k = 30 +k = 1 +dim = 5 +shape = 2, 2 +logA = np.random.randn(*shape, k) +mean = np.random.randn(*shape, k, dim) * 10 +cov = np.random.randn(*shape, k, dim, dim) +cov = np.einsum("...ij,...kj->...ik", cov, cov) +# dist = mixture_normal(logA, mean, cov) +dist = multivariate_normal(mean, cov) + +# dist = dist.marginalise([0,1,2,3]) +# fig, ax = plt.subplots() + + +cols = list(range(dist.dim)) +fig, axes = make_2d_axes(cols) +rvs = dist.rvs(1000).reshape(1000, -1, dist.dim) +for x in cols: + for y in cols: + if x == y: + plot(dist.marginalise(list(set(cols) - {x})), axes.loc[x, x].twin) + elif x < y: + plot(dist.marginalise(list(set(cols) - {x, y})), axes.loc[y, x], alpha=0.5) + else: + for k in range(rvs.shape[1]): + axes.loc[y, x].scatter(rvs[:, k, x], rvs[:, k, y], s=1)