Skip to content

Commit 4355ab8

Browse files
committed
MAINT: link to pytorch issue for negative indices
1 parent 92662a6 commit 4355ab8

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,8 @@ def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: obje
815815
if x.ndim != 1:
816816
raise ValueError("axis must be specified when ndim > 1")
817817
axis = 0
818+
# torch does not support negative indices,
819+
# see https://github.com/pytorch/pytorch/issues/146211
818820
return torch.index_select(
819821
x,
820822
axis,
@@ -824,6 +826,8 @@ def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: obje
824826

825827

826828
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
829+
# torch does not support negative indices,
830+
# see https://github.com/pytorch/pytorch/issues/146211
827831
return torch.take_along_dim(
828832
x,
829833
torch.where(indices < 0, indices + x.shape[axis], indices),

0 commit comments

Comments
 (0)