diff --git a/mitmproxy/net/udp.py b/mitmproxy/net/udp.py index 3590d4284..328fcba75 100644 --- a/mitmproxy/net/udp.py +++ b/mitmproxy/net/udp.py @@ -175,12 +175,8 @@ class DatagramWriter: """ self._transport = transport self._remote_addr = remote_addr - if reader is not None: - self._reader = reader - self._closed = asyncio.Event() - else: - self._reader = None - self._closed = None + self._reader = reader + self._closed = asyncio.Event() if reader is not None else None @property def _protocol(self) -> DrainableDatagramProtocol | udp_wireguard.WireGuardDatagramTransport: @@ -203,15 +199,9 @@ class DatagramWriter: self._transport.close() else: self._closed.set() - assert self._reader + if self._reader is not None: self._reader.feed_eof() - def is_closing(self) -> bool: - if self._closed is None: - return self._transport.is_closing() - else: - return self._closed.is_set() - async def wait_closed(self) -> None: if self._closed is None: await self._protocol.wait_closed() diff --git a/mitmproxy/proxy/server.py b/mitmproxy/proxy/server.py index 1341564ad..06033e427 100644 --- a/mitmproxy/proxy/server.py +++ b/mitmproxy/proxy/server.py @@ -367,8 +367,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta): elif isinstance(command, commands.SendData): writer = self.transports[command.connection].writer assert writer - if not writer.is_closing(): - writer.write(command.data) + writer.write(command.data) elif isinstance(command, commands.CloseConnection): self.close_connection(command.connection, command.half_close) elif isinstance(command, commands.StartHook): @@ -394,8 +393,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta): try: writer = self.transports[connection].writer assert writer - if not writer.is_closing(): - writer.write_eof() + writer.write_eof() except OSError: # if we can't write to the socket anymore we presume it completely dead. connection.state = ConnectionState.CLOSED diff --git a/test/mitmproxy/net/test_udp.py b/test/mitmproxy/net/test_udp.py index f90538f55..6e5a9f623 100644 --- a/test/mitmproxy/net/test_udp.py +++ b/test/mitmproxy/net/test_udp.py @@ -45,16 +45,11 @@ async def test_client_server(): server.resume_writing() await server.drain() - assert not client_writer.is_closing() - assert not server_writer.is_closing() - assert await client_reader.read(MAX_DATAGRAM_SIZE) == b"msg4" client_writer.close() - assert client_writer.is_closing() await client_writer.wait_closed() server_writer.close() - assert server_writer.is_closing() await server_writer.wait_closed() server.close()