Skip to content

Commit 92662a6

Browse files
committed
ENH: torch: allow negative indices in take()
1 parent a60869a commit 92662a6

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,12 @@ def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: obje
815815
if x.ndim != 1:
816816
raise ValueError("axis must be specified when ndim > 1")
817817
axis = 0
818-
return torch.index_select(x, axis, indices, **kwargs)
818+
return torch.index_select(
819+
x,
820+
axis,
821+
torch.where(indices < 0, indices + x.shape[axis], indices),
822+
**kwargs
823+
)
819824

820825

821826
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:

0 commit comments

Comments
 (0)