88proper equality and inequality comparisons, including type safety.
99"""
1010
11+ import pytest
1112from cuda .core .experimental import Device , Stream
1213
1314# ============================================================================
@@ -21,8 +22,7 @@ def test_stream_equality_same_handle(init_cuda):
2122 s1 = device .create_stream ()
2223 s2 = Stream .from_handle (int (s1 .handle ))
2324
24- assert s1 == s2 , "Streams with same handle should be equal"
25- assert not (s1 != s2 ), "Equal streams should not be not-equal"
25+ assert s1 == s2 , "Equal streams should be equal"
2626
2727
2828def test_stream_inequality_different_handles (init_cuda ):
@@ -32,7 +32,6 @@ def test_stream_inequality_different_handles(init_cuda):
3232 s2 = device .create_stream ()
3333
3434 assert s1 != s2 , "Different streams should not be equal"
35- assert not (s1 == s2 ), "Different streams should not be equal"
3635
3736
3837def test_stream_equality_reflexive (init_cuda ):
@@ -58,7 +57,7 @@ def test_stream_type_safety(init_cuda):
5857 # These should not raise exceptions
5958 assert (stream == "not a stream" ) is False
6059 assert (stream == 123 ) is False
61- assert (stream == None ) is False
60+ assert (stream is None ) is False
6261 assert (stream == Device ()) is False
6362
6463
@@ -70,7 +69,7 @@ def test_stream_not_equal_operator(init_cuda):
7069 s3 = Stream .from_handle (int (s1 .handle ))
7170
7271 assert s1 != s2 , "Different streams should be not-equal"
73- assert not ( s1 != s3 ) , "Same handle streams should not be not- equal"
72+ assert s1 == s3 , "Same handle streams should be equal"
7473
7574
7675# ============================================================================
@@ -96,7 +95,6 @@ def test_event_inequality_different_events(init_cuda):
9695 e2 = stream .record ()
9796
9897 assert e1 != e2 , "Different events should not be equal"
99- assert not (e1 == e2 ), "Different events should not be equal"
10098
10199
102100def test_event_type_safety (init_cuda ):
@@ -107,7 +105,7 @@ def test_event_type_safety(init_cuda):
107105
108106 assert (event == "not an event" ) is False
109107 assert (event == 123 ) is False
110- assert (event == None ) is False
108+ assert (event is None ) is False
111109
112110
113111# ============================================================================
@@ -145,7 +143,7 @@ def test_context_type_safety(init_cuda):
145143
146144 assert (context == "not a context" ) is False
147145 assert (context == 123 ) is False
148- assert (context == None ) is False
146+ assert (context is None ) is False
149147
150148
151149# ============================================================================
@@ -172,14 +170,13 @@ def test_device_equality_reflexive(init_cuda):
172170def test_device_inequality_different_id (init_cuda ):
173171 """Devices with different device_id should not be equal."""
174172 try :
175- # Only runs on when two devices are available
176173 dev0 = Device (0 )
177174 dev1 = Device (1 )
178175
179176 assert dev0 != dev1 , "Different devices should not be equal"
180- assert not ( dev0 == dev1 ) , "Different devices should not be equal"
177+ assert dev0 != dev1 , "Different devices should be not- equal"
181178 except (ValueError , Exception ):
182- pass
179+ pytest . skip ( "Test requires at least 2 CUDA devices" )
183180
184181
185182def test_device_type_safety (init_cuda ):
@@ -188,7 +185,7 @@ def test_device_type_safety(init_cuda):
188185
189186 assert (device == "not a device" ) is False
190187 assert (device == 123 ) is False
191- assert (device == None ) is False
188+ assert (device is None ) is False
192189
193190
194191# ============================================================================
0 commit comments