Skip to content
This repository was archived by the owner on Nov 23, 2017. It is now read-only.

Commit d6dcf25

Browse files
committed
Fix callbacks race in SelectorLoop.sock_connect.
1 parent 6f8f833 commit d6dcf25

File tree

2 files changed

+111
-40
lines changed

2 files changed

+111
-40
lines changed

asyncio/selector_events.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ def _sock_sendall(self, fut, registered, sock, data):
400400
data = data[n:]
401401
self.add_writer(fd, self._sock_sendall, fut, True, sock, data)
402402

403+
@coroutine
403404
def sock_connect(self, sock, address):
404405
"""Connect to a remote socket at address.
405406
@@ -408,24 +409,16 @@ def sock_connect(self, sock, address):
408409
if self._debug and sock.gettimeout() != 0:
409410
raise ValueError("the socket must be non-blocking")
410411

411-
fut = self.create_future()
412-
if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX:
413-
self._sock_connect(fut, sock, address)
414-
else:
412+
if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX:
415413
resolved = base_events._ensure_resolved(
416414
address, family=sock.family, proto=sock.proto, loop=self)
417-
resolved.add_done_callback(
418-
lambda resolved: self._on_resolved(fut, sock, resolved))
419-
420-
return fut
421-
422-
def _on_resolved(self, fut, sock, resolved):
423-
try:
415+
if not resolved.done():
416+
yield from resolved
424417
_, _, _, _, address = resolved.result()[0]
425-
except Exception as exc:
426-
fut.set_exception(exc)
427-
else:
428-
self._sock_connect(fut, sock, address)
418+
419+
fut = self.create_future()
420+
self._sock_connect(fut, sock, address)
421+
return (yield from fut)
429422

430423
def _sock_connect(self, fut, sock, address):
431424
fd = sock.fileno()
@@ -436,8 +429,8 @@ def _sock_connect(self, fut, sock, address):
436429
# connection runs in background. We have to wait until the socket
437430
# becomes writable to be notified when the connection succeed or
438431
# fails.
439-
fut.add_done_callback(functools.partial(self._sock_connect_done,
440-
fd))
432+
fut.add_done_callback(
433+
functools.partial(self._sock_connect_done, fd))
441434
self.add_writer(fd, self._sock_connect_cb, fut, sock, address)
442435
except Exception as exc:
443436
fut.set_exception(exc)

tests/test_selector_events.py

Lines changed: 101 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import errno
44
import socket
5+
import threading
6+
import time
57
import unittest
68
from unittest import mock
79
try:
@@ -337,18 +339,6 @@ def test__sock_sendall_none(self):
337339
(10, self.loop._sock_sendall, f, True, sock, b'data'),
338340
self.loop.add_writer.call_args[0])
339341

340-
def test_sock_connect(self):
341-
sock = test_utils.mock_nonblocking_socket()
342-
self.loop._sock_connect = mock.Mock()
343-
344-
f = self.loop.sock_connect(sock, ('127.0.0.1', 8080))
345-
self.assertIsInstance(f, asyncio.Future)
346-
self.loop._run_once()
347-
future_in, sock_in, address_in = self.loop._sock_connect.call_args[0]
348-
self.assertEqual(future_in, f)
349-
self.assertEqual(sock_in, sock)
350-
self.assertEqual(address_in, ('127.0.0.1', 8080))
351-
352342
def test_sock_connect_timeout(self):
353343
# asyncio issue #205: sock_connect() must unregister the socket on
354344
# timeout error
@@ -360,29 +350,34 @@ def test_sock_connect_timeout(self):
360350
sock.connect.side_effect = BlockingIOError
361351

362352
# first call to sock_connect() registers the socket
363-
fut = self.loop.sock_connect(sock, ('127.0.0.1', 80))
353+
fut = self.loop.create_task(
354+
self.loop.sock_connect(sock, ('127.0.0.1', 80)))
364355
self.loop._run_once()
365356
self.assertTrue(sock.connect.called)
366357
self.assertTrue(self.loop.add_writer.called)
367-
self.assertEqual(len(fut._callbacks), 1)
368358

369359
# on timeout, the socket must be unregistered
370360
sock.connect.reset_mock()
371-
fut.set_exception(asyncio.TimeoutError)
372-
with self.assertRaises(asyncio.TimeoutError):
361+
fut.cancel()
362+
with self.assertRaises(asyncio.CancelledError):
373363
self.loop.run_until_complete(fut)
374364
self.assertTrue(self.loop.remove_writer.called)
375365

376-
def test_sock_connect_resolve_using_socket_params(self):
366+
@mock.patch('socket.getaddrinfo')
367+
def test_sock_connect_resolve_using_socket_params(self, m_gai):
377368
addr = ('need-resolution.com', 8080)
378369
sock = test_utils.mock_nonblocking_socket()
379-
self.loop.getaddrinfo = mock.Mock()
380-
self.loop.sock_connect(sock, addr)
381-
while not self.loop.getaddrinfo.called:
370+
m_gai.side_effect = (None, None, None, None, ('127.0.0.1', 0))
371+
m_gai._is_coroutine = False
372+
con = self.loop.create_task(self.loop.sock_connect(sock, addr))
373+
while not m_gai.called:
382374
self.loop._run_once()
383-
self.loop.getaddrinfo.assert_called_with(
384-
*addr, type=sock.type, family=sock.family, proto=sock.proto,
385-
flags=0)
375+
m_gai.assert_called_with(
376+
addr[0], addr[1], sock.family, sock.type, sock.proto, 0)
377+
378+
con.cancel()
379+
with self.assertRaises(asyncio.CancelledError):
380+
self.loop.run_until_complete(con)
386381

387382
def test__sock_connect(self):
388383
f = asyncio.Future(loop=self.loop)
@@ -1792,5 +1787,88 @@ def test_fatal_error_connected(self, m_exc):
17921787
exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY))
17931788

17941789

1790+
class SelectorLoopFunctionalTests(unittest.TestCase):
1791+
1792+
def setUp(self):
1793+
self.loop = asyncio.new_event_loop()
1794+
asyncio.set_event_loop(None)
1795+
1796+
def tearDown(self):
1797+
self.loop.close()
1798+
1799+
@asyncio.coroutine
1800+
def recv_all(self, sock, nbytes):
1801+
buf = b''
1802+
while len(buf) < nbytes:
1803+
buf += yield from self.loop.sock_recv(sock, nbytes - len(buf))
1804+
return buf
1805+
1806+
def test_sock_connect_sock_write_race(self):
1807+
TIMEOUT = 3.0
1808+
PAYLOAD = b'DATA' * 1024 * 1024
1809+
1810+
class Server(threading.Thread):
1811+
def __init__(self, *args, srv_sock, **kwargs):
1812+
super().__init__(*args, **kwargs)
1813+
self.srv_sock = srv_sock
1814+
1815+
def run(self):
1816+
with self.srv_sock:
1817+
srv_sock.listen(100)
1818+
1819+
sock, addr = self.srv_sock.accept()
1820+
sock.settimeout(TIMEOUT)
1821+
1822+
with sock:
1823+
sock.sendall(b'helo')
1824+
1825+
buf = bytearray()
1826+
while len(buf) < len(PAYLOAD):
1827+
pack = sock.recv(1024 * 65)
1828+
if not pack:
1829+
break
1830+
buf.extend(pack)
1831+
1832+
@asyncio.coroutine
1833+
def client(addr):
1834+
sock = socket.socket()
1835+
with sock:
1836+
sock.setblocking(False)
1837+
1838+
started = time.monotonic()
1839+
while True:
1840+
if time.monotonic() - started > TIMEOUT:
1841+
self.fail('unable to connect to the socket')
1842+
return
1843+
try:
1844+
yield from self.loop.sock_connect(sock, addr)
1845+
except OSError:
1846+
yield from asyncio.sleep(0.05, loop=self.loop)
1847+
else:
1848+
break
1849+
1850+
# Give 'Server' thread a chance to accept and send b'helo'
1851+
time.sleep(0.1)
1852+
1853+
data = yield from self.recv_all(sock, 4)
1854+
self.assertEqual(data, b'helo')
1855+
yield from self.loop.sock_sendall(sock, PAYLOAD)
1856+
1857+
srv_sock = socket.socket()
1858+
srv_sock.settimeout(TIMEOUT)
1859+
srv_sock.bind(('127.0.0.1', 0))
1860+
srv_addr = srv_sock.getsockname()
1861+
1862+
srv = Server(srv_sock=srv_sock, daemon=True)
1863+
srv.start()
1864+
1865+
try:
1866+
self.loop.run_until_complete(
1867+
asyncio.wait_for(client(srv_addr), loop=self.loop,
1868+
timeout=TIMEOUT))
1869+
finally:
1870+
srv.join()
1871+
1872+
17951873
if __name__ == '__main__':
17961874
unittest.main()

0 commit comments

Comments
 (0)