Skip to content

Commit a60869a

Browse files
committed
ENH: torch: allow negative indices in take_along_axis
1 parent 35d333a commit a60869a

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,11 @@ def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: obje
819819

820820

821821
def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
822-
return torch.take_along_dim(x, indices, dim=axis)
822+
return torch.take_along_dim(
823+
x,
824+
torch.where(indices < 0, indices + x.shape[axis], indices),
825+
dim=axis
826+
)
823827

824828

825829
def sign(x: Array, /) -> Array:

0 commit comments

Comments
 (0)