@@ -59,95 +59,92 @@ def test_xp(self, xp: ModuleType):
5959 xp_assert_equal (actual , expected )
6060
6161
62- @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no isdtype" )
63- @pytest .mark .parametrize (
64- ("dtype" , "b" , "defined" ),
65- [
66- # Well-defined cases of dtype promotion from Python scalar to Array
67- # bool vs. bool
68- ("bool" , True , True ),
69- # int vs. xp.*int*, xp.float*, xp.complex*
70- ("int16" , 1 , True ),
71- ("uint8" , 1 , True ),
72- ("float32" , 1 , True ),
73- ("float64" , 1 , True ),
74- ("complex64" , 1 , True ),
75- ("complex128" , 1 , True ),
76- # float vs. xp.float, xp.complex
77- ("float32" , 1.0 , True ),
78- ("float64" , 1.0 , True ),
79- ("complex64" , 1.0 , True ),
80- ("complex128" , 1.0 , True ),
81- # complex vs. xp.complex
82- ("complex64" , 1.0j , True ),
83- ("complex128" , 1.0j , True ),
84- # Undefined cases
85- ("bool" , 1 , False ),
86- ("int64" , 1.0 , False ),
87- ("float64" , 1.0j , False ),
88- ],
89- )
90- def test_asarrays_array_vs_scalar (
91- dtype : str , b : int | float | complex , defined : bool , xp : ModuleType
92- ):
93- a = xp .asarray (1 , dtype = getattr (xp , dtype ))
94-
95- xa , xb = asarrays (a , b , xp )
96- assert xa .dtype == a .dtype
97- if defined :
98- assert xb .dtype == a .dtype
99- else :
100- assert xb .dtype == xp .asarray (b ).dtype
101-
102- xbr , xar = asarrays (b , a , xp )
103- assert xar .dtype == xa .dtype
104- assert xbr .dtype == xb .dtype
105-
106-
107- def test_asarrays_scalar_vs_scalar (xp : ModuleType ):
108- a , b = asarrays (1 , 2.2 , xp = xp )
109- assert a .dtype == xp .asarray (1 ).dtype # Default dtype
110- assert b .dtype == xp .asarray (2.2 ).dtype # Default dtype; not broadcasted
111-
112-
113- ALL_TYPES = (
114- "int8" ,
115- "int16" ,
116- "int32" ,
117- "int64" ,
118- "uint8" ,
119- "uint16" ,
120- "uint32" ,
121- "uint64" ,
122- "float32" ,
123- "float64" ,
124- "complex64" ,
125- "complex128" ,
126- "bool" ,
127- )
128-
129-
130- @pytest .mark .parametrize ("a_type" , ALL_TYPES )
131- @pytest .mark .parametrize ("b_type" , ALL_TYPES )
132- def test_asarrays_array_vs_array (a_type : str , b_type : str , xp : ModuleType ):
133- """
134- Test that when both inputs of asarray are already Array API objects,
135- they are returned unchanged.
136- """
137- a = xp .asarray (1 , dtype = getattr (xp , a_type ))
138- b = xp .asarray (1 , dtype = getattr (xp , b_type ))
139- xa , xb = asarrays (a , b , xp )
140- assert xa .dtype == a .dtype
141- assert xb .dtype == b .dtype
142-
143-
144- @pytest .mark .parametrize ("dtype" , [np .float64 , np .complex128 ])
145- def test_asarrays_numpy_generics (dtype : type ):
146- """
147- Test special case of np.float64 and np.complex128,
148- which are subclasses of float and complex.
149- """
150- a = dtype (0 )
151- xa , xb = asarrays (a , 0 , xp = np )
152- assert xa .dtype == dtype
153- assert xb .dtype == dtype
62+ class TestAsArrays :
63+ @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no isdtype" )
64+ @pytest .mark .parametrize (
65+ ("dtype" , "b" , "defined" ),
66+ [
67+ # Well-defined cases of dtype promotion from Python scalar to Array
68+ # bool vs. bool
69+ ("bool" , True , True ),
70+ # int vs. xp.*int*, xp.float*, xp.complex*
71+ ("int16" , 1 , True ),
72+ ("uint8" , 1 , True ),
73+ ("float32" , 1 , True ),
74+ ("float64" , 1 , True ),
75+ ("complex64" , 1 , True ),
76+ ("complex128" , 1 , True ),
77+ # float vs. xp.float, xp.complex
78+ ("float32" , 1.0 , True ),
79+ ("float64" , 1.0 , True ),
80+ ("complex64" , 1.0 , True ),
81+ ("complex128" , 1.0 , True ),
82+ # complex vs. xp.complex
83+ ("complex64" , 1.0j , True ),
84+ ("complex128" , 1.0j , True ),
85+ # Undefined cases
86+ ("bool" , 1 , False ),
87+ ("int64" , 1.0 , False ),
88+ ("float64" , 1.0j , False ),
89+ ],
90+ )
91+ def test_array_vs_scalar (
92+ self , dtype : str , b : int | float | complex , defined : bool , xp : ModuleType
93+ ):
94+ a = xp .asarray (1 , dtype = getattr (xp , dtype ))
95+
96+ xa , xb = asarrays (a , b , xp )
97+ assert xa .dtype == a .dtype
98+ if defined :
99+ assert xb .dtype == a .dtype
100+ else :
101+ assert xb .dtype == xp .asarray (b ).dtype
102+
103+ xbr , xar = asarrays (b , a , xp )
104+ assert xar .dtype == xa .dtype
105+ assert xbr .dtype == xb .dtype
106+
107+ def test_scalar_vs_scalar (self , xp : ModuleType ):
108+ a , b = asarrays (1 , 2.2 , xp = xp )
109+ assert a .dtype == xp .asarray (1 ).dtype # Default dtype
110+ assert b .dtype == xp .asarray (2.2 ).dtype # Default dtype; not broadcasted
111+
112+ ALL_TYPES : tuple [str , ...] = (
113+ "int8" ,
114+ "int16" ,
115+ "int32" ,
116+ "int64" ,
117+ "uint8" ,
118+ "uint16" ,
119+ "uint32" ,
120+ "uint64" ,
121+ "float32" ,
122+ "float64" ,
123+ "complex64" ,
124+ "complex128" ,
125+ "bool" ,
126+ )
127+
128+ @pytest .mark .parametrize ("a_type" , ALL_TYPES )
129+ @pytest .mark .parametrize ("b_type" , ALL_TYPES )
130+ def test_array_vs_array (self , a_type : str , b_type : str , xp : ModuleType ):
131+ """
132+ Test that when both inputs of asarray are already Array API objects,
133+ they are returned unchanged.
134+ """
135+ a = xp .asarray (1 , dtype = getattr (xp , a_type ))
136+ b = xp .asarray (1 , dtype = getattr (xp , b_type ))
137+ xa , xb = asarrays (a , b , xp )
138+ assert xa .dtype == a .dtype
139+ assert xb .dtype == b .dtype
140+
141+ @pytest .mark .parametrize ("dtype" , [np .float64 , np .complex128 ])
142+ def test_numpy_generics (self , dtype : type ):
143+ """
144+ Test special case of np.float64 and np.complex128,
145+ which are subclasses of float and complex.
146+ """
147+ a = dtype (0 )
148+ xa , xb = asarrays (a , 0 , xp = np )
149+ assert xa .dtype == dtype
150+ assert xb .dtype == dtype
0 commit comments