Skip to content

Commit 63fbfc6

Browse files
authored
Merge pull request #25 from pymc-devs/logitnormal_gh
Use Gauss-Hermite quadrature for logitnormal moments and entropy
2 parents 913aba2 + da45c34 commit 63fbfc6

3 files changed

Lines changed: 83 additions & 25 deletions

File tree

distributions/logitnormal.py

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,118 @@
1+
import numpy as np
12
import pytensor.tensor as pt
23

34
from distributions.helper import (
45
cdf_bounds,
5-
continuous_entropy,
6-
continuous_kurtosis,
7-
continuous_mean,
8-
continuous_skewness,
9-
continuous_variance,
106
ppf_bounds_cont,
117
)
128
from distributions.normal import ppf as normal_ppf
139

14-
# Support bounds for logitnormal (open interval (0, 1))
15-
_LOWER = 0.001
16-
_UPPER = 0.999
17-
1810

1911
def _logit(x):
2012
return pt.log(x) - pt.log1p(-x)
2113

2214

23-
def _expit(y):
24-
return pt.sigmoid(y)
15+
def _ghq_moments(mu, sigma, order=1, mean_val=None, n_points=70):
16+
"""
17+
Compute moments of the logit-normal using Gauss-Hermite quadrature.
18+
19+
Based on https://en.wikipedia.org/wiki/Logit-normal_distribution#Moments
20+
but using Gauss-Hermite quadrature for better accuracy.
21+
22+
Parameters
23+
----------
24+
mu : tensor
25+
Mean of underlying normal distribution
26+
sigma : tensor
27+
Standard deviation of underlying normal distribution
28+
order : int
29+
Order of the moment
30+
mean_val : tensor, optional
31+
If provided, compute central moment around this mean
32+
n_points : int
33+
Number of Gauss–Hermite nodes
34+
35+
Returns
36+
-------
37+
tensor
38+
Estimated moment
39+
"""
40+
gh_x, gh_w = np.polynomial.hermite.hermgauss(n_points)
41+
gh_x = pt.as_tensor_variable(gh_x)
42+
gh_w = pt.as_tensor_variable(gh_w)
43+
44+
broadcast_shape = pt.broadcast_arrays(mu, sigma)[0]
45+
46+
gh_x_bc = gh_x.reshape((-1,) + (1,) * broadcast_shape.ndim)
47+
gh_w_bc = gh_w.reshape((-1,) + (1,) * broadcast_shape.ndim)
48+
49+
z = pt.sqrt(2.0) * sigma * gh_x_bc + mu
50+
x_vals = pt.sigmoid(z)
51+
52+
if mean_val is not None:
53+
integrand = (x_vals - mean_val) ** order
54+
else:
55+
integrand = x_vals**order
56+
57+
result = pt.sum(gh_w_bc * integrand, axis=0) / pt.sqrt(pt.pi)
58+
59+
return result
2560

2661

2762
def mean(mu, sigma):
28-
return continuous_mean(_LOWER, _UPPER, logpdf, mu, sigma)
63+
return _ghq_moments(mu, sigma, order=1)
2964

3065

3166
def mode(mu, sigma):
32-
return _expit(mu)
67+
return pt.sigmoid(mu)
3368

3469

3570
def median(mu, sigma):
3671
shape = pt.broadcast_arrays(mu, sigma)[0]
37-
return pt.full_like(shape, _expit(mu))
72+
return pt.full_like(shape, pt.sigmoid(mu))
3873

3974

4075
def var(mu, sigma):
41-
return continuous_variance(_LOWER, _UPPER, logpdf, mu, sigma)
76+
mean_val = _ghq_moments(mu, sigma, order=1)
77+
return _ghq_moments(mu, sigma, order=2, mean_val=mean_val)
4278

4379

4480
def std(mu, sigma):
4581
return pt.sqrt(var(mu, sigma))
4682

4783

4884
def skewness(mu, sigma):
49-
return continuous_skewness(_LOWER, _UPPER, logpdf, mu, sigma)
85+
mean_val = _ghq_moments(mu, sigma, order=1)
86+
variance = _ghq_moments(mu, sigma, order=2, mean_val=mean_val)
87+
third_central = _ghq_moments(mu, sigma, order=3, mean_val=mean_val)
88+
return third_central / (pt.sqrt(variance) ** 3)
5089

5190

