Skip to content

Commit

Permalink
Merge pull request cupy#5801 from kmaehashi/fix-test-failure
Browse files Browse the repository at this point in the history
Fix test skip issue
  • Loading branch information
kmaehashi authored Sep 29, 2021
2 parents 2db8567 + f491108 commit 737621e
Showing 1 changed file with 27 additions and 39 deletions.
66 changes: 27 additions & 39 deletions tests/cupyx_tests/scipy_tests/fft_tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
import cupyx.scipy.fft as cp_fft


_irfft_skip_condition = (
int(cp.cuda.device.get_compute_capability()) < 70 and
10020 >= cp.cuda.runtime.runtimeGetVersion() >= 10010)


def _fft_module(xp):
if xp is not np:
return cp_fft
Expand Down Expand Up @@ -1141,8 +1146,7 @@ def test_rfft2_backend(self, xp, dtype):
testing.assert_array_equal(x, x_orig)
return _correct_np_dtype(xp, dtype, out)

@pytest.mark.skipif(int(cp.cuda.device.get_compute_capability()) < 70 and
10020 >= cp.cuda.runtime.runtimeGetVersion() >= 10010,
@pytest.mark.skipif(_irfft_skip_condition,
reason="Known to fail with Pascal or older")
@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(rtol=1e-4, atol=1e-7, accept_error=ValueError,
Expand All @@ -1156,8 +1160,7 @@ def test_irfft2(self, xp, dtype):

return _correct_np_dtype(xp, dtype, out)

@pytest.mark.skipif(int(cp.cuda.device.get_compute_capability()) < 70 and
10020 >= cp.cuda.runtime.runtimeGetVersion() >= 10010,
@pytest.mark.skipif(_irfft_skip_condition,
reason="Known to fail with Pascal or older")
@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(rtol=1e-4, atol=1e-7, accept_error=ValueError,
Expand Down Expand Up @@ -1252,8 +1255,7 @@ def test_irfft2_plan_manager(self, xp, dtype):

return _correct_np_dtype(xp, dtype, out)

@pytest.mark.skipif(int(cp.cuda.device.get_compute_capability()) < 70 and
10020 >= cp.cuda.runtime.runtimeGetVersion() >= 10010,
@pytest.mark.skipif(_irfft_skip_condition,
reason="Known to fail with Pascal or older")
@testing.with_requires('scipy>=1.4.0')
@testing.for_all_dtypes(no_complex=True)
Expand Down Expand Up @@ -1402,8 +1404,7 @@ def test_rfftn_backend(self, xp, dtype):
testing.assert_array_equal(x, x_orig)
return _correct_np_dtype(xp, dtype, out)

@pytest.mark.skipif(int(cp.cuda.device.get_compute_capability()) < 70 and
10020 >= cp.cuda.runtime.runtimeGetVersion() >= 10010,
@pytest.mark.skipif(_irfft_skip_condition,
reason="Known to fail with Pascal or older")
@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(rtol=1e-4, atol=1e-7, accept_error=ValueError,
Expand All @@ -1417,8 +1418,7 @@ def test_irfftn(self, xp, dtype):

return _correct_np_dtype(xp, dtype, out)

@pytest.mark.skipif(int(cp.cuda.device.get_compute_capability()) < 70 and
10020 >= cp.cuda.runtime.runtimeGetVersion() >= 10010,
@pytest.mark.skipif(_irfft_skip_condition,
reason="Known to fail with Pascal or older")
@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(rtol=1e-4, atol=1e-7, accept_error=ValueError,
Expand Down Expand Up @@ -1513,8 +1513,7 @@ def test_irfftn_plan_manager(self, xp, dtype):

return _correct_np_dtype(xp, dtype, out)

@pytest.mark.skipif(int(cp.cuda.device.get_compute_capability()) < 70 and
10020 >= cp.cuda.runtime.runtimeGetVersion() >= 10010,
@pytest.mark.skipif(_irfft_skip_condition,
reason="Known to fail with Pascal or older")
@testing.with_requires('scipy>=1.4.0')
@testing.for_all_dtypes(no_complex=True)
Expand Down Expand Up @@ -1651,35 +1650,30 @@ class TestHfft2:
def setUp(self):
_skip_forward_backward(self.norm)

@pytest.mark.skipif(_irfft_skip_condition,
reason="Known to fail with Pascal or older")
@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(rtol=4e-4, atol=1e-7, accept_error=ValueError,
contiguous_check=False)
def test_hfft2(self, xp, dtype):
x = testing.shaped_random(self.shape, xp, dtype)
x_orig = x.copy()
with pytest.warns(None) as record:
out = _fft_module(xp).hfft2(x, s=self.s, axes=self.axes,
norm=self.norm)
if len(record) == 1 and 'issue of cuFFT' in record[0].message:
# CUDA 10.2 bug
pytest.skip(record[0].message)
out = _fft_module(xp).hfft2(x, s=self.s, axes=self.axes,
norm=self.norm)
testing.assert_array_equal(x, x_orig)
return _correct_np_dtype(xp, dtype, out)

@pytest.mark.skipif(_irfft_skip_condition,
reason="Known to fail with Pascal or older")
@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(rtol=4e-4, atol=1e-7, accept_error=ValueError,
contiguous_check=False)
def test_hfft2_backend(self, xp, dtype):
x = testing.shaped_random(self.shape, xp, dtype)
x_orig = x.copy()
backend = 'scipy' if xp is np else cp_fft
with pytest.warns(None) as record:
with scipy_fft.set_backend(backend):
out = scipy_fft.hfft2(
x, s=self.s, axes=self.axes, norm=self.norm)
if len(record) == 1 and 'issue of cuFFT' in record[0].message:
# CUDA 10.2 bug
pytest.skip(record[0].message)
with scipy_fft.set_backend(backend):
out = scipy_fft.hfft2(x, s=self.s, axes=self.axes, norm=self.norm)
testing.assert_array_equal(x, x_orig)
return _correct_np_dtype(xp, dtype, out)

Expand Down Expand Up @@ -1732,36 +1726,30 @@ class TestHfftn:
def setUp(self):
_skip_forward_backward(self.norm)

@pytest.mark.skipif(_irfft_skip_condition,
reason="Known to fail with Pascal or older")
@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(rtol=4e-4, atol=1e-5, accept_error=ValueError,
contiguous_check=False)
def test_hfftn(self, xp, dtype):
x = testing.shaped_random(self.shape, xp, dtype)
x_orig = x.copy()

with pytest.warns(None) as record:
out = _fft_module(xp).hfftn(x, s=self.s, axes=self.axes,
norm=self.norm)
if len(record) == 1 and 'issue of cuFFT' in record[0].message:
# CUDA 10.2 bug
pytest.skip(record[0].message)
out = _fft_module(xp).hfftn(x, s=self.s, axes=self.axes,
norm=self.norm)
testing.assert_array_equal(x, x_orig)
return _correct_np_dtype(xp, dtype, out)

@pytest.mark.skipif(_irfft_skip_condition,
reason="Known to fail with Pascal or older")
@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(rtol=4e-4, atol=1e-5, accept_error=ValueError,
contiguous_check=False)
def test_hfftn_backend(self, xp, dtype):
x = testing.shaped_random(self.shape, xp, dtype)
x_orig = x.copy()
backend = 'scipy' if xp is np else cp_fft
with pytest.warns(None) as record:
with scipy_fft.set_backend(backend):
out = scipy_fft.hfftn(
x, s=self.s, axes=self.axes, norm=self.norm)
if len(record) == 1 and 'issue of cuFFT' in record[0].message:
# CUDA 10.2 bug
pytest.skip(record[0].message)
with scipy_fft.set_backend(backend):
out = scipy_fft.hfftn(x, s=self.s, axes=self.axes, norm=self.norm)
testing.assert_array_equal(x, x_orig)
return _correct_np_dtype(xp, dtype, out)

Expand Down

0 comments on commit 737621e

Please sign in to comment.