diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 6fe1c1ac2..3b58d7537 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -321,7 +321,7 @@ def loopback_server_factory(socket, version=SSLv23_METHOD): return server -def loopback(server_factory=None, client_factory=None): +def loopback(server_factory=None, client_factory=None, blocking=True): """ Create a connected socket pair and force two connected SSL sockets to talk to each other via memory BIOs. @@ -337,8 +337,8 @@ def loopback(server_factory=None, client_factory=None): handshake(client, server) - server.setblocking(True) - client.setblocking(True) + server.setblocking(blocking) + client.setblocking(blocking) return server, client @@ -3292,11 +3292,134 @@ def test_memoryview_really_doesnt_overfill(self): self._doesnt_overfill_test(_make_memoryview) +@pytest.fixture +def nonblocking_tls_connections_pair(): + """Return a non-blocking TLS loopback connections pair.""" + return loopback(blocking=False) + + +@pytest.fixture +def nonblocking_tls_server_connection(nonblocking_tls_connections_pair): + """Return a non-blocking TLS server socket connected to loopback.""" + return nonblocking_tls_connections_pair[0] + + +@pytest.fixture +def nonblocking_tls_client_connection(nonblocking_tls_connections_pair): + """Return a non-blocking TLS client socket connected to loopback.""" + return nonblocking_tls_connections_pair[1] + + class TestConnectionSendall: """ Tests for `Connection.sendall`. """ + def test_want_write( + self, + monkeypatch, + nonblocking_tls_server_connection, + nonblocking_tls_client_connection, + ): + msg = b"x" + garbage_size = 1024 * 1024 * 64 + large_payload = b"p" * garbage_size * 2 + payload_size = len(large_payload) + + sent_garbage_size = 0 + try: + sent_garbage_size += nonblocking_tls_client_connection.send( + msg * garbage_size, + ) + except WantWriteError: + pass + for i in range(garbage_size): + try: + sent_garbage_size += nonblocking_tls_client_connection.send( + msg, + ) + except WantWriteError: + break + else: + pytest.fail( + "Failed to fill socket buffer, cannot test " + "'want write' in `sendall()`" + ) + garbage_payload = sent_garbage_size * msg + + def consume_garbage(conn): + assert patched_ssl_write.want_write_counter >= 1 + assert not consume_garbage.garbage_consumed + + while len(consume_garbage.consumed) < sent_garbage_size: + try: + consume_garbage.consumed += conn.recv( + sent_garbage_size - len(consume_garbage.consumed), + ) + except WantReadError: + pass + + assert consume_garbage.consumed == garbage_payload + + consume_garbage.garbage_consumed = True + + consume_garbage.garbage_consumed = False + consume_garbage.consumed = b"" + + def consume_payload(conn): + try: + consume_payload.consumed += conn.recv(payload_size) + except WantReadError: + pass + + consume_payload.consumed = b"" + + original_ssl_write = _lib.SSL_write + + def patched_ssl_write(ctx, data, size): + write_result = original_ssl_write(ctx, data, size) + try: + nonblocking_tls_client_connection._raise_ssl_error( + ctx, + write_result, + ) + except WantWriteError: + patched_ssl_write.want_write_counter += 1 + consume_data_on_server = ( + consume_payload + if consume_garbage.garbage_consumed + else consume_garbage + ) + + consume_data_on_server(nonblocking_tls_server_connection) + # NOTE: We don't re-raise it as the calling code will do + # NOTE: the same after the call. + return write_result + + patched_ssl_write.want_write_counter = 0 + + # NOTE: Make the client think it needs a handshake so that it'll + # NOTE: attempt to `do_handshake()` on the next `SSL_write()` + # NOTE: that originates from `sendall()`: + nonblocking_tls_client_connection.set_connect_state() + try: + nonblocking_tls_client_connection.do_handshake() + except WantWriteError: + assert True # Sanity check + except: + assert False # This should never happen (see the note above) + + with monkeypatch.context() as mp_ctx: + mp_ctx.setattr(_lib, "SSL_write", patched_ssl_write) + nonblocking_tls_client_connection.sendall(large_payload) + + assert consume_garbage.garbage_consumed + + # NOTE: Read the leftover data from the very last `SSL_write()` + consume_payload(nonblocking_tls_server_connection) + + assert consume_payload.consumed == large_payload + def test_wrong_args(self): """ When called with arguments other than a string argument for its first