diff --git a/tests/test_tcp.py b/tests/test_tcp.py index 8ea6e59..6ab170a 100644 --- a/tests/test_tcp.py +++ b/tests/test_tcp.py @@ -2606,35 +2606,6 @@ class _TestSSL(tb.SSLTestCase): self.assertEqual(len(data), CHUNK * SIZE) sock.close() - def openssl_server(sock): - conn = openssl_ssl.Connection(sslctx_openssl, sock) - conn.set_accept_state() - - while True: - try: - data = conn.recv(16384) - self.assertEqual(data, b'ping') - break - except openssl_ssl.WantReadError: - pass - - # use renegotiation to queue data in peer _write_backlog - conn.renegotiate() - conn.send(b'pong') - - data_size = 0 - while True: - try: - chunk = conn.recv(16384) - if not chunk: - break - data_size += len(chunk) - except openssl_ssl.WantReadError: - pass - except openssl_ssl.ZeroReturnError: - break - self.assertEqual(data_size, CHUNK * SIZE) - def run(meth): def wrapper(sock): try: @@ -2652,12 +2623,18 @@ class _TestSSL(tb.SSLTestCase): *addr, ssl=client_sslctx, server_hostname='') + sslprotocol = writer.get_extra_info('uvloop.sslproto') writer.write(b'ping') data = await reader.readexactly(4) self.assertEqual(data, b'pong') + + sslprotocol.pause_writing() for _ in range(SIZE): writer.write(b'x' * CHUNK) + writer.close() + sslprotocol.resume_writing() + await self.wait_closed(writer) try: data = await reader.read() @@ -2669,9 +2646,6 @@ class _TestSSL(tb.SSLTestCase): with self.tcp_server(run(server)) as srv: self.loop.run_until_complete(client(srv.addr)) - with self.tcp_server(run(openssl_server)) as srv: - self.loop.run_until_complete(client(srv.addr)) - def test_remote_shutdown_receives_trailing_data(self): if self.implementation == 'asyncio': raise unittest.SkipTest() @@ -2892,7 +2866,13 @@ class _TestSSL(tb.SSLTestCase): self.assertIsNone(ctx()) def test_shutdown_timeout_handler_not_set(self): + if self.implementation == 'asyncio': + # asyncio cannot receive EOF after resume_reading() + raise unittest.SkipTest() + loop = self.loop + eof = asyncio.Event() + extra = None def server(sock): sslctx = self._create_server_ssl_context(self.ONLYCERT, @@ -2900,12 +2880,12 @@ class _TestSSL(tb.SSLTestCase): sock = sslctx.wrap_socket(sock, server_side=True) sock.send(b'hello') assert sock.recv(1024) == b'world' - time.sleep(0.1) - sock.send(b'extra bytes' * 1) + sock.send(b'extra bytes') # sending EOF here sock.shutdown(socket.SHUT_WR) + loop.call_soon_threadsafe(eof.set) # make sure we have enough time to reproduce the issue - time.sleep(0.1) + assert sock.recv(1024) == b'' sock.close() class Protocol(asyncio.Protocol): @@ -2917,20 +2897,28 @@ class _TestSSL(tb.SSLTestCase): self.transport = transport def data_received(self, data): - self.transport.write(b'world') - # pause reading would make incoming data stay in the sslobj - self.transport.pause_reading() - # resume for AIO to pass - loop.call_later(0.2, self.transport.resume_reading) + if data == b'hello': + self.transport.write(b'world') + # pause reading would make incoming data stay in the sslobj + self.transport.pause_reading() + else: + nonlocal extra + extra = data def connection_lost(self, exc): - self.fut.set_result(None) + if exc is None: + self.fut.set_result(None) + else: + self.fut.set_exception(exc) async def client(addr): ctx = self._create_client_ssl_context() tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx) + await eof.wait() + tr.resume_reading() await pr.fut tr.close() + assert extra == b'extra bytes' with self.tcp_server(server) as srv: loop.run_until_complete(client(srv.addr)) diff --git a/uvloop/sslproto.pxd b/uvloop/sslproto.pxd index c29af7b..bc94bfd 100644 --- a/uvloop/sslproto.pxd +++ b/uvloop/sslproto.pxd @@ -65,6 +65,7 @@ cdef class SSLProtocol: bint _ssl_writing_paused bint _app_reading_paused + bint _eof_received size_t _incoming_high_water size_t _incoming_low_water diff --git a/uvloop/sslproto.pyx b/uvloop/sslproto.pyx index c5b9c3a..1a52e71 100644 --- a/uvloop/sslproto.pyx +++ b/uvloop/sslproto.pyx @@ -278,6 +278,7 @@ cdef class SSLProtocol: self._incoming_high_water = 0 self._incoming_low_water = 0 self._set_read_buffer_limits() + self._eof_received = False self._app_writing_paused = False self._outgoing_high_water = 0 @@ -391,6 +392,7 @@ cdef class SSLProtocol: will close itself. If it returns a true value, closing the transport is up to the protocol. """ + self._eof_received = True try: if self._loop.get_debug(): aio_logger.debug("%r received EOF", self) @@ -400,9 +402,10 @@ cdef class SSLProtocol: elif self._state == WRAPPED: self._set_state(FLUSHING) - self._do_write() - self._set_state(SHUTDOWN) - self._do_shutdown() + if self._app_reading_paused: + return True + else: + self._do_flush() elif self._state == FLUSHING: self._do_write() @@ -412,11 +415,14 @@ cdef class SSLProtocol: elif self._state == SHUTDOWN: self._do_shutdown() - finally: + except Exception: self._transport.close() + raise cdef _get_extra_info(self, name, default=None): - if name in self._extra: + if name == 'uvloop.sslproto': + return self + elif name in self._extra: return self._extra[name] elif self._transport is not None: return self._transport.get_extra_info(name, default) @@ -555,33 +561,14 @@ cdef class SSLProtocol: aio_TimeoutError('SSL shutdown timed out')) cdef _do_flush(self): - if self._write_backlog: - try: - while True: - # data is discarded when FLUSHING - chunk_size = len(self._sslobj_read(SSL_READ_MAX_SIZE)) - if not chunk_size: - # close_notify - break - except ssl_SSLAgainErrors as exc: - pass - except ssl_SSLError as exc: - self._on_shutdown_complete(exc) - return - - try: - self._do_write() - except Exception as exc: - self._on_shutdown_complete(exc) - return - - if not self._write_backlog: - self._set_state(SHUTDOWN) - self._do_shutdown() + self._do_read() + self._set_state(SHUTDOWN) + self._do_shutdown() cdef _do_shutdown(self): try: - self._sslobj.unwrap() + if not self._eof_received: + self._sslobj.unwrap() except ssl_SSLAgainErrors as exc: self._process_outgoing() except ssl_SSLError as exc: @@ -655,7 +642,7 @@ cdef class SSLProtocol: # Incoming flow cdef _do_read(self): - if self._state != WRAPPED: + if self._state != WRAPPED and self._state != FLUSHING: return try: if not self._app_reading_paused: