Skip to content

Commit

Permalink
Merge pull request cupy#8065 from asi1024/beta-invalid
Browse files Browse the repository at this point in the history
Fix `cupyx.scipy.special.betainc` for invalid inputs
  • Loading branch information
takagi authored and chainer-ci committed Jan 10, 2024
1 parent 0b75202 commit f0ff96c
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 14 deletions.
31 changes: 19 additions & 12 deletions cupyx/scipy/special/_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,21 +474,20 @@
double a, b, t, x, xc, w, y;
int flag;
if (aa <= 0.0 || bb <= 0.0)
{
if (!isfinite(aa) || !isfinite(bb) || isnan(xx)) {
return CUDART_NAN;
}
if ((xx <= 0.0) || (xx >= 1.0)) {
if (xx == 0.0) {
return 0.0;
}
if (xx == 1.0) {
return 1.0;
}
if (aa <= 0.0 || bb <= 0.0 || xx < 0 || xx > 1) {
return CUDART_NAN;
}
if (xx == 0.0) {
return 0.0;
}
if (xx == 1.0) {
return 1.0;
}
flag = 0;
if ((bb * xx) <= 1.0 && xx <= 0.95) {
t = pseries(aa, bb, xx);
Expand Down Expand Up @@ -829,11 +828,19 @@
double a, b, y0, d, y, x, x0, x1, lgm, yp, di, dithresh, yl, yh, xt;
int i, rflg, dir, nflg;
if (isnan(aa) || isnan(bb) || isnan(yy0)) {
return CUDART_NAN;
}
if (aa <= 0.0 || bb <= 0.0 || yy0 < 0 || yy0 > 1) {
return CUDART_NAN;
}
i = 0;
if (yy0 <= 0) {
if (yy0 == 0) {
return 0.0;
}
if (yy0 >= 1.0) {
if (yy0 == 1.0) {
return 1.0;
}
x0 = 0.0;
Expand Down
2 changes: 2 additions & 0 deletions cupyx/scipy/special/_stats_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,8 @@
}
dk = k + 1;
dn = n;
if (y <= 0.0) return 0.0;
if (y >= 1.0) return 1.0;
w = incbi(dn, dk, y);
return (w);
}
Expand Down
2 changes: 1 addition & 1 deletion tests/cupyx_tests/scipy_tests/special_tests/test_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_betaln_specific_vals(self):
3.1811881124242447, rtol=1e-14, atol=0)


@testing.with_requires('scipy')
@testing.with_requires('scipy>=1.12.0rc1')
class TestBetaInc:

@pytest.mark.parametrize('function', ['betainc', 'betaincinv'])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
cupyx_scipy_ufuncs = set()


@testing.with_requires("scipy")
@testing.with_requires("scipy>=1.12.0rc1")
@pytest.mark.parametrize("ufunc", sorted(cupyx_scipy_ufuncs & scipy_ufuncs))
class TestUfunc:
def _should_skip(self, f):
Expand Down

0 comments on commit f0ff96c

Please sign in to comment.