@@ -95,19 +95,19 @@ def test_mutator_descriptors():
9595 with pytest .raises (TypeError ) as excinfo :
9696 m .fixed_mutator_r (zc )
9797 assert (
98- '(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[5, 6]",'
98+ '(arg0: typing.Annotated[numpy.typing.NDArray[ numpy.float32] , "[5, 6]",'
9999 ' "flags.writeable", "flags.c_contiguous"]) -> None' in str (excinfo .value )
100100 )
101101 with pytest .raises (TypeError ) as excinfo :
102102 m .fixed_mutator_c (zr )
103103 assert (
104- '(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[5, 6]",'
104+ '(arg0: typing.Annotated[numpy.typing.NDArray[ numpy.float32] , "[5, 6]",'
105105 ' "flags.writeable", "flags.f_contiguous"]) -> None' in str (excinfo .value )
106106 )
107107 with pytest .raises (TypeError ) as excinfo :
108108 m .fixed_mutator_a (np .array ([[1 , 2 ], [3 , 4 ]], dtype = "float32" ))
109109 assert (
110- '(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[5, 6]", "flags.writeable"]) -> None'
110+ '(arg0: typing.Annotated[numpy.typing.NDArray[ numpy.float32] , "[5, 6]", "flags.writeable"]) -> None'
111111 in str (excinfo .value )
112112 )
113113 zr .flags .writeable = False
@@ -202,7 +202,7 @@ def test_negative_stride_from_python(msg):
202202 msg (excinfo .value )
203203 == """
204204 double_threer(): incompatible function arguments. The following argument types are supported:
205- 1. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[1, 3]", "flags.writeable"]) -> None
205+ 1. (arg0: typing.Annotated[numpy.typing.NDArray[ numpy.float32] , "[1, 3]", "flags.writeable"]) -> None
206206
207207 Invoked with: """
208208 + repr (np .array ([5.0 , 4.0 , 3.0 ], dtype = "float32" ))
@@ -214,7 +214,7 @@ def test_negative_stride_from_python(msg):
214214 msg (excinfo .value )
215215 == """
216216 double_threec(): incompatible function arguments. The following argument types are supported:
217- 1. (arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[3, 1]", "flags.writeable"]) -> None
217+ 1. (arg0: typing.Annotated[numpy.typing.NDArray[ numpy.float32] , "[3, 1]", "flags.writeable"]) -> None
218218
219219 Invoked with: """
220220 + repr (np .array ([7.0 , 4.0 , 1.0 ], dtype = "float32" ))
@@ -818,3 +818,22 @@ def test_custom_operator_new():
818818 o = m .CustomOperatorNew ()
819819 np .testing .assert_allclose (o .a , 0.0 )
820820 np .testing .assert_allclose (o .b .diagonal (), 1.0 )
821+
822+
823+ def test_arraylike_signature (doc ):
824+ assert doc (m .round_trip_vector ) == (
825+ 'round_trip_vector(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[m, 1]"])'
826+ ' -> typing.Annotated[numpy.typing.NDArray[numpy.float32], "[m, 1]"]'
827+ )
828+ assert doc (m .round_trip_dense ) == (
829+ 'round_trip_dense(arg0: typing.Annotated[numpy.typing.ArrayLike, numpy.float32, "[m, n]"])'
830+ ' -> typing.Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]"]'
831+ )
832+ assert doc (m .round_trip_dense_ref ) == (
833+ 'round_trip_dense_ref(arg0: typing.Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]", "flags.writeable", "flags.c_contiguous"])'
834+ ' -> typing.Annotated[numpy.typing.NDArray[numpy.float32], "[m, n]", "flags.writeable", "flags.c_contiguous"]'
835+ )
836+ m .round_trip_vector ([1.0 , 2.0 ])
837+ m .round_trip_dense ([[1.0 , 2.0 ], [3.0 , 4.0 ]])
838+ with pytest .raises (TypeError , match = "incompatible function arguments" ):
839+ m .round_trip_dense_ref ([[1.0 , 2.0 ], [3.0 , 4.0 ]])
0 commit comments