Skip to content

Commit 015aab9

Browse files
Merge pull request #77 from IntelPython/fix-float-utils
Fix float utils
2 parents a03458e + 2faba1b commit 015aab9

File tree

4 files changed

+37
-25
lines changed

4 files changed

+37
-25
lines changed

conda-recipe/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{% set version = "1.3.1" %}
1+
{% set version = "1.3.3" %}
22
{% set buildnumber = 0 %}
33

44
package:

mkl_fft/_float_utils.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

27-
from numpy import (half, float32, asarray, ndarray,
28-
longdouble, float64, longcomplex, complex_, float128, complex256)
27+
import numpy as np
2928

3029
__all__ = ['__upcast_float16_array', '__downcast_float128_array', '__supported_array_or_not_implemented']
3130

@@ -35,18 +34,18 @@ def __upcast_float16_array(x):
3534
instead of float64, as mkl_fft would do"""
3635
if hasattr(x, "dtype"):
3736
xdt = x.dtype
38-
if xdt == half:
37+
if xdt == np.half:
3938
# no half-precision routines, so convert to single precision
40-
return asarray(x, dtype=float32)
41-
if xdt == longdouble and not xdt == float64:
39+
return np.asarray(x, dtype=np.float32)
40+
if xdt == np.longdouble and not xdt == np.float64:
4241
raise ValueError("type %s is not supported" % xdt)
43-
if not isinstance(x, ndarray):
44-
__x = asarray(x)
42+
if not isinstance(x, np.ndarray):
43+
__x = np.asarray(x)
4544
xdt = __x.dtype
46-
if xdt == half:
45+
if xdt == np.half:
4746
# no half-precision routines, so convert to single precision
48-
return asarray(__x, dtype=float32)
49-
if xdt == longdouble and not xdt == float64:
47+
return np.asarray(__x, dtype=np.float32)
48+
if xdt == np.longdouble and not xdt == np.float64:
5049
raise ValueError("type %s is not supported" % xdt)
5150
return __x
5251
return x
@@ -58,17 +57,17 @@ def __downcast_float128_array(x):
5857
complex128, instead of raising an error"""
5958
if hasattr(x, "dtype"):
6059
xdt = x.dtype
61-
if xdt == longdouble and not xdt == float64:
62-
return asarray(x, dtype=float64)
63-
elif xdt == longcomplex and not xdt == complex_:
64-
return asarray(x, dtype=complex_)
65-
if not isinstance(x, ndarray):
66-
__x = asarray(x)
60+
if xdt == np.longdouble and not xdt == np.float64:
61+
return np.asarray(x, dtype=np.float64)
62+
elif xdt == np.longcomplex and not xdt == np.complex_:
63+
return np.asarray(x, dtype=np.complex_)
64+
if not isinstance(x, np.ndarray):
65+
__x = np.asarray(x)
6766
xdt = __x.dtype
68-
if xdt == longdouble and not xdt == float64:
69-
return asarray(x, dtype=float64)
70-
elif xdt == longcomplex and not xdt == complex_:
71-
return asarray(x, dtype=complex_)
67+
if xdt == np.longdouble and not xdt == np.float64:
68+
return np.asarray(x, dtype=np.float64)
69+
elif xdt == np.longcomplex and not xdt == np.complex_:
70+
return np.asarray(x, dtype=np.complex_)
7271
return __x
7372
return x
7473

@@ -78,7 +77,12 @@ def __supported_array_or_not_implemented(x):
7877
Used in _scipy_fft_backend to convert array to float32,
7978
float64, complex64, or complex128 type or return NotImplemented
8079
"""
81-
__x = asarray(x)
82-
if __x.dtype in [half, float128, complex256]:
80+
__x = np.asarray(x)
81+
black_list = [np.half]
82+
if hasattr(np, 'float128'):
83+
black_list.append(np.float128)
84+
if hasattr(np, 'complex256'):
85+
black_list.append(np.complex256)
86+
if __x.dtype in black_list:
8387
return NotImplemented
8488
return __x

mkl_fft/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.3.1'
1+
__version__ = '1.3.3'

mkl_fft/tests/test_interfaces.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,15 @@ def test_numpy_rftn(norm, dtype):
117117
assert np.allclose(x, xx, atol=tol, rtol=tol)
118118

119119

120-
@pytest.mark.parametrize('dtype', [np.float16, np.float128, np.complex256])
120+
def _get_blacklisted_dtypes():
121+
bl_list = []
122+
for dt in ['float16', 'float128', 'complex256']:
123+
if hasattr(np, dt):
124+
bl_list.append(getattr(np, dt))
125+
return bl_list
126+
127+
128+
@pytest.mark.parametrize('dtype', _get_blacklisted_dtypes())
121129
def test_scipy_no_support_for(dtype):
122130
x = np.ones(16, dtype=dtype)
123131
w = mfi.scipy_fft.fft(x)

0 commit comments

Comments
 (0)