Skip to content

Commit d2aa92e

Browse files
committed
Implement metalog cdf
1 parent 0988d64 commit d2aa92e

File tree

1 file changed

+45
-15
lines changed

1 file changed

+45
-15
lines changed

epochutils/stats/distributions.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import warnings
22

3+
from scipy.optimize import bisect
34
from scipy.stats import rv_continuous
45
import numpy as np
56

@@ -167,12 +168,10 @@ def __init__(self, quantiles, n_terms=None):
167168
n_terms = len(self.quantiles)
168169
self.n_terms = n_terms
169170

170-
self.a = self._fit_metalog(self.quantiles, n_terms, self.transform)
171-
if self.a is None:
171+
self.metalog_a = self._fit_metalog(self.quantiles, n_terms, self.transform)
172+
if self.metalog_a is None:
172173
raise ValueError('Failed to fit metalog. The Y^T Y matrix is not invertible.')
173174

174-
self._check_feasibility(self.a, self.raw_quantiles)
175-
176175
def _compute_transforms(self, lower_bound, upper_bound):
177176
if (lower_bound is not None) and (upper_bound is not None):
178177
# TODO: Handle y = 0 and y = 1 cases
@@ -205,10 +204,10 @@ def _check_feasibility(self, a, raw_quantiles):
205204
feasible = (a[1] > 0 and abs(a[2])/a[1] <= 1.66711)
206205
else:
207206
feasible = True
208-
print('Warning: Feasibility check not implemented for more than 3 quantiles')
207+
warnings.warn('Warning: Feasibility check not implemented for more than 3 quantiles')
209208

210209
if not feasible:
211-
warnings.warn(f'Failed feasibility check for quantiles {raw_quantiles}')
210+
raise ValueError(f'Failed feasibility check for quantiles {raw_quantiles}')
212211

213212
# Equations 7 and 8
214213
def _fit_metalog(self, quantiles, n_terms, transform):
@@ -249,8 +248,11 @@ def _fit_metalog(self, quantiles, n_terms, transform):
249248
return a
250249

251250
# Equation 1, 2 and 3
252-
def ppf(self, y, _apply_transform=True):
253-
a = self.a
251+
def _ppf(self, y, _apply_transform=True):
252+
input_is_array = isinstance(y, np.ndarray)
253+
y = np.atleast_1d(y)
254+
255+
a = self.metalog_a
254256
n = self.n_terms
255257

256258
mu_coeff_indices = np.concatenate(([1, 4, 5], np.arange(7, n + 1, 2)))
@@ -262,14 +264,42 @@ def ppf(self, y, _apply_transform=True):
262264
mu = np.dot(a[mu_coeff_indices-1], np.power(y - 0.5, np.vstack(np.arange(len(mu_coeff_indices)))))
263265
s = np.dot(a[s_coeff_indices-1], np.power(y - 0.5, np.vstack(np.arange(len(s_coeff_indices)))))
264266

265-
result = mu + s * np.log(y / (1 - y))
267+
safe_mask = (y > 0) & (y < 1)
268+
269+
result = np.full_like(y, np.nan)
270+
result[y == 0] = self.lower_bound if self.lower_bound is not None else -np.inf
271+
result[y == 1] = self.upper_bound if self.upper_bound is not None else np.inf
272+
273+
result[safe_mask] = mu + s * np.log(y[safe_mask] / (1 - y[safe_mask]))
266274
if _apply_transform:
267-
result = self.inverse_transform(result)
268-
return result if isinstance(y, np.ndarray) else result[0]
275+
result[safe_mask] = self.inverse_transform(result[safe_mask])
276+
277+
return result if input_is_array else result[0]
278+
279+
def cdf(self, x, _apply_transform=True):
280+
# For now, use a simple root-finding algorithm
281+
282+
if np.isscalar(x):
283+
return self._cdf_scalar(x)
284+
else:
285+
return np.array([self._cdf_scalar(a) for a in np.atleast_1d(x)])
286+
287+
def _cdf_scalar(self, x):
288+
if self.lower_bound is not None and x <= self.lower_bound:
289+
return 0
290+
291+
if self.upper_bound is not None and x >= self.upper_bound:
292+
return 1
293+
294+
def objective(p):
295+
return self._ppf(p) - x
296+
297+
p_estimate = bisect(objective, 0, 1) # Find root in [0, 1]
298+
return p_estimate
269299

270300
# Equation 9
271301
def pdf_from_cum_prob(self, y):
272-
a = self.a
302+
a = self.metalog_a
273303
n = self.n_terms
274304

275305
denominator_terms = []
@@ -289,7 +319,7 @@ def pdf_from_cum_prob(self, y):
289319

290320
# Adjust for the transformation
291321
if (self.lower_bound is not None) or (self.upper_bound is not None):
292-
ppf = self.ppf(y, _apply_transform=False)
322+
ppf = self._ppf(y, _apply_transform=False)
293323
exp_ppf = np.exp(ppf)
294324

295325
# TODO Handle y = 0 and y = 1 cases
@@ -303,10 +333,10 @@ def pdf_from_cum_prob(self, y):
303333
return pdf
304334

305335
def quantile(self, q):
306-
return self.ppf(q)
336+
return self._ppf(q)
307337

308338
def rvs(self, size=1):
309-
return self.ppf(np.random.uniform(size=size))
339+
return self._ppf(np.random.uniform(size=size))
310340

311341

312342
# Aliases

0 commit comments

Comments
 (0)