Skip to content

Commit 2b92f13

Browse files
Numpy does not have float128 on all platforms
1 parent a03458e commit 2b92f13

File tree

1 file changed

+26
-22
lines changed

1 file changed

+26
-22
lines changed

mkl_fft/_float_utils.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

27-
from numpy import (half, float32, asarray, ndarray,
28-
longdouble, float64, longcomplex, complex_, float128, complex256)
27+
import numpy as np
2928

3029
__all__ = ['__upcast_float16_array', '__downcast_float128_array', '__supported_array_or_not_implemented']
3130

@@ -35,18 +34,18 @@ def __upcast_float16_array(x):
3534
instead of float64, as mkl_fft would do"""
3635
if hasattr(x, "dtype"):
3736
xdt = x.dtype
38-
if xdt == half:
37+
if xdt == np.half:
3938
# no half-precision routines, so convert to single precision
40-
return asarray(x, dtype=float32)
41-
if xdt == longdouble and not xdt == float64:
39+
return np.asarray(x, dtype=np.float32)
40+
if xdt == np.longdouble and not xdt == np.float64:
4241
raise ValueError("type %s is not supported" % xdt)
43-
if not isinstance(x, ndarray):
44-
__x = asarray(x)
42+
if not isinstance(x, np.ndarray):
43+
__x = np.asarray(x)
4544
xdt = __x.dtype
46-
if xdt == half:
45+
if xdt == np.half:
4746
# no half-precision routines, so convert to single precision
48-
return asarray(__x, dtype=float32)
49-
if xdt == longdouble and not xdt == float64:
47+
return np.asarray(__x, dtype=np.float32)
48+
if xdt == np.longdouble and not xdt == np.float64:
5049
raise ValueError("type %s is not supported" % xdt)
5150
return __x
5251
return x
@@ -58,17 +57,17 @@ def __downcast_float128_array(x):
5857
complex128, instead of raising an error"""
5958
if hasattr(x, "dtype"):
6059
xdt = x.dtype
61-
if xdt == longdouble and not xdt == float64:
62-
return asarray(x, dtype=float64)
63-
elif xdt == longcomplex and not xdt == complex_:
64-
return asarray(x, dtype=complex_)
65-
if not isinstance(x, ndarray):
66-
__x = asarray(x)
60+
if xdt == np.longdouble and not xdt == np.float64:
61+
return np.asarray(x, dtype=np.float64)
62+
elif xdt == np.longcomplex and not xdt == np.complex_:
63+
return np.asarray(x, dtype=np.complex_)
64+
if not isinstance(x, np.ndarray):
65+
__x = np.asarray(x)
6766
xdt = __x.dtype
68-
if xdt == longdouble and not xdt == float64:
69-
return asarray(x, dtype=float64)
70-
elif xdt == longcomplex and not xdt == complex_:
71-
return asarray(x, dtype=complex_)
67+
if xdt == np.longdouble and not xdt == np.float64:
68+
return np.asarray(x, dtype=np.float64)
69+
elif xdt == np.longcomplex and not xdt == np.complex_:
70+
return np.asarray(x, dtype=np.complex_)
7271
return __x
7372
return x
7473

@@ -78,7 +77,12 @@ def __supported_array_or_not_implemented(x):
7877
Used in _scipy_fft_backend to convert array to float32,
7978
float64, complex64, or complex128 type or return NotImplemented
8079
"""
81-
__x = asarray(x)
82-
if __x.dtype in [half, float128, complex256]:
80+
__x = np.asarray(x)
81+
black_list = [np.half]
82+
if hasattr(np, 'float128'):
83+
black_list.append(np.float128)
84+
if hasattr(np, 'complex256'):
85+
black_list.append(np.complex256)
86+
if __x.dtype in black_list:
8387
return NotImplemented
8488
return __x

0 commit comments

Comments
 (0)