1212import math
1313import sys
1414import warnings
15+ from collections .abc import Collection , Hashable
16+ from functools import lru_cache
1517from types import NoneType
1618from typing import (
1719 TYPE_CHECKING ,
5658_API_VERSIONS : Final = _API_VERSIONS_OLD | frozenset ({"2024.12" })
5759
5860
61+ @lru_cache (100 )
62+ def _issubclass_fast (cls : type , modname : str , clsname : str ) -> bool :
63+ try :
64+ mod = sys .modules [modname ]
65+ except KeyError :
66+ return False
67+ parent_cls = getattr (mod , clsname )
68+ return issubclass (cls , parent_cls )
69+
70+
5971def _is_jax_zero_gradient_array (x : object ) -> TypeGuard [_ZeroGradientArray ]:
6072 """Return True if `x` is a zero-gradient array.
6173
6274 These arrays are a design quirk of Jax that may one day be removed.
6375 See https://github.com/google/jax/issues/20620.
6476 """
65- if "numpy" not in sys .modules or "jax" not in sys .modules :
77+ # Fast exit
78+ try :
79+ dtype = x .dtype # type: ignore[attr-defined]
80+ except AttributeError :
81+ return False
82+ cls = cast (Hashable , type (dtype ))
83+ if not _issubclass_fast (cls , "numpy.dtypes" , "VoidDType" ):
6684 return False
6785
68- import jax
69- import numpy as np
86+ if " jax" not in sys . modules :
87+ return False
7088
71- jax_float0 = cast ("np.dtype[np.void]" , jax .float0 )
72- return (
73- isinstance (x , np .ndarray )
74- and cast ("npt.NDArray[np.void]" , x ).dtype == jax_float0
75- )
89+ import jax
90+ # jax.float0 is a np.dtype([('float0', 'V')])
91+ return dtype == jax .float0
7692
7793
7894def is_numpy_array (x : object ) -> TypeIs [npt .NDArray [Any ]]:
@@ -96,15 +112,12 @@ def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]:
96112 is_jax_array
97113 is_pydata_sparse_array
98114 """
99- # Avoid importing NumPy if it isn't already
100- if "numpy" not in sys .modules :
101- return False
102-
103- import numpy as np
104-
105115 # TODO: Should we reject ndarray subclasses?
106- return (isinstance (x , (np .ndarray , np .generic ))
107- and not _is_jax_zero_gradient_array (x )) # pyright: ignore[reportUnknownArgumentType] # fmt: skip
116+ cls = cast (Hashable , type (x ))
117+ return (
118+ _issubclass_fast (cls , "numpy" , "ndarray" )
119+ or _issubclass_fast (cls , "numpy" , "generic" )
120+ ) and not _is_jax_zero_gradient_array (x )
108121
109122
110123def is_cupy_array (x : object ) -> bool :
@@ -128,14 +141,8 @@ def is_cupy_array(x: object) -> bool:
128141 is_jax_array
129142 is_pydata_sparse_array
130143 """
131- # Avoid importing CuPy if it isn't already
132- if "cupy" not in sys .modules :
133- return False
134-
135- import cupy as cp
136-
137- # TODO: Should we reject ndarray subclasses?
138- return isinstance (x , cp .ndarray ) # pyright: ignore[reportUnknownMemberType]
144+ cls = cast (Hashable , type (x ))
145+ return _issubclass_fast (cls , "cupy" , "ndarray" )
139146
140147
141148def is_torch_array (x : object ) -> TypeIs [torch .Tensor ]:
@@ -156,14 +163,8 @@ def is_torch_array(x: object) -> TypeIs[torch.Tensor]:
156163 is_jax_array
157164 is_pydata_sparse_array
158165 """
159- # Avoid importing torch if it isn't already
160- if "torch" not in sys .modules :
161- return False
162-
163- import torch
164-
165- # TODO: Should we reject ndarray subclasses?
166- return isinstance (x , torch .Tensor )
166+ cls = cast (Hashable , type (x ))
167+ return _issubclass_fast (cls , "torch" , "Tensor" )
167168
168169
169170def is_ndonnx_array (x : object ) -> TypeIs [ndx .Array ]:
@@ -185,13 +186,8 @@ def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]:
185186 is_jax_array
186187 is_pydata_sparse_array
187188 """
188- # Avoid importing torch if it isn't already
189- if "ndonnx" not in sys .modules :
190- return False
191-
192- import ndonnx as ndx
193-
194- return isinstance (x , ndx .Array )
189+ cls = cast (Hashable , type (x ))
190+ return _issubclass_fast (cls , "ndonnx" , "Array" )
195191
196192
197193def is_dask_array (x : object ) -> TypeIs [da .Array ]:
@@ -213,13 +209,8 @@ def is_dask_array(x: object) -> TypeIs[da.Array]:
213209 is_jax_array
214210 is_pydata_sparse_array
215211 """
216- # Avoid importing dask if it isn't already
217- if "dask.array" not in sys .modules :
218- return False
219-
220- import dask .array
221-
222- return isinstance (x , dask .array .Array )
212+ cls = cast (Hashable , type (x ))
213+ return _issubclass_fast (cls , "dask.array" , "Array" )
223214
224215
225216def is_jax_array (x : object ) -> TypeIs [jax .Array ]:
@@ -242,13 +233,8 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]:
242233 is_dask_array
243234 is_pydata_sparse_array
244235 """
245- # Avoid importing jax if it isn't already
246- if "jax" not in sys .modules :
247- return False
248-
249- import jax
250-
251- return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
236+ cls = cast (Hashable , type (x ))
237+ return _issubclass_fast (cls , "jax" , "Array" ) or _is_jax_zero_gradient_array (x )
252238
253239
254240def is_pydata_sparse_array (x : object ) -> TypeIs [sparse .SparseArray ]:
@@ -271,14 +257,9 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
271257 is_dask_array
272258 is_jax_array
273259 """
274- # Avoid importing jax if it isn't already
275- if "sparse" not in sys .modules :
276- return False
277-
278- import sparse
279-
280260 # TODO: Account for other backends.
281- return isinstance (x , sparse .SparseArray )
261+ cls = cast (Hashable , type (x ))
262+ return _issubclass_fast (cls , "sparse" , "SparseArray" )
282263
283264
284265def is_array_api_obj (x : object ) -> TypeGuard [_ArrayApiObj ]:
@@ -297,13 +278,23 @@ def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]:
297278 is_jax_array
298279 """
299280 return (
300- is_numpy_array (x )
301- or is_cupy_array (x )
302- or is_torch_array (x )
303- or is_dask_array (x )
304- or is_jax_array (x )
305- or is_pydata_sparse_array (x )
306- or hasattr (x , "__array_namespace__" )
281+ hasattr (x , '__array_namespace__' )
282+ or _is_array_api_cls (cast (Hashable , type (x )))
283+ )
284+
285+
286+ @lru_cache (100 )
287+ def _is_array_api_cls (cls : type ) -> bool :
288+ return (
289+ # TODO: drop support for numpy<2 which didn't have __array_namespace__
290+ _issubclass_fast (cls , "numpy" , "ndarray" )
291+ or _issubclass_fast (cls , "numpy" , "generic" )
292+ or _issubclass_fast (cls , "cupy" , "ndarray" )
293+ or _issubclass_fast (cls , "torch" , "Tensor" )
294+ or _issubclass_fast (cls , "dask.array" , "Array" )
295+ or _issubclass_fast (cls , "sparse" , "SparseArray" )
296+ # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__
297+ or _issubclass_fast (cls , "jax" , "Array" )
307298 )
308299
309300
@@ -312,6 +303,7 @@ def _compat_module_name() -> str:
312303 return __name__ .removesuffix (".common._helpers" )
313304
314305
306+ @lru_cache (100 )
315307def is_numpy_namespace (xp : Namespace ) -> bool :
316308 """
317309 Returns True if `xp` is a NumPy namespace.
@@ -333,6 +325,7 @@ def is_numpy_namespace(xp: Namespace) -> bool:
333325 return xp .__name__ in {"numpy" , _compat_module_name () + ".numpy" }
334326
335327
328+ @lru_cache (100 )
336329def is_cupy_namespace (xp : Namespace ) -> bool :
337330 """
338331 Returns True if `xp` is a CuPy namespace.
@@ -354,6 +347,7 @@ def is_cupy_namespace(xp: Namespace) -> bool:
354347 return xp .__name__ in {"cupy" , _compat_module_name () + ".cupy" }
355348
356349
350+ @lru_cache (100 )
357351def is_torch_namespace (xp : Namespace ) -> bool :
358352 """
359353 Returns True if `xp` is a PyTorch namespace.
@@ -394,6 +388,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool:
394388 return xp .__name__ == "ndonnx"
395389
396390
391+ @lru_cache (100 )
397392def is_dask_namespace (xp : Namespace ) -> bool :
398393 """
399394 Returns True if `xp` is a Dask namespace.
@@ -934,6 +929,19 @@ def size(x: HasShape[float | None]) -> int | None:
934929 return None if math .isnan (out ) else cast (int , out )
935930
936931
932+ @lru_cache (100 )
933+ def _is_writeable_cls (cls : type ) -> bool | None :
934+ if (
935+ _issubclass_fast (cls , "numpy" , "generic" )
936+ or _issubclass_fast (cls , "jax" , "Array" )
937+ or _issubclass_fast (cls , "sparse" , "SparseArray" )
938+ ):
939+ return False
940+ if _is_array_api_cls (cls ):
941+ return True
942+ return None
943+
944+
937945def is_writeable_array (x : object ) -> TypeGuard [_ArrayApiObj ]:
938946 """
939947 Return False if ``x.__setitem__`` is expected to raise; True otherwise.
@@ -944,11 +952,32 @@ def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]:
944952 As there is no standard way to check if an array is writeable without actually
945953 writing to it, this function blindly returns True for all unknown array types.
946954 """
947- if is_numpy_array (x ):
948- return x .flags .writeable
949- if is_jax_array (x ) or is_pydata_sparse_array (x ):
955+ cls = cast (Hashable , type (x ))
956+ if _issubclass_fast (cls , "numpy" , "ndarray" ):
957+ return cast ("npt.NDArray" , x ).flags .writeable
958+ res = _is_writeable_cls (cls )
959+ if res is not None :
960+ return res
961+ return hasattr (x , '__array_namespace__' )
962+
963+
964+ @lru_cache (100 )
965+ def _is_lazy_cls (cls : type ) -> bool | None :
966+ if (
967+ _issubclass_fast (cls , "numpy" , "ndarray" )
968+ or _issubclass_fast (cls , "numpy" , "generic" )
969+ or _issubclass_fast (cls , "cupy" , "ndarray" )
970+ or _issubclass_fast (cls , "torch" , "Tensor" )
971+ or _issubclass_fast (cls , "sparse" , "SparseArray" )
972+ ):
950973 return False
951- return is_array_api_obj (x )
974+ if (
975+ _issubclass_fast (cls , "jax" , "Array" )
976+ or _issubclass_fast (cls , "dask.array" , "Array" )
977+ or _issubclass_fast (cls , "ndonnx" , "Array" )
978+ ):
979+ return True
980+ return None
952981
953982
954983def is_lazy_array (x : object ) -> TypeGuard [_ArrayApiObj ]:
@@ -964,14 +993,6 @@ def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
964993 This function errs on the side of caution for array types that may or may not be
965994 lazy, e.g. JAX arrays, by always returning True for them.
966995 """
967- if (
968- is_numpy_array (x )
969- or is_cupy_array (x )
970- or is_torch_array (x )
971- or is_pydata_sparse_array (x )
972- ):
973- return False
974-
975996 # **JAX note:** while it is possible to determine if you're inside or outside
976997 # jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
977998 # as we do below for unknown arrays, this is not recommended by JAX best practices.
@@ -981,10 +1002,14 @@ def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
9811002 # compatibility, is highly detrimental to performance as the whole graph will end
9821003 # up being computed multiple times.
9831004
984- if is_jax_array (x ) or is_dask_array (x ) or is_ndonnx_array (x ):
985- return True
1005+ # Note: skipping reclassification of JAX zero gradient arrays, as one will
1006+ # exclusively get them once they leave a jax.grad JIT context.
1007+ cls = cast (Hashable , type (x ))
1008+ res = _is_lazy_cls (cls )
1009+ if res is not None :
1010+ return res
9861011
987- if not is_array_api_obj ( x ):
1012+ if not hasattr ( x , "__array_namespace__" ):
9881013 return False
9891014
9901015 # Unknown Array API compatible object. Note that this test may have dire consequences
@@ -1037,7 +1062,7 @@ def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
10371062 "to_device" ,
10381063]
10391064
1040- _all_ignore = [" sys" , " math" , " inspect" , " warnings" ]
1065+ _all_ignore = ['lru_cache' , ' sys' , ' math' , ' inspect' , ' warnings' ]
10411066
10421067def __dir__ () -> list [str ]:
10431068 return __all__
0 commit comments