Skip to content

Commit

Permalink
Merge pull request cupy#8898 from asi1024/sph-harm
Browse files Browse the repository at this point in the history
Add `special.sph_harm_y` and deprecate `special.sph_harm`
  • Loading branch information
asi1024 authored and chainer-ci committed Jan 24, 2025
1 parent f72cee4 commit fc96072
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 2 deletions.
1 change: 1 addition & 0 deletions cupyx/scipy/special/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
# Legendre functions
from cupyx.scipy.special._lpmv import lpmv # NOQA
from cupyx.scipy.special._sph_harm import sph_harm # NOQA
from cupyx.scipy.special._sph_harm import sph_harm_y # NOQA

# Other special functions
from cupyx.scipy.special._binom import binom # NOQA
Expand Down
29 changes: 28 additions & 1 deletion cupyx/scipy/special/_sph_harm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
https://github.com/scipy/scipy/blob/master/scipy/special/sph_harm.pxd
"""

import warnings

from cupy import _core

from cupyx.scipy.special._poch import poch_definition
Expand Down Expand Up @@ -71,7 +73,7 @@
)


sph_harm = _core.create_ufunc(
_sph_harm = _core.create_ufunc(
"cupyx_scipy_lpmv",
("iiff->F", "iidd->D", "llff->F", "lldd->D"),
"out0 = out0_type(sph_harmonic(in0, in1, in2, in3));",
Expand All @@ -82,3 +84,28 @@
""",
)


def sph_harm(m, n, theta, phi, out=None):
"""Spherical Harmonic.
.. seealso:: :meth:`scipy.special.sph_harm`
"""

warnings.warn(DeprecationWarning(
"`cupyx.scipy.special.sph_harm` is deprecated in CuPy v14 "
"and are planned to be removed in the future."))

return _sph_harm(m, n, theta, phi, out=out)


def sph_harm_y(n, m, theta, phi, *, diff_n=0):
"""Spherical Harmonic.
.. seealso:: :meth:`scipy.special.sph_harm`
"""
if diff_n != 0:
raise NotImplementedError("Derivatives not implemented.")

return _sph_harm(m, n, phi, theta)
13 changes: 12 additions & 1 deletion tests/cupyx_tests/scipy_tests/special_tests/test_sph_harm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,22 @@ def _get_harmonic_list(degree_max):
@testing.with_requires("scipy")
class TestBasic():

@pytest.mark.filterwarnings('ignore::DeprecationWarning')
@pytest.mark.parametrize("m, n", _get_harmonic_list(degree_max=5))
@testing.for_dtypes(["e", "f", "d"])
@numpy_cupy_allclose(scipy_name="scp")
@numpy_cupy_allclose(scipy_name="scp", rtol=1e-7, atol=1e-10)
def test_sph_harm(self, xp, scp, dtype, m, n):
theta = xp.linspace(0, 2 * cp.pi)
phi = xp.linspace(0, cp.pi)
theta, phi = xp.meshgrid(theta, phi)
return scp.special.sph_harm(m, n, theta, phi)

@testing.with_requires("scipy>=1.15.0")
@pytest.mark.parametrize("m, n", _get_harmonic_list(degree_max=5))
@testing.for_dtypes(["e", "f", "d"])
@numpy_cupy_allclose(scipy_name="scp", rtol=1e-7, atol=1e-10)
def test_sph_harm_y(self, xp, scp, dtype, m, n):
theta = xp.linspace(0, cp.pi)
phi = xp.linspace(0, 2 * cp.pi)
theta, phi = xp.meshgrid(theta, phi)
return scp.special.sph_harm_y(n, m, theta, phi)

0 comments on commit fc96072

Please sign in to comment.