Revert "Revert "Don't write to closing streams, fix #5363 (#5589)""

This reverts commit ec5a74cd0e.
This commit is contained in:
Maximilian Hils 2022-09-22 18:29:11 +02:00
parent ec5a74cd0e
commit db9b3c21be
3 changed files with 22 additions and 5 deletions

View File

@ -175,8 +175,12 @@ class DatagramWriter:
"""
self._transport = transport
self._remote_addr = remote_addr
self._reader = reader
self._closed = asyncio.Event() if reader is not None else None
if reader is not None:
self._reader = reader
self._closed = asyncio.Event()
else:
self._reader = None
self._closed = None
@property
def _protocol(self) -> DrainableDatagramProtocol | udp_wireguard.WireGuardDatagramTransport:
@ -199,9 +203,15 @@ class DatagramWriter:
self._transport.close()
else:
self._closed.set()
if self._reader is not None:
assert self._reader
self._reader.feed_eof()
def is_closing(self) -> bool:
if self._closed is None:
return self._transport.is_closing()
else:
return self._closed.is_set()
async def wait_closed(self) -> None:
if self._closed is None:
await self._protocol.wait_closed()

View File

@ -367,7 +367,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
elif isinstance(command, commands.SendData):
writer = self.transports[command.connection].writer
assert writer
writer.write(command.data)
if not writer.is_closing():
writer.write(command.data)
elif isinstance(command, commands.CloseConnection):
self.close_connection(command.connection, command.half_close)
elif isinstance(command, commands.StartHook):
@ -393,7 +394,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
try:
writer = self.transports[connection].writer
assert writer
writer.write_eof()
if not writer.is_closing():
writer.write_eof()
except OSError:
# if we can't write to the socket anymore we presume it completely dead.
connection.state = ConnectionState.CLOSED

View File

@ -45,11 +45,16 @@ async def test_client_server():
server.resume_writing()
await server.drain()
assert not client_writer.is_closing()
assert not server_writer.is_closing()
assert await client_reader.read(MAX_DATAGRAM_SIZE) == b"msg4"
client_writer.close()
assert client_writer.is_closing()
await client_writer.wait_closed()
server_writer.close()
assert server_writer.is_closing()
await server_writer.wait_closed()
server.close()