* don't write to closing streams, fix #5363 * tests++
This commit is contained in:
parent
d3fb9f4349
commit
a1ddbcad53
|
@ -175,8 +175,12 @@ class DatagramWriter:
|
|||
"""
|
||||
self._transport = transport
|
||||
self._remote_addr = remote_addr
|
||||
self._reader = reader
|
||||
self._closed = asyncio.Event() if reader is not None else None
|
||||
if reader is not None:
|
||||
self._reader = reader
|
||||
self._closed = asyncio.Event()
|
||||
else:
|
||||
self._reader = None
|
||||
self._closed = None
|
||||
|
||||
@property
|
||||
def _protocol(self) -> DrainableDatagramProtocol | udp_wireguard.WireGuardDatagramTransport:
|
||||
|
@ -199,9 +203,15 @@ class DatagramWriter:
|
|||
self._transport.close()
|
||||
else:
|
||||
self._closed.set()
|
||||
if self._reader is not None:
|
||||
assert self._reader
|
||||
self._reader.feed_eof()
|
||||
|
||||
def is_closing(self) -> bool:
|
||||
if self._closed is None:
|
||||
return self._transport.is_closing()
|
||||
else:
|
||||
return self._closed.is_set()
|
||||
|
||||
async def wait_closed(self) -> None:
|
||||
if self._closed is None:
|
||||
await self._protocol.wait_closed()
|
||||
|
|
|
@ -367,7 +367,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||
elif isinstance(command, commands.SendData):
|
||||
writer = self.transports[command.connection].writer
|
||||
assert writer
|
||||
writer.write(command.data)
|
||||
if not writer.is_closing():
|
||||
writer.write(command.data)
|
||||
elif isinstance(command, commands.CloseConnection):
|
||||
self.close_connection(command.connection, command.half_close)
|
||||
elif isinstance(command, commands.StartHook):
|
||||
|
@ -393,7 +394,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||
try:
|
||||
writer = self.transports[connection].writer
|
||||
assert writer
|
||||
writer.write_eof()
|
||||
if not writer.is_closing():
|
||||
writer.write_eof()
|
||||
except OSError:
|
||||
# if we can't write to the socket anymore we presume it completely dead.
|
||||
connection.state = ConnectionState.CLOSED
|
||||
|
|
|
@ -45,11 +45,16 @@ async def test_client_server():
|
|||
server.resume_writing()
|
||||
await server.drain()
|
||||
|
||||
assert not client_writer.is_closing()
|
||||
assert not server_writer.is_closing()
|
||||
|
||||
assert await client_reader.read(MAX_DATAGRAM_SIZE) == b"msg4"
|
||||
client_writer.close()
|
||||
assert client_writer.is_closing()
|
||||
await client_writer.wait_closed()
|
||||
|
||||
server_writer.close()
|
||||
assert server_writer.is_closing()
|
||||
await server_writer.wait_closed()
|
||||
|
||||
server.close()
|
||||
|
|
Loading…
Reference in New Issue