Skip to content

Commit c15f738

Browse files
vtavanaantonwolfy
andauthored
update fft tests (#2502)
update FFT tests. --------- Co-authored-by: Anton <[email protected]>
1 parent 78597f9 commit c15f738

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

dpnp/tests/test_fft.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -563,9 +563,13 @@ def test_basic(self, dtype, n, norm):
563563

564564
result = dpnp.fft.hfft(ia, n=n, norm=norm)
565565
expected = numpy.fft.hfft(a, n=n, norm=norm)
566-
# check_only_type_kind=True since NumPy always returns float64
567-
# but dpnp return float32 if input is float32
568-
assert_dtype_allclose(result, expected, check_only_type_kind=True)
566+
# TODO: change to the commented line when mkl_fft-2.0.0 is released
567+
# and being used with Intel NumPy >= 2.0.0
568+
flag = True
569+
# flag = True if numpy_version() < "2.0.0" else False
570+
assert_dtype_allclose(
571+
result, expected, factor=24, check_only_type_kind=flag
572+
)
569573

570574
@pytest.mark.parametrize(
571575
"dtype", get_all_dtypes(no_none=True, no_complex=True)
@@ -579,7 +583,7 @@ def test_inverse(self, dtype, n, norm):
579583
result = dpnp.fft.ihfft(ia, n=n, norm=norm)
580584
expected = numpy.fft.ihfft(a, n=n, norm=norm)
581585
flag = True if numpy_version() < "2.0.0" else False
582-
assert_dtype_allclose(result, expected, check_only_type_kind=True)
586+
assert_dtype_allclose(result, expected, check_only_type_kind=flag)
583587

584588
def test_error(self):
585589
a = dpnp.ones(11)
@@ -600,14 +604,18 @@ class TestIrfft:
600604
@pytest.mark.parametrize("n", [None, 5, 18])
601605
@pytest.mark.parametrize("norm", [None, "backward", "forward", "ortho"])
602606
def test_basic(self, dtype, n, norm):
603-
a = generate_random_numpy_array(11)
607+
a = generate_random_numpy_array(11, dtype=dtype)
604608
ia = dpnp.array(a)
605609

606610
result = dpnp.fft.irfft(ia, n=n, norm=norm)
607611
expected = numpy.fft.irfft(a, n=n, norm=norm)
608-
# check_only_type_kind=True since NumPy always returns float64
609-
# but dpnp return float32 if input is float32
610-
assert_dtype_allclose(result, expected, check_only_type_kind=True)
612+
# TODO: change to the commented line when mkl_fft-2.0.0 is released
613+
# and being used with Intel NumPy >= 2.0.0
614+
flag = True
615+
# flag = True if numpy_version() < "2.0.0" else False
616+
assert_dtype_allclose(
617+
result, expected, factor=24, check_only_type_kind=flag
618+
)
611619

612620
@pytest.mark.parametrize("dtype", get_complex_dtypes())
613621
@pytest.mark.parametrize("n", [None, 5, 8])
@@ -771,8 +779,11 @@ def test_float16(self):
771779

772780
expected = numpy.fft.rfft(a)
773781
result = dpnp.fft.rfft(ia)
774-
# check_only_type_kind=True since Intel NumPy returns complex128
775-
assert_dtype_allclose(result, expected, check_only_type_kind=True)
782+
# TODO: change to the commented line when mkl_fft-2.0.0 is released
783+
# and being used with Intel NumPy >= 2.0.0
784+
flag = True
785+
# flag = True if numpy_version() < "2.0.0" else False
786+
assert_dtype_allclose(result, expected, check_only_type_kind=flag)
776787

777788
@testing.with_requires("numpy>=2.0.0")
778789
@pytest.mark.parametrize("xp", [numpy, dpnp])
@@ -954,7 +965,8 @@ def test_1d_array(self):
954965

955966
result = dpnp.fft.irfftn(ia)
956967
expected = numpy.fft.irfftn(a)
957-
# TODO: change to the commented line when mkl_fft-gh-180 is merged
968+
# TODO: change to the commented line when mkl_fft-2.0.0 is released
969+
# and being used with Intel NumPy >= 2.0.0
958970
flag = True
959971
# flag = True if numpy_version() < "2.0.0" else False
960972
assert_dtype_allclose(result, expected, check_only_type_kind=flag)

0 commit comments

Comments
 (0)