fix missing data on EOF in flushing

* when EOF is received and data is still pending in incoming buffer,
  the data will be lost before this fix
* also removed sleep from a recent-written test
This commit is contained in:
Fantix King 2019-10-25 17:28:27 -05:00 committed by Yury Selivanov
parent 695a520195
commit 6476aad6fd
3 changed files with 47 additions and 71 deletions

View File

@ -2606,35 +2606,6 @@ class _TestSSL(tb.SSLTestCase):
self.assertEqual(len(data), CHUNK * SIZE)
sock.close()
def openssl_server(sock):
conn = openssl_ssl.Connection(sslctx_openssl, sock)
conn.set_accept_state()
while True:
try:
data = conn.recv(16384)
self.assertEqual(data, b'ping')
break
except openssl_ssl.WantReadError:
pass
# use renegotiation to queue data in peer _write_backlog
conn.renegotiate()
conn.send(b'pong')
data_size = 0
while True:
try:
chunk = conn.recv(16384)
if not chunk:
break
data_size += len(chunk)
except openssl_ssl.WantReadError:
pass
except openssl_ssl.ZeroReturnError:
break
self.assertEqual(data_size, CHUNK * SIZE)
def run(meth):
def wrapper(sock):
try:
@ -2652,12 +2623,18 @@ class _TestSSL(tb.SSLTestCase):
*addr,
ssl=client_sslctx,
server_hostname='')
sslprotocol = writer.get_extra_info('uvloop.sslproto')
writer.write(b'ping')
data = await reader.readexactly(4)
self.assertEqual(data, b'pong')
sslprotocol.pause_writing()
for _ in range(SIZE):
writer.write(b'x' * CHUNK)
writer.close()
sslprotocol.resume_writing()
await self.wait_closed(writer)
try:
data = await reader.read()
@ -2669,9 +2646,6 @@ class _TestSSL(tb.SSLTestCase):
with self.tcp_server(run(server)) as srv:
self.loop.run_until_complete(client(srv.addr))
with self.tcp_server(run(openssl_server)) as srv:
self.loop.run_until_complete(client(srv.addr))
def test_remote_shutdown_receives_trailing_data(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest()
@ -2892,7 +2866,13 @@ class _TestSSL(tb.SSLTestCase):
self.assertIsNone(ctx())
def test_shutdown_timeout_handler_not_set(self):
if self.implementation == 'asyncio':
# asyncio cannot receive EOF after resume_reading()
raise unittest.SkipTest()
loop = self.loop
eof = asyncio.Event()
extra = None
def server(sock):
sslctx = self._create_server_ssl_context(self.ONLYCERT,
@ -2900,12 +2880,12 @@ class _TestSSL(tb.SSLTestCase):
sock = sslctx.wrap_socket(sock, server_side=True)
sock.send(b'hello')
assert sock.recv(1024) == b'world'
time.sleep(0.1)
sock.send(b'extra bytes' * 1)
sock.send(b'extra bytes')
# sending EOF here
sock.shutdown(socket.SHUT_WR)
loop.call_soon_threadsafe(eof.set)
# make sure we have enough time to reproduce the issue
time.sleep(0.1)
assert sock.recv(1024) == b''
sock.close()
class Protocol(asyncio.Protocol):
@ -2917,20 +2897,28 @@ class _TestSSL(tb.SSLTestCase):
self.transport = transport
def data_received(self, data):
self.transport.write(b'world')
# pause reading would make incoming data stay in the sslobj
self.transport.pause_reading()
# resume for AIO to pass
loop.call_later(0.2, self.transport.resume_reading)
if data == b'hello':
self.transport.write(b'world')
# pause reading would make incoming data stay in the sslobj
self.transport.pause_reading()
else:
nonlocal extra
extra = data
def connection_lost(self, exc):
self.fut.set_result(None)
if exc is None:
self.fut.set_result(None)
else:
self.fut.set_exception(exc)
async def client(addr):
ctx = self._create_client_ssl_context()
tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
await eof.wait()
tr.resume_reading()
await pr.fut
tr.close()
assert extra == b'extra bytes'
with self.tcp_server(server) as srv:
loop.run_until_complete(client(srv.addr))

View File

@ -65,6 +65,7 @@ cdef class SSLProtocol:
bint _ssl_writing_paused
bint _app_reading_paused
bint _eof_received
size_t _incoming_high_water
size_t _incoming_low_water

View File

@ -278,6 +278,7 @@ cdef class SSLProtocol:
self._incoming_high_water = 0
self._incoming_low_water = 0
self._set_read_buffer_limits()
self._eof_received = False
self._app_writing_paused = False
self._outgoing_high_water = 0
@ -391,6 +392,7 @@ cdef class SSLProtocol:
will close itself. If it returns a true value, closing the
transport is up to the protocol.
"""
self._eof_received = True
try:
if self._loop.get_debug():
aio_logger.debug("%r received EOF", self)
@ -400,9 +402,10 @@ cdef class SSLProtocol:
elif self._state == WRAPPED:
self._set_state(FLUSHING)
self._do_write()
self._set_state(SHUTDOWN)
self._do_shutdown()
if self._app_reading_paused:
return True
else:
self._do_flush()
elif self._state == FLUSHING:
self._do_write()
@ -412,11 +415,14 @@ cdef class SSLProtocol:
elif self._state == SHUTDOWN:
self._do_shutdown()
finally:
except Exception:
self._transport.close()
raise
cdef _get_extra_info(self, name, default=None):
if name in self._extra:
if name == 'uvloop.sslproto':
return self
elif name in self._extra:
return self._extra[name]
elif self._transport is not None:
return self._transport.get_extra_info(name, default)
@ -555,33 +561,14 @@ cdef class SSLProtocol:
aio_TimeoutError('SSL shutdown timed out'))
cdef _do_flush(self):
if self._write_backlog:
try:
while True:
# data is discarded when FLUSHING
chunk_size = len(self._sslobj_read(SSL_READ_MAX_SIZE))
if not chunk_size:
# close_notify
break
except ssl_SSLAgainErrors as exc:
pass
except ssl_SSLError as exc:
self._on_shutdown_complete(exc)
return
try:
self._do_write()
except Exception as exc:
self._on_shutdown_complete(exc)
return
if not self._write_backlog:
self._set_state(SHUTDOWN)
self._do_shutdown()
self._do_read()
self._set_state(SHUTDOWN)
self._do_shutdown()
cdef _do_shutdown(self):
try:
self._sslobj.unwrap()
if not self._eof_received:
self._sslobj.unwrap()
except ssl_SSLAgainErrors as exc:
self._process_outgoing()
except ssl_SSLError as exc:
@ -655,7 +642,7 @@ cdef class SSLProtocol:
# Incoming flow
cdef _do_read(self):
if self._state != WRAPPED:
if self._state != WRAPPED and self._state != FLUSHING:
return
try:
if not self._app_reading_paused: