2
2
3
3
import errno
4
4
import socket
5
+ import threading
6
+ import time
5
7
import unittest
6
8
from unittest import mock
7
9
try :
@@ -337,18 +339,6 @@ def test__sock_sendall_none(self):
337
339
(10 , self .loop ._sock_sendall , f , True , sock , b'data' ),
338
340
self .loop .add_writer .call_args [0 ])
339
341
340
- def test_sock_connect (self ):
341
- sock = test_utils .mock_nonblocking_socket ()
342
- self .loop ._sock_connect = mock .Mock ()
343
-
344
- f = self .loop .sock_connect (sock , ('127.0.0.1' , 8080 ))
345
- self .assertIsInstance (f , asyncio .Future )
346
- self .loop ._run_once ()
347
- future_in , sock_in , address_in = self .loop ._sock_connect .call_args [0 ]
348
- self .assertEqual (future_in , f )
349
- self .assertEqual (sock_in , sock )
350
- self .assertEqual (address_in , ('127.0.0.1' , 8080 ))
351
-
352
342
def test_sock_connect_timeout (self ):
353
343
# asyncio issue #205: sock_connect() must unregister the socket on
354
344
# timeout error
@@ -360,29 +350,34 @@ def test_sock_connect_timeout(self):
360
350
sock .connect .side_effect = BlockingIOError
361
351
362
352
# first call to sock_connect() registers the socket
363
- fut = self .loop .sock_connect (sock , ('127.0.0.1' , 80 ))
353
+ fut = self .loop .create_task (
354
+ self .loop .sock_connect (sock , ('127.0.0.1' , 80 )))
364
355
self .loop ._run_once ()
365
356
self .assertTrue (sock .connect .called )
366
357
self .assertTrue (self .loop .add_writer .called )
367
- self .assertEqual (len (fut ._callbacks ), 1 )
368
358
369
359
# on timeout, the socket must be unregistered
370
360
sock .connect .reset_mock ()
371
- fut .set_exception ( asyncio . TimeoutError )
372
- with self .assertRaises (asyncio .TimeoutError ):
361
+ fut .cancel ( )
362
+ with self .assertRaises (asyncio .CancelledError ):
373
363
self .loop .run_until_complete (fut )
374
364
self .assertTrue (self .loop .remove_writer .called )
375
365
376
- def test_sock_connect_resolve_using_socket_params (self ):
366
+ @mock .patch ('socket.getaddrinfo' )
367
+ def test_sock_connect_resolve_using_socket_params (self , m_gai ):
377
368
addr = ('need-resolution.com' , 8080 )
378
369
sock = test_utils .mock_nonblocking_socket ()
379
- self .loop .getaddrinfo = mock .Mock ()
380
- self .loop .sock_connect (sock , addr )
381
- while not self .loop .getaddrinfo .called :
370
+ m_gai .side_effect = (None , None , None , None , ('127.0.0.1' , 0 ))
371
+ m_gai ._is_coroutine = False
372
+ con = self .loop .create_task (self .loop .sock_connect (sock , addr ))
373
+ while not m_gai .called :
382
374
self .loop ._run_once ()
383
- self .loop .getaddrinfo .assert_called_with (
384
- * addr , type = sock .type , family = sock .family , proto = sock .proto ,
385
- flags = 0 )
375
+ m_gai .assert_called_with (
376
+ addr [0 ], addr [1 ], sock .family , sock .type , sock .proto , 0 )
377
+
378
+ con .cancel ()
379
+ with self .assertRaises (asyncio .CancelledError ):
380
+ self .loop .run_until_complete (con )
386
381
387
382
def test__sock_connect (self ):
388
383
f = asyncio .Future (loop = self .loop )
@@ -1792,5 +1787,88 @@ def test_fatal_error_connected(self, m_exc):
1792
1787
exc_info = (ConnectionRefusedError , MOCK_ANY , MOCK_ANY ))
1793
1788
1794
1789
1790
+ class SelectorLoopFunctionalTests (unittest .TestCase ):
1791
+
1792
+ def setUp (self ):
1793
+ self .loop = asyncio .new_event_loop ()
1794
+ asyncio .set_event_loop (None )
1795
+
1796
+ def tearDown (self ):
1797
+ self .loop .close ()
1798
+
1799
+ @asyncio .coroutine
1800
+ def recv_all (self , sock , nbytes ):
1801
+ buf = b''
1802
+ while len (buf ) < nbytes :
1803
+ buf += yield from self .loop .sock_recv (sock , nbytes - len (buf ))
1804
+ return buf
1805
+
1806
+ def test_sock_connect_sock_write_race (self ):
1807
+ TIMEOUT = 3.0
1808
+ PAYLOAD = b'DATA' * 1024 * 1024
1809
+
1810
+ class Server (threading .Thread ):
1811
+ def __init__ (self , * args , srv_sock , ** kwargs ):
1812
+ super ().__init__ (* args , ** kwargs )
1813
+ self .srv_sock = srv_sock
1814
+
1815
+ def run (self ):
1816
+ with self .srv_sock :
1817
+ srv_sock .listen (100 )
1818
+
1819
+ sock , addr = self .srv_sock .accept ()
1820
+ sock .settimeout (TIMEOUT )
1821
+
1822
+ with sock :
1823
+ sock .sendall (b'helo' )
1824
+
1825
+ buf = bytearray ()
1826
+ while len (buf ) < len (PAYLOAD ):
1827
+ pack = sock .recv (1024 * 65 )
1828
+ if not pack :
1829
+ break
1830
+ buf .extend (pack )
1831
+
1832
+ @asyncio .coroutine
1833
+ def client (addr ):
1834
+ sock = socket .socket ()
1835
+ with sock :
1836
+ sock .setblocking (False )
1837
+
1838
+ started = time .monotonic ()
1839
+ while True :
1840
+ if time .monotonic () - started > TIMEOUT :
1841
+ self .fail ('unable to connect to the socket' )
1842
+ return
1843
+ try :
1844
+ yield from self .loop .sock_connect (sock , addr )
1845
+ except OSError :
1846
+ yield from asyncio .sleep (0.05 , loop = self .loop )
1847
+ else :
1848
+ break
1849
+
1850
+ # Give 'Server' thread a chance to accept and send b'helo'
1851
+ time .sleep (0.1 )
1852
+
1853
+ data = yield from self .recv_all (sock , 4 )
1854
+ self .assertEqual (data , b'helo' )
1855
+ yield from self .loop .sock_sendall (sock , PAYLOAD )
1856
+
1857
+ srv_sock = socket .socket ()
1858
+ srv_sock .settimeout (TIMEOUT )
1859
+ srv_sock .bind (('127.0.0.1' , 0 ))
1860
+ srv_addr = srv_sock .getsockname ()
1861
+
1862
+ srv = Server (srv_sock = srv_sock , daemon = True )
1863
+ srv .start ()
1864
+
1865
+ try :
1866
+ self .loop .run_until_complete (
1867
+ asyncio .wait_for (client (srv_addr ), loop = self .loop ,
1868
+ timeout = TIMEOUT ))
1869
+ finally :
1870
+ srv .join ()
1871
+
1872
+
1795
1873
if __name__ == '__main__' :
1796
1874
unittest .main ()
0 commit comments