Skip to content

Commit 7d2ae7a

Browse files
committed
ENH: return tuples not lists from functions changed in 2025.12
- broadcast_arrays - meshgrid - __array_namespace_info__().devices()
1 parent c303adc commit 7d2ae7a

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

array_api_strict/_creation_functions.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def linspace(
304304
)
305305

306306

307-
def meshgrid(*arrays: Array, indexing: Literal["xy", "ij"] = "xy") -> list[Array]:
307+
def meshgrid(*arrays: Array, indexing: Literal["xy", "ij"] = "xy") -> tuple[Array, ...]:
308308
"""
309309
Array API compatible wrapper for :py:func:`np.meshgrid <numpy.meshgrid>`.
310310
@@ -327,10 +327,12 @@ def meshgrid(*arrays: Array, indexing: Literal["xy", "ij"] = "xy") -> list[Array
327327
else:
328328
device = None
329329

330-
return [
330+
typ = list if get_array_api_strict_flags()['api_version'] < '2025.12' else tuple
331+
332+
return typ(
331333
Array._new(array, device=device)
332334
for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)
333-
]
335+
)
334336

335337

336338
def ones(

array_api_strict/_data_type_functions.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,19 @@ def astype(
4949
return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy), device=device)
5050

5151

52-
def broadcast_arrays(*arrays: Array) -> list[Array]:
52+
def broadcast_arrays(*arrays: Array) -> tuple[Array, ...]:
5353
"""
5454
Array API compatible wrapper for :py:func:`np.broadcast_arrays <numpy.broadcast_arrays>`.
5555
5656
See its docstring for more information.
5757
"""
5858
from ._array_object import Array
5959

60-
return [
60+
typ = list if get_array_api_strict_flags()['api_version'] < '2025.12' else tuple
61+
62+
return typ(
6163
Array._new(array, device=arrays[0].device) for array in np.broadcast_arrays(*[a._array for a in arrays])
62-
]
64+
)
6365

6466

6567
def broadcast_to(x: Array, /, shape: tuple[int, ...]) -> Array:

array_api_strict/_info.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,5 +130,8 @@ def dtypes(
130130
raise ValueError(f"unsupported kind: {kind!r}")
131131

132132
@requires_api_version('2023.12')
133-
def devices(self) -> list[Device]:
134-
return list(ALL_DEVICES)
133+
def devices(self) -> tuple[Device]:
134+
if get_array_api_strict_flags()['api_version'] < '2025.12':
135+
return list(ALL_DEVICES)
136+
else:
137+
return tuple(ALL_DEVICES)

0 commit comments

Comments
 (0)