diff --git a/mitmproxy/net/udp.py b/mitmproxy/net/udp.py index 328fcba75..3590d4284 100644 --- a/mitmproxy/net/udp.py +++ b/mitmproxy/net/udp.py @@ -175,8 +175,12 @@ class DatagramWriter: """ self._transport = transport self._remote_addr = remote_addr - self._reader = reader - self._closed = asyncio.Event() if reader is not None else None + if reader is not None: + self._reader = reader + self._closed = asyncio.Event() + else: + self._reader = None + self._closed = None @property def _protocol(self) -> DrainableDatagramProtocol | udp_wireguard.WireGuardDatagramTransport: @@ -199,9 +203,15 @@ class DatagramWriter: self._transport.close() else: self._closed.set() - if self._reader is not None: + assert self._reader self._reader.feed_eof() + def is_closing(self) -> bool: + if self._closed is None: + return self._transport.is_closing() + else: + return self._closed.is_set() + async def wait_closed(self) -> None: if self._closed is None: await self._protocol.wait_closed() diff --git a/mitmproxy/proxy/server.py b/mitmproxy/proxy/server.py index 06033e427..1341564ad 100644 --- a/mitmproxy/proxy/server.py +++ b/mitmproxy/proxy/server.py @@ -367,7 +367,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta): elif isinstance(command, commands.SendData): writer = self.transports[command.connection].writer assert writer - writer.write(command.data) + if not writer.is_closing(): + writer.write(command.data) elif isinstance(command, commands.CloseConnection): self.close_connection(command.connection, command.half_close) elif isinstance(command, commands.StartHook): @@ -393,7 +394,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta): try: writer = self.transports[connection].writer assert writer - writer.write_eof() + if not writer.is_closing(): + writer.write_eof() except OSError: # if we can't write to the socket anymore we presume it completely dead. connection.state = ConnectionState.CLOSED diff --git a/test/mitmproxy/net/test_udp.py b/test/mitmproxy/net/test_udp.py index 6e5a9f623..f90538f55 100644 --- a/test/mitmproxy/net/test_udp.py +++ b/test/mitmproxy/net/test_udp.py @@ -45,11 +45,16 @@ async def test_client_server(): server.resume_writing() await server.drain() + assert not client_writer.is_closing() + assert not server_writer.is_closing() + assert await client_reader.read(MAX_DATAGRAM_SIZE) == b"msg4" client_writer.close() + assert client_writer.is_closing() await client_writer.wait_closed() server_writer.close() + assert server_writer.is_closing() await server_writer.wait_closed() server.close()