Skip to content

Commit 9842040

Browse files
authored
Merge pull request #179 from ev-br/tuples_not_lists
ENH: return tuples not lists from functions changed in 2025.12
2 parents d1de467 + 7d2ae7a commit 9842040

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

array_api_strict/_creation_functions.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def linspace(
309309
)
310310

311311

312-
def meshgrid(*arrays: Array, indexing: Literal["xy", "ij"] = "xy") -> list[Array]:
312+
def meshgrid(*arrays: Array, indexing: Literal["xy", "ij"] = "xy") -> tuple[Array, ...]:
313313
"""
314314
Array API compatible wrapper for :py:func:`np.meshgrid <numpy.meshgrid>`.
315315
@@ -332,10 +332,12 @@ def meshgrid(*arrays: Array, indexing: Literal["xy", "ij"] = "xy") -> list[Array
332332
else:
333333
device = None
334334

335-
return [
335+
typ = list if get_array_api_strict_flags()['api_version'] < '2025.12' else tuple
336+
337+
return typ(
336338
Array._new(array, device=device)
337339
for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)
338-
]
340+
)
339341

340342

341343
def ones(

array_api_strict/_data_type_functions.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,19 @@ def astype(
4949
return Array._new(x._array.astype(dtype=dtype._np_dtype, copy=copy), device=device)
5050

5151

52-
def broadcast_arrays(*arrays: Array) -> list[Array]:
52+
def broadcast_arrays(*arrays: Array) -> tuple[Array, ...]:
5353
"""
5454
Array API compatible wrapper for :py:func:`np.broadcast_arrays <numpy.broadcast_arrays>`.
5555
5656
See its docstring for more information.
5757
"""
5858
from ._array_object import Array
5959

60-
return [
60+
typ = list if get_array_api_strict_flags()['api_version'] < '2025.12' else tuple
61+
62+
return typ(
6163
Array._new(array, device=arrays[0].device) for array in np.broadcast_arrays(*[a._array for a in arrays])
62-
]
64+
)
6365

6466

6567
@requires_api_version("2025.12")

array_api_strict/_info.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,5 +130,8 @@ def dtypes(
130130
raise ValueError(f"unsupported kind: {kind!r}")
131131

132132
@requires_api_version('2023.12')
133-
def devices(self) -> list[Device]:
134-
return list(ALL_DEVICES)
133+
def devices(self) -> tuple[Device]:
134+
if get_array_api_strict_flags()['api_version'] < '2025.12':
135+
return list(ALL_DEVICES)
136+
else:
137+
return tuple(ALL_DEVICES)

0 commit comments

Comments
 (0)