@@ -71,13 +71,13 @@ def test_take(x, data):
7171)
7272def test_take_along_axis (x , data ):
7373 # TODO
74- # 1. negative axis
7574 # 2. negative indices
7675 # 3. different dtypes for indices
77- axis = data .draw (st .integers (0 , max (x .ndim - 1 , 0 )), label = "axis" )
76+ axis = data .draw (st .integers (- x . ndim , max (x .ndim - 1 , 0 )), label = "axis" )
7877 len_axis = data .draw (st .integers (0 , 2 * x .shape [axis ]), label = "len_axis" )
7978
80- idx_shape = x .shape [:axis ] + (len_axis ,) + x .shape [axis + 1 :]
79+ n_axis = axis + x .ndim if axis < 0 else axis
80+ idx_shape = x .shape [:n_axis ] + (len_axis ,) + x .shape [n_axis + 1 :]
8181 indices = data .draw (
8282 hh .arrays (
8383 shape = idx_shape ,
@@ -94,7 +94,7 @@ def test_take_along_axis(x, data):
9494 ph .assert_shape (
9595 "take_along_axis" ,
9696 out_shape = out .shape ,
97- expected = x .shape [:axis ] + (len_axis ,) + x .shape [axis + 1 :],
97+ expected = x .shape [:n_axis ] + (len_axis ,) + x .shape [n_axis + 1 :],
9898 kw = dict (
9999 x = x ,
100100 indices = indices ,
@@ -103,12 +103,11 @@ def test_take_along_axis(x, data):
103103 )
104104
105105 # value test: notation is from `np.take_along_axis` docstring
106- Ni , Nk = x .shape [:axis ], x .shape [axis + 1 :]
106+ Ni , Nk = x .shape [:n_axis ], x .shape [n_axis + 1 :]
107107 for ii in sh .ndindex (Ni ):
108108 for kk in sh .ndindex (Nk ):
109109 a_1d = x [ii + (slice (None ),) + kk ]
110110 i_1d = indices [ii + (slice (None ),) + kk ]
111111 o_1d = out [ii + (slice (None ),) + kk ]
112112 for j in range (len_axis ):
113113 assert o_1d [j ] == a_1d [i_1d [j ]], f'{ ii = } , { kk = } , { j = } '
114-
0 commit comments