diff --git a/mitmproxy/addons/dumper.py b/mitmproxy/addons/dumper.py index 91762cde1..73e45da5d 100644 --- a/mitmproxy/addons/dumper.py +++ b/mitmproxy/addons/dumper.py @@ -314,7 +314,7 @@ class Dumper: def format_websocket_error(self, websocket: WebSocketData) -> str: try: - ret = CloseReason(websocket.close_code).name + ret = CloseReason(websocket.close_code).name # type: ignore except ValueError: ret = f"UNKNOWN_ERROR={websocket.close_code}" if websocket.close_reason: @@ -362,8 +362,8 @@ class Dumper: desc = f"DNS {opcode} ({type})" desc_color = { - "DNS QUERY (A)": "green", - "DNS QUERY (AAAA)": "magenta", + "A": "green", + "AAAA": "magenta", }.get(type, "red") desc = self.style(desc, fg=desc_color) diff --git a/mitmproxy/command.py b/mitmproxy/command.py index 8051323e7..950fa44ef 100644 --- a/mitmproxy/command.py +++ b/mitmproxy/command.py @@ -5,6 +5,7 @@ import functools import inspect import logging +import pyparsing import sys import textwrap import types @@ -194,7 +195,7 @@ class CommandManager: Parse a possibly partial command. Return a sequence of ParseResults and a sequence of remainder type help items. """ - parts: list[str] = command_lexer.expr.parseString(cmdstr, parseAll=True) + parts: pyparsing.ParseResults = command_lexer.expr.parseString(cmdstr, parseAll=True) parsed: list[ParseResult] = [] next_params: list[CommandParameter] = [ diff --git a/mitmproxy/net/udp.py b/mitmproxy/net/udp.py index 328fcba75..3590d4284 100644 --- a/mitmproxy/net/udp.py +++ b/mitmproxy/net/udp.py @@ -175,8 +175,12 @@ class DatagramWriter: """ self._transport = transport self._remote_addr = remote_addr - self._reader = reader - self._closed = asyncio.Event() if reader is not None else None + if reader is not None: + self._reader = reader + self._closed = asyncio.Event() + else: + self._reader = None + self._closed = None @property def _protocol(self) -> DrainableDatagramProtocol | udp_wireguard.WireGuardDatagramTransport: @@ -199,9 +203,15 @@ class DatagramWriter: self._transport.close() else: self._closed.set() - if self._reader is not None: + assert self._reader self._reader.feed_eof() + def is_closing(self) -> bool: + if self._closed is None: + return self._transport.is_closing() + else: + return self._closed.is_set() + async def wait_closed(self) -> None: if self._closed is None: await self._protocol.wait_closed() diff --git a/mitmproxy/proxy/layers/http/_http1.py b/mitmproxy/proxy/layers/http/_http1.py index 4cab5bd9b..cae151310 100644 --- a/mitmproxy/proxy/layers/http/_http1.py +++ b/mitmproxy/proxy/layers/http/_http1.py @@ -141,7 +141,7 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta): def make_pipe(self) -> layer.CommandGenerator[None]: self.state = self.passthrough if self.buf: - already_received = self.buf.maybe_extract_at_most(len(self.buf)) + already_received = self.buf.maybe_extract_at_most(len(self.buf)) or b"" # Some clients send superfluous newlines after CONNECT, we want to eat those. already_received = already_received.lstrip(b"\r\n") if already_received: @@ -264,11 +264,10 @@ class Http1Server(Http1Connection): if isinstance(event, events.DataReceived): request_head = self.buf.maybe_extract_lines() if request_head: - request_head = [ - bytes(x) for x in request_head - ] # TODO: Make url.parse compatible with bytearrays try: - self.request = http1.read_request_head(request_head) + self.request = http1.read_request_head( + [bytes(x) for x in request_head] + ) if self.context.options.validate_inbound_headers: http1.validate_headers(self.request.headers) expected_body_size = http1.expected_http_body_size(self.request) @@ -388,11 +387,10 @@ class Http1Client(Http1Connection): response_head = self.buf.maybe_extract_lines() if response_head: - response_head = [ - bytes(x) for x in response_head - ] # TODO: Make url.parse compatible with bytearrays try: - self.response = http1.read_response_head(response_head) + self.response = http1.read_response_head( + [bytes(x) for x in response_head] + ) if self.context.options.validate_inbound_headers: http1.validate_headers(self.response.headers) expected_size = http1.expected_http_body_size( diff --git a/mitmproxy/proxy/layers/http/_upstream_proxy.py b/mitmproxy/proxy/layers/http/_upstream_proxy.py index a9227918d..62909bead 100644 --- a/mitmproxy/proxy/layers/http/_upstream_proxy.py +++ b/mitmproxy/proxy/layers/http/_upstream_proxy.py @@ -75,11 +75,10 @@ class HttpUpstreamProxy(tunnel.TunnelLayer): self.buf += data response_head = self.buf.maybe_extract_lines() if response_head: - response_head = [ - bytes(x) for x in response_head - ] # TODO: Make url.parse compatible with bytearrays try: - response = http1.read_response_head(response_head) + response = http1.read_response_head([ + bytes(x) for x in response_head + ]) except ValueError as e: proxyaddr = human.format_address(self.tunnel_connection.address) yield commands.Log(f"{proxyaddr}: {e}") diff --git a/mitmproxy/proxy/mode_servers.py b/mitmproxy/proxy/mode_servers.py index 1baebea34..000644022 100644 --- a/mitmproxy/proxy/mode_servers.py +++ b/mitmproxy/proxy/mode_servers.py @@ -144,10 +144,10 @@ class ServerInstance(Generic[M], metaclass=ABCMeta): ) handler.layer = self.make_top_layer(handler.layer.context) if isinstance(self.mode, mode_specs.TransparentMode): - socket = writer.get_extra_info("socket") + s = cast(socket.socket, writer.get_extra_info("socket")) try: assert platform.original_addr - original_dst = platform.original_addr(socket) + original_dst = platform.original_addr(s) except Exception as e: logger.error(f"Transparent mode failure: {e!r}") return @@ -390,6 +390,7 @@ class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]): await self.handle_tcp_connection(stream, stream) def wg_handle_udp_datagram(self, data: bytes, remote_addr: Address, local_addr: Address) -> None: + assert self._server is not None transport = WireGuardDatagramTransport(self._server, local_addr, remote_addr) self.handle_udp_datagram( transport, diff --git a/mitmproxy/proxy/server.py b/mitmproxy/proxy/server.py index 06033e427..1341564ad 100644 --- a/mitmproxy/proxy/server.py +++ b/mitmproxy/proxy/server.py @@ -367,7 +367,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta): elif isinstance(command, commands.SendData): writer = self.transports[command.connection].writer assert writer - writer.write(command.data) + if not writer.is_closing(): + writer.write(command.data) elif isinstance(command, commands.CloseConnection): self.close_connection(command.connection, command.half_close) elif isinstance(command, commands.StartHook): @@ -393,7 +394,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta): try: writer = self.transports[connection].writer assert writer - writer.write_eof() + if not writer.is_closing(): + writer.write_eof() except OSError: # if we can't write to the socket anymore we presume it completely dead. connection.state = ConnectionState.CLOSED diff --git a/test/mitmproxy/net/test_udp.py b/test/mitmproxy/net/test_udp.py index 6e5a9f623..f90538f55 100644 --- a/test/mitmproxy/net/test_udp.py +++ b/test/mitmproxy/net/test_udp.py @@ -45,11 +45,16 @@ async def test_client_server(): server.resume_writing() await server.drain() + assert not client_writer.is_closing() + assert not server_writer.is_closing() + assert await client_reader.read(MAX_DATAGRAM_SIZE) == b"msg4" client_writer.close() + assert client_writer.is_closing() await client_writer.wait_closed() server_writer.close() + assert server_writer.is_closing() await server_writer.wait_closed() server.close() diff --git a/tox.ini b/tox.ini index 99f608fff..029e480e3 100644 --- a/tox.ini +++ b/tox.ini @@ -36,6 +36,7 @@ deps = types-requests==2.28.10 types-cryptography==3.3.23 types-pyOpenSSL==22.0.10 + -e .[dev] commands = mypy {posargs}