@@ -73,22 +73,30 @@ def test_take_along_axis(x, data):
7373 # TODO
7474 # 2. negative indices
7575 # 3. different dtypes for indices
76- axis = data .draw (st .integers (- x .ndim , max (x .ndim - 1 , 0 )), label = "axis" )
77- len_axis = data .draw (st .integers (0 , 2 * x .shape [axis ]), label = "len_axis" )
76+ axis = data .draw (
77+ st .integers (- x .ndim , max (x .ndim - 1 , 0 )) | st .none (),
78+ label = "axis"
79+ )
80+ if axis is None :
81+ axis_kw = {}
82+ n_axis = x .ndim - 1
83+ else :
84+ axis_kw = {"axis" : axis }
85+ n_axis = axis + x .ndim if axis < 0 else axis
7886
79- n_axis = axis + x . ndim if axis < 0 else axis
87+ len_axis = data . draw ( st . integers ( 0 , 2 * x . shape [ n_axis ]), label = "len_axis" )
8088 idx_shape = x .shape [:n_axis ] + (len_axis ,) + x .shape [n_axis + 1 :]
8189 indices = data .draw (
8290 hh .arrays (
8391 shape = idx_shape ,
8492 dtype = dh .default_int ,
85- elements = {"min_value" : 0 , "max_value" : x .shape [axis ]- 1 }
93+ elements = {"min_value" : 0 , "max_value" : x .shape [n_axis ]- 1 }
8694 ),
8795 label = "indices"
8896 )
8997 note (f"{ indices = } { idx_shape = } " )
9098
91- out = xp .take_along_axis (x , indices , axis = axis )
99+ out = xp .take_along_axis (x , indices , ** axis_kw )
92100
93101 ph .assert_dtype ("take_along_axis" , in_dtype = x .dtype , out_dtype = out .dtype )
94102 ph .assert_shape (
0 commit comments