mirror of https://github.com/MagicStack/uvloop.git
fix missing data on EOF in flushing
* when EOF is received and data is still pending in incoming buffer, the data will be lost before this fix * also removed sleep from a recent-written test
This commit is contained in:
parent
695a520195
commit
6476aad6fd
|
@ -2606,35 +2606,6 @@ class _TestSSL(tb.SSLTestCase):
|
|||
self.assertEqual(len(data), CHUNK * SIZE)
|
||||
sock.close()
|
||||
|
||||
def openssl_server(sock):
|
||||
conn = openssl_ssl.Connection(sslctx_openssl, sock)
|
||||
conn.set_accept_state()
|
||||
|
||||
while True:
|
||||
try:
|
||||
data = conn.recv(16384)
|
||||
self.assertEqual(data, b'ping')
|
||||
break
|
||||
except openssl_ssl.WantReadError:
|
||||
pass
|
||||
|
||||
# use renegotiation to queue data in peer _write_backlog
|
||||
conn.renegotiate()
|
||||
conn.send(b'pong')
|
||||
|
||||
data_size = 0
|
||||
while True:
|
||||
try:
|
||||
chunk = conn.recv(16384)
|
||||
if not chunk:
|
||||
break
|
||||
data_size += len(chunk)
|
||||
except openssl_ssl.WantReadError:
|
||||
pass
|
||||
except openssl_ssl.ZeroReturnError:
|
||||
break
|
||||
self.assertEqual(data_size, CHUNK * SIZE)
|
||||
|
||||
def run(meth):
|
||||
def wrapper(sock):
|
||||
try:
|
||||
|
@ -2652,12 +2623,18 @@ class _TestSSL(tb.SSLTestCase):
|
|||
*addr,
|
||||
ssl=client_sslctx,
|
||||
server_hostname='')
|
||||
sslprotocol = writer.get_extra_info('uvloop.sslproto')
|
||||
writer.write(b'ping')
|
||||
data = await reader.readexactly(4)
|
||||
self.assertEqual(data, b'pong')
|
||||
|
||||
sslprotocol.pause_writing()
|
||||
for _ in range(SIZE):
|
||||
writer.write(b'x' * CHUNK)
|
||||
|
||||
writer.close()
|
||||
sslprotocol.resume_writing()
|
||||
|
||||
await self.wait_closed(writer)
|
||||
try:
|
||||
data = await reader.read()
|
||||
|
@ -2669,9 +2646,6 @@ class _TestSSL(tb.SSLTestCase):
|
|||
with self.tcp_server(run(server)) as srv:
|
||||
self.loop.run_until_complete(client(srv.addr))
|
||||
|
||||
with self.tcp_server(run(openssl_server)) as srv:
|
||||
self.loop.run_until_complete(client(srv.addr))
|
||||
|
||||
def test_remote_shutdown_receives_trailing_data(self):
|
||||
if self.implementation == 'asyncio':
|
||||
raise unittest.SkipTest()
|
||||
|
@ -2892,7 +2866,13 @@ class _TestSSL(tb.SSLTestCase):
|
|||
self.assertIsNone(ctx())
|
||||
|
||||
def test_shutdown_timeout_handler_not_set(self):
|
||||
if self.implementation == 'asyncio':
|
||||
# asyncio cannot receive EOF after resume_reading()
|
||||
raise unittest.SkipTest()
|
||||
|
||||
loop = self.loop
|
||||
eof = asyncio.Event()
|
||||
extra = None
|
||||
|
||||
def server(sock):
|
||||
sslctx = self._create_server_ssl_context(self.ONLYCERT,
|
||||
|
@ -2900,12 +2880,12 @@ class _TestSSL(tb.SSLTestCase):
|
|||
sock = sslctx.wrap_socket(sock, server_side=True)
|
||||
sock.send(b'hello')
|
||||
assert sock.recv(1024) == b'world'
|
||||
time.sleep(0.1)
|
||||
sock.send(b'extra bytes' * 1)
|
||||
sock.send(b'extra bytes')
|
||||
# sending EOF here
|
||||
sock.shutdown(socket.SHUT_WR)
|
||||
loop.call_soon_threadsafe(eof.set)
|
||||
# make sure we have enough time to reproduce the issue
|
||||
time.sleep(0.1)
|
||||
assert sock.recv(1024) == b''
|
||||
sock.close()
|
||||
|
||||
class Protocol(asyncio.Protocol):
|
||||
|
@ -2917,20 +2897,28 @@ class _TestSSL(tb.SSLTestCase):
|
|||
self.transport = transport
|
||||
|
||||
def data_received(self, data):
|
||||
self.transport.write(b'world')
|
||||
# pause reading would make incoming data stay in the sslobj
|
||||
self.transport.pause_reading()
|
||||
# resume for AIO to pass
|
||||
loop.call_later(0.2, self.transport.resume_reading)
|
||||
if data == b'hello':
|
||||
self.transport.write(b'world')
|
||||
# pause reading would make incoming data stay in the sslobj
|
||||
self.transport.pause_reading()
|
||||
else:
|
||||
nonlocal extra
|
||||
extra = data
|
||||
|
||||
def connection_lost(self, exc):
|
||||
self.fut.set_result(None)
|
||||
if exc is None:
|
||||
self.fut.set_result(None)
|
||||
else:
|
||||
self.fut.set_exception(exc)
|
||||
|
||||
async def client(addr):
|
||||
ctx = self._create_client_ssl_context()
|
||||
tr, pr = await loop.create_connection(Protocol, *addr, ssl=ctx)
|
||||
await eof.wait()
|
||||
tr.resume_reading()
|
||||
await pr.fut
|
||||
tr.close()
|
||||
assert extra == b'extra bytes'
|
||||
|
||||
with self.tcp_server(server) as srv:
|
||||
loop.run_until_complete(client(srv.addr))
|
||||
|
|
|
@ -65,6 +65,7 @@ cdef class SSLProtocol:
|
|||
|
||||
bint _ssl_writing_paused
|
||||
bint _app_reading_paused
|
||||
bint _eof_received
|
||||
|
||||
size_t _incoming_high_water
|
||||
size_t _incoming_low_water
|
||||
|
|
|
@ -278,6 +278,7 @@ cdef class SSLProtocol:
|
|||
self._incoming_high_water = 0
|
||||
self._incoming_low_water = 0
|
||||
self._set_read_buffer_limits()
|
||||
self._eof_received = False
|
||||
|
||||
self._app_writing_paused = False
|
||||
self._outgoing_high_water = 0
|
||||
|
@ -391,6 +392,7 @@ cdef class SSLProtocol:
|
|||
will close itself. If it returns a true value, closing the
|
||||
transport is up to the protocol.
|
||||
"""
|
||||
self._eof_received = True
|
||||
try:
|
||||
if self._loop.get_debug():
|
||||
aio_logger.debug("%r received EOF", self)
|
||||
|
@ -400,9 +402,10 @@ cdef class SSLProtocol:
|
|||
|
||||
elif self._state == WRAPPED:
|
||||
self._set_state(FLUSHING)
|
||||
self._do_write()
|
||||
self._set_state(SHUTDOWN)
|
||||
self._do_shutdown()
|
||||
if self._app_reading_paused:
|
||||
return True
|
||||
else:
|
||||
self._do_flush()
|
||||
|
||||
elif self._state == FLUSHING:
|
||||
self._do_write()
|
||||
|
@ -412,11 +415,14 @@ cdef class SSLProtocol:
|
|||
elif self._state == SHUTDOWN:
|
||||
self._do_shutdown()
|
||||
|
||||
finally:
|
||||
except Exception:
|
||||
self._transport.close()
|
||||
raise
|
||||
|
||||
cdef _get_extra_info(self, name, default=None):
|
||||
if name in self._extra:
|
||||
if name == 'uvloop.sslproto':
|
||||
return self
|
||||
elif name in self._extra:
|
||||
return self._extra[name]
|
||||
elif self._transport is not None:
|
||||
return self._transport.get_extra_info(name, default)
|
||||
|
@ -555,33 +561,14 @@ cdef class SSLProtocol:
|
|||
aio_TimeoutError('SSL shutdown timed out'))
|
||||
|
||||
cdef _do_flush(self):
|
||||
if self._write_backlog:
|
||||
try:
|
||||
while True:
|
||||
# data is discarded when FLUSHING
|
||||
chunk_size = len(self._sslobj_read(SSL_READ_MAX_SIZE))
|
||||
if not chunk_size:
|
||||
# close_notify
|
||||
break
|
||||
except ssl_SSLAgainErrors as exc:
|
||||
pass
|
||||
except ssl_SSLError as exc:
|
||||
self._on_shutdown_complete(exc)
|
||||
return
|
||||
|
||||
try:
|
||||
self._do_write()
|
||||
except Exception as exc:
|
||||
self._on_shutdown_complete(exc)
|
||||
return
|
||||
|
||||
if not self._write_backlog:
|
||||
self._set_state(SHUTDOWN)
|
||||
self._do_shutdown()
|
||||
self._do_read()
|
||||
self._set_state(SHUTDOWN)
|
||||
self._do_shutdown()
|
||||
|
||||
cdef _do_shutdown(self):
|
||||
try:
|
||||
self._sslobj.unwrap()
|
||||
if not self._eof_received:
|
||||
self._sslobj.unwrap()
|
||||
except ssl_SSLAgainErrors as exc:
|
||||
self._process_outgoing()
|
||||
except ssl_SSLError as exc:
|
||||
|
@ -655,7 +642,7 @@ cdef class SSLProtocol:
|
|||
# Incoming flow
|
||||
|
||||
cdef _do_read(self):
|
||||
if self._state != WRAPPED:
|
||||
if self._state != WRAPPED and self._state != FLUSHING:
|
||||
return
|
||||
try:
|
||||
if not self._app_reading_paused:
|
||||
|
|
Loading…
Reference in New Issue