5291
def kurtosis(mu, sigma):
53-
return continuous_kurtosis(_LOWER, _UPPER, logpdf, mu, sigma)
92+
mean_val = _ghq_moments(mu, sigma, order=1)
93+
variance = _ghq_moments(mu, sigma, order=2, mean_val=mean_val)
94+
fourth_central = _ghq_moments(mu, sigma, order=4, mean_val=mean_val)
95+
return fourth_central / (variance**2) - 3
5496

5597

5698
def entropy(mu, sigma):
57-
return continuous_entropy(_LOWER, _UPPER, logpdf, mu, sigma)
99+
gh_x, gh_w = np.polynomial.hermite.hermgauss(70)
100+
gh_x = pt.as_tensor_variable(gh_x)
101+
gh_w = pt.as_tensor_variable(gh_w)
102+
103+
broadcast_shape = pt.broadcast_arrays(mu, sigma)[0]
104+
105+
gh_x_bc = gh_x.reshape((-1,) + (1,) * broadcast_shape.ndim)
106+
gh_w_bc = gh_w.reshape((-1,) + (1,) * broadcast_shape.ndim)
107+
108+
z = pt.sqrt(2.0) * sigma * gh_x_bc + mu
109+
x_vals = pt.sigmoid(z)
110+
111+
integrand = -logpdf(x_vals, mu, sigma)
112+
113+
result = pt.sum(gh_w_bc * integrand, axis=0) / pt.sqrt(pt.pi)
114+
115+
return result
58116

59117

60118
def pdf(x, mu, sigma):
@@ -121,12 +179,12 @@ def logsf(x, mu, sigma):
121179

122180

123181
def ppf(q, mu, sigma):
124-
return ppf_bounds_cont(_expit(normal_ppf(q, mu, sigma)), q, 0, 1)
182+
return ppf_bounds_cont(pt.sigmoid(normal_ppf(q, mu, sigma)), q, 0, 1)
125183

126184

127185
def isf(q, mu, sigma):
128186
return ppf(1 - q, mu, sigma)
129187

130188

131189
def rvs(mu, sigma, size=None, random_state=None):
132-
return _expit(pt.random.normal(mu, sigma, rng=random_state, size=size))
190+
return pt.sigmoid(pt.random.normal(mu, sigma, rng=random_state, size=size))

tests/helper_empirical.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def run_empirical_tests(
2020
kurtosis_rtol=1e-1,
2121
quantiles_rtol=1e-4,
2222
cdf_rtol=1e-4,
23+
pdf_cdf_rtol=1e-3,
2324
is_discrete=False,
2425
):
2526
"""Test a distribution against empirical samples for distributions not in scipy."""
@@ -205,7 +206,7 @@ def run_empirical_tests(
205206
rel_error = np.abs(numerical_pdf[mask] - pdf_vals[mask]) / (
206207
np.abs(pdf_vals[mask]) + 1e-10
207208
)
208-
assert np.all(rel_error < 1e-3), (
209+
assert np.all(rel_error < pdf_cdf_rtol), (
209210
f"PDF doesn't match CDF derivative. Max rel error: {np.max(rel_error)}"
210211
)
211212

tests/test_logitnormal.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
"params",
1212
[
1313
[0.0, 1.0], # Standard logit-normal (centered)
14-
[0.0, 0.5], # Narrower distribution (sigma=0.001 is too extreme for numerical integration)
15-
[1.0, 1.0], # Shifted right (mode > 0.5)
14+
[0.0, 0.001], # Narrower distribution
1615
[-1.0, 1.0], # Shifted left (mode < 0.5)
1716
[0.0, 2.0], # Wider distribution (approaches U-shape)
1817
[2.0, 0.5], # Strongly shifted right
@@ -28,12 +27,12 @@ def test_logitnormal_vs_random(params):
2827
p_params=p_params,
2928
support=support,
3029
name="logitnormal",
31-
sample_size=500_000,
3230
mean_rtol=1e-2,
3331
var_rtol=1e-2,
3432
std_rtol=1e-2,
3533
skewness_rtol=2e-1,
36-
kurtosis_rtol=2e-1,
34+
kurtosis_rtol=2e-1 if params[1] > 0.01 else 1,
3735
quantiles_rtol=3e-2,
3836
cdf_rtol=5e-2,
37+
pdf_cdf_rtol=1e-2,
3938
)

0 commit comments

Comments
 (0)