diff --git a/CHANGELOG.md b/CHANGELOG.md index 4aaab21f4..2ea648305 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ ([#6548](https://github.com/mitmproxy/mitmproxy/pull/6548), @zanieb) * Improved handling for `--allow-hosts`/`--ignore-hosts` options in WireGuard mode (#5930). ([#6513](https://github.com/mitmproxy/mitmproxy/pull/6513), @dsphper) +* Fix a bug where TCP connections were not closed properly. + ([#6543](https://github.com/mitmproxy/mitmproxy/pull/6543), @mhils) * DNS resolution is now exempted from `--ignore-hosts` in WireGuard Mode. ([#6513](https://github.com/mitmproxy/mitmproxy/pull/6513), @dsphper) * Fix a bug where logging was stopped prematurely during shutdown. diff --git a/mitmproxy/addons/clientplayback.py b/mitmproxy/addons/clientplayback.py index 53b202505..d24548c39 100644 --- a/mitmproxy/addons/clientplayback.py +++ b/mitmproxy/addons/clientplayback.py @@ -161,9 +161,13 @@ class ClientPlayback: ) self.options = ctx.options - def done(self): + async def done(self): if self.playback_task: self.playback_task.cancel() + try: + await self.playback_task + except asyncio.CancelledError: + pass async def playback(self): while True: diff --git a/mitmproxy/optmanager.py b/mitmproxy/optmanager.py index b116f31d1..3dbc215c5 100644 --- a/mitmproxy/optmanager.py +++ b/mitmproxy/optmanager.py @@ -523,7 +523,7 @@ def parse(text): if not text: return {} try: - yaml = ruamel.yaml.YAML(typ="unsafe", pure=True) + yaml = ruamel.yaml.YAML(typ="safe", pure=True) data = yaml.load(text) except ruamel.yaml.error.YAMLError as v: if hasattr(v, "problem_mark"): diff --git a/mitmproxy/proxy/mode_servers.py b/mitmproxy/proxy/mode_servers.py index 5ada94ed1..f051a9f45 100644 --- a/mitmproxy/proxy/mode_servers.py +++ b/mitmproxy/proxy/mode_servers.py @@ -199,6 +199,7 @@ class ServerInstance(Generic[M], metaclass=ABCMeta): original_dst = platform.original_addr(s) except Exception as e: logger.error(f"Transparent mode failure: {e!r}") + writer.close() return else: handler.layer.context.client.sockname = original_dst diff --git a/mitmproxy/proxy/server.py b/mitmproxy/proxy/server.py index c99026192..58068fa80 100644 --- a/mitmproxy/proxy/server.py +++ b/mitmproxy/proxy/server.py @@ -308,7 +308,10 @@ class ConnectionHandler(metaclass=abc.ABCMeta): # we may still use this connection to *send* stuff, # even though the remote has closed their side of the connection. # to make this work we keep this task running and wait for cancellation. - await asyncio.Event().wait() + try: + await asyncio.Event().wait() + except asyncio.CancelledError as e: + cancelled = e try: writer = self.transports[connection].writer @@ -336,10 +339,15 @@ class ConnectionHandler(metaclass=abc.ABCMeta): transport.handler.cancel(f"Error sending data: {e}") async def on_timeout(self) -> None: - self.log(f"Closing connection due to inactivity: {self.client}") - handler = self.transports[self.client].handler - assert handler - handler.cancel("timeout") + try: + handler = self.transports[self.client].handler + except KeyError: # pragma: no cover + # there is a super short window between connection close and watchdog cancellation + pass + else: + self.log(f"Closing connection due to inactivity: {self.client}") + assert handler + handler.cancel("timeout") async def hook_task(self, hook: commands.StartHook) -> None: await self.handle_hook(hook) diff --git a/pyproject.toml b/pyproject.toml index 19b952e3d..6844a316f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,6 +130,8 @@ testpaths = "test" addopts = "--capture=no --color=yes" filterwarnings = [ "ignore::DeprecationWarning:tornado.*:", + "error::RuntimeWarning", + "error::pytest.PytestUnraisableExceptionWarning", ] [tool.mypy] diff --git a/test/mitmproxy/addons/test_asgiapp.py b/test/mitmproxy/addons/test_asgiapp.py index 0926f091e..233188b39 100644 --- a/test/mitmproxy/addons/test_asgiapp.py +++ b/test/mitmproxy/addons/test_asgiapp.py @@ -57,57 +57,70 @@ async def test_asgi_full(caplog): assert await ps.setup_servers() proxy_addr = ("127.0.0.1", ps.listen_addrs()[0][1]) - reader, writer = await asyncio.open_connection(*proxy_addr) + # We parallelize connection establishment/closure because those operations tend to be slow. + [ + (r1, w1), + (r2, w2), + (r3, w3), + (r4, w4), + (r5, w5), + ] = await asyncio.gather( + asyncio.open_connection(*proxy_addr), + asyncio.open_connection(*proxy_addr), + asyncio.open_connection(*proxy_addr), + asyncio.open_connection(*proxy_addr), + asyncio.open_connection(*proxy_addr), + ) + req = f"GET http://testapp:80/ HTTP/1.1\r\n\r\n" - writer.write(req.encode()) - header = await reader.readuntil(b"\r\n\r\n") + w1.write(req.encode()) + header = await r1.readuntil(b"\r\n\r\n") assert header.startswith(b"HTTP/1.1 200 OK") - body = await reader.readuntil(b"testapp") + body = await r1.readuntil(b"testapp") assert body == b"testapp" - writer.close() - await writer.wait_closed() - reader, writer = await asyncio.open_connection(*proxy_addr) req = f"GET http://testapp:80/parameters?param1=1¶m2=2 HTTP/1.1\r\n\r\n" - writer.write(req.encode()) - header = await reader.readuntil(b"\r\n\r\n") + w2.write(req.encode()) + header = await r2.readuntil(b"\r\n\r\n") assert header.startswith(b"HTTP/1.1 200 OK") - body = await reader.readuntil(b"}") + body = await r2.readuntil(b"}") assert body == b'{"param1": "1", "param2": "2"}' - writer.close() - await writer.wait_closed() - reader, writer = await asyncio.open_connection(*proxy_addr) req = f"POST http://testapp:80/requestbody HTTP/1.1\r\nContent-Length: 6\r\n\r\nHello!" - writer.write(req.encode()) - header = await reader.readuntil(b"\r\n\r\n") + w3.write(req.encode()) + header = await r3.readuntil(b"\r\n\r\n") assert header.startswith(b"HTTP/1.1 200 OK") - body = await reader.readuntil(b"}") + body = await r3.readuntil(b"}") assert body == b'{"body": "Hello!"}' - writer.close() - await writer.wait_closed() - reader, writer = await asyncio.open_connection(*proxy_addr) req = f"GET http://errapp:80/?foo=bar HTTP/1.1\r\n\r\n" - writer.write(req.encode()) - header = await reader.readuntil(b"\r\n\r\n") + w4.write(req.encode()) + header = await r4.readuntil(b"\r\n\r\n") assert header.startswith(b"HTTP/1.1 500") - body = await reader.readuntil(b"ASGI Error") + body = await r4.readuntil(b"ASGI Error") assert body == b"ASGI Error" - writer.close() - await writer.wait_closed() assert "ValueError" in caplog.text - reader, writer = await asyncio.open_connection(*proxy_addr) req = f"GET http://noresponseapp:80/ HTTP/1.1\r\n\r\n" - writer.write(req.encode()) - header = await reader.readuntil(b"\r\n\r\n") + w5.write(req.encode()) + header = await r5.readuntil(b"\r\n\r\n") assert header.startswith(b"HTTP/1.1 500") - body = await reader.readuntil(b"ASGI Error") + body = await r5.readuntil(b"ASGI Error") assert body == b"ASGI Error" - writer.close() - await writer.wait_closed() assert "no response sent" in caplog.text + w1.close() + w2.close() + w3.close() + w4.close() + w5.close() + await asyncio.gather( + w1.wait_closed(), + w2.wait_closed(), + w3.wait_closed(), + w4.wait_closed(), + w5.wait_closed(), + ) + tctx.configure(ps, server=False) assert await ps.setup_servers() diff --git a/test/mitmproxy/addons/test_clientplayback.py b/test/mitmproxy/addons/test_clientplayback.py index f9f889e42..f9fc32ae9 100644 --- a/test/mitmproxy/addons/test_clientplayback.py +++ b/test/mitmproxy/addons/test_clientplayback.py @@ -17,23 +17,48 @@ from mitmproxy.test import tflow @asynccontextmanager async def tcp_server(handle_conn, **server_args) -> Address: - server = await asyncio.start_server(handle_conn, "127.0.0.1", 0, **server_args) + """TCP server context manager that... + + 1. Exits only after all handlers have returned. + 2. Ensures that all handlers are closed properly. If we don't do that, + we get ghost errors in others tests from StreamWriter.__del__. + + Spawning a TCP server is relatively slow. Consider using in-memory networking for faster tests. + """ + if not hasattr(asyncio, "TaskGroup"): + pytest.skip("Skipped because asyncio.TaskGroup is unavailable.") + + tasks = asyncio.TaskGroup() + + async def handle_conn_wrapper( + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> None: + try: + await handle_conn(reader, writer) + except Exception as e: + print(f"!!! TCP handler failed: {e}") + raise + finally: + if not writer.is_closing(): + writer.close() + await writer.wait_closed() + + async def _handle(r, w): + tasks.create_task(handle_conn_wrapper(r, w)) + + server = await asyncio.start_server(_handle, "127.0.0.1", 0, **server_args) await server.start_serving() - try: - yield server.sockets[0].getsockname() - finally: - server.close() + async with server: + async with tasks: + yield server.sockets[0].getsockname() @pytest.mark.parametrize("mode", ["http", "https", "upstream", "err"]) @pytest.mark.parametrize("concurrency", [-1, 1]) async def test_playback(tdata, mode, concurrency): - handler_ok = asyncio.Event() - async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): if mode == "err": - writer.close() - handler_ok.set() return req = await reader.readline() if mode == "upstream": @@ -49,7 +74,6 @@ async def test_playback(tdata, mode, concurrency): writer.write(b"HTTP/1.1 204 No Content\r\n\r\n") await writer.drain() assert not await reader.read() - handler_ok.set() cp = ClientPlayback() ps = Proxyserver() @@ -92,22 +116,20 @@ async def test_playback(tdata, mode, concurrency): cp.start_replay([flow]) assert cp.count() == 1 await asyncio.wait_for(cp.queue.join(), 5) - await asyncio.wait_for(handler_ok.wait(), 5) - cp.done() - if mode != "err": - assert flow.response.status_code == 204 + while cp.replay_tasks: + await asyncio.sleep(0.001) + if mode != "err": + assert flow.response.status_code == 204 + await cp.done() async def test_playback_https_upstream(): - handler_ok = asyncio.Event() - async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): conn_req = await reader.readuntil(b"\r\n\r\n") assert conn_req == b"CONNECT address:22 HTTP/1.1\r\n\r\n" writer.write(b"HTTP/1.1 502 Bad Gateway\r\n\r\n") await writer.drain() assert not await reader.read() - handler_ok.set() cp = ClientPlayback() ps = Proxyserver() @@ -122,17 +144,17 @@ async def test_playback_https_upstream(): cp.start_replay([flow]) assert cp.count() == 1 await asyncio.wait_for(cp.queue.join(), 5) - await asyncio.wait_for(handler_ok.wait(), 5) - cp.done() - assert flow.response is None - assert ( - str(flow.error) - == f"Upstream proxy {addr[0]}:{addr[1]} refused HTTP CONNECT request: 502 Bad Gateway" - ) + + assert flow.response is None + assert ( + str(flow.error) + == f"Upstream proxy {addr[0]}:{addr[1]} refused HTTP CONNECT request: 502 Bad Gateway" + ) + await cp.done() async def test_playback_crash(monkeypatch, caplog_async): - async def raise_err(): + async def raise_err(*_, **__): raise ValueError("oops") monkeypatch.setattr(ReplayHandler, "replay", raise_err) @@ -141,8 +163,9 @@ async def test_playback_crash(monkeypatch, caplog_async): cp.running() cp.start_replay([tflow.tflow(live=False)]) await caplog_async.await_log("Client replay has crashed!") + assert "oops" in caplog_async.caplog.text assert cp.count() == 0 - cp.done() + await cp.done() def test_check(): diff --git a/test/mitmproxy/addons/test_proxyserver.py b/test/mitmproxy/addons/test_proxyserver.py index 341bfbe09..00f959fb2 100644 --- a/test/mitmproxy/addons/test_proxyserver.py +++ b/test/mitmproxy/addons/test_proxyserver.py @@ -23,6 +23,7 @@ from aioquic.quic.configuration import QuicConfiguration from aioquic.quic.connection import QuicConnection from aioquic.quic.connection import QuicConnectionError +from .test_clientplayback import tcp_server import mitmproxy.platform from mitmproxy import dns from mitmproxy import exceptions @@ -55,16 +56,6 @@ class HelperAddon: self.flows.append(f) -@asynccontextmanager -async def tcp_server(handle_conn) -> Address: - server = await asyncio.start_server(handle_conn, "127.0.0.1", 0) - await server.start_serving() - try: - yield server.sockets[0].getsockname() - finally: - server.close() - - async def test_start_stop(caplog_async): caplog_async.set_level("INFO") @@ -74,7 +65,6 @@ async def test_start_stop(caplog_async): assert await reader.readuntil(b"\r\n\r\n") == b"GET /hello HTTP/1.1\r\n\r\n" writer.write(b"HTTP/1.1 204 No Content\r\n\r\n") await writer.drain() - writer.close() ps = Proxyserver() nl = NextLayer() @@ -160,6 +150,9 @@ async def test_inject() -> None: ps.inject_tcp(state.flows[0], True, b"c") assert await reader.read(1) == b"c" + writer.close() + await writer.wait_closed() + async def test_inject_fail(caplog) -> None: ps = Proxyserver() @@ -311,6 +304,9 @@ async def test_dns(caplog_async) -> None: tctx.configure(ps, server=False) await caplog_async.await_log("stopped") + w.close() + await w.wait_closed() + def test_validation_no_transparent(monkeypatch): monkeypatch.setattr(mitmproxy.platform, "original_addr", None) @@ -373,6 +369,9 @@ async def test_udp(caplog_async) -> None: tctx.configure(ps, server=False) await caplog_async.await_log("stopped") + w.close() + await w.wait_closed() + class H3EchoServer(QuicConnectionProtocol): def __init__(self, *args, **kwargs) -> None: @@ -779,6 +778,11 @@ async def test_reverse_http3_and_quic_stream( await _test_echo(client, strict=scheme == "http3") assert len(ps.connections) == 1 + # dirty hack: forcibly close all connections so that there are no unexpected asyncio tasks + # that may cause test failures because they have not been run. + for conn in ps.servers[mode].manager.connections.values(): + await conn.on_timeout() + tctx.configure(ps, server=False) await caplog_async.await_log(f"stopped") diff --git a/test/mitmproxy/addons/test_script.py b/test/mitmproxy/addons/test_script.py index add0801fb..c36279cb4 100644 --- a/test/mitmproxy/addons/test_script.py +++ b/test/mitmproxy/addons/test_script.py @@ -328,6 +328,7 @@ def test_order(tdata, capsys): ] ) time = r"\[[\d:.]+\] " + out = capsys.readouterr().out assert re.match( rf"{time}Loading script.+recorder.py\n" rf"{time}\('recorder', 'load', .+\n" @@ -335,5 +336,5 @@ def test_order(tdata, capsys): rf"{time}Loading script.+shutdown.py\n" rf"{time}\('recorder', 'running', .+\n" rf"{time}\('recorder', 'done', .+\n$", - capsys.readouterr().out, + out, ) diff --git a/test/mitmproxy/proxy/test_mode_servers.py b/test/mitmproxy/proxy/test_mode_servers.py index 274eabe42..9071e28c7 100644 --- a/test/mitmproxy/proxy/test_mode_servers.py +++ b/test/mitmproxy/proxy/test_mode_servers.py @@ -269,6 +269,7 @@ async def test_udp_start_stop(caplog_async): assert await caplog_async.await_log("sent an invalid message") writer.close() + await writer.wait_closed() await inst.stop() assert await caplog_async.await_log("stopped") @@ -324,6 +325,7 @@ async def test_udp_dual_stack(caplog_async): writer.write(b"\x00\x00\x01") assert await caplog_async.await_log("sent an invalid message") writer.close() + await writer.wait_closed() if "listening on IPv4 only" not in caplog_async.caplog.text: caplog_async.clear() @@ -331,6 +333,7 @@ async def test_udp_dual_stack(caplog_async): writer.write(b"\x00\x00\x01") assert await caplog_async.await_log("sent an invalid message") writer.close() + await writer.wait_closed() await inst.stop() assert await caplog_async.await_log("stopped") @@ -338,7 +341,7 @@ async def test_udp_dual_stack(caplog_async): @pytest.fixture() def patched_local_redirector(monkeypatch): - start_local_redirector = AsyncMock() + start_local_redirector = AsyncMock(return_value=Mock()) monkeypatch.setattr(mitmproxy_rs, "start_local_redirector", start_local_redirector) # make sure _server and _instance are restored after this test monkeypatch.setattr(LocalRedirectorInstance, "_server", None) diff --git a/test/mitmproxy/tools/web/test_master.py b/test/mitmproxy/tools/web/test_master.py index ee0225754..e956b73d4 100644 --- a/test/mitmproxy/tools/web/test_master.py +++ b/test/mitmproxy/tools/web/test_master.py @@ -1,5 +1,4 @@ import asyncio -from unittest.mock import MagicMock import pytest @@ -8,8 +7,11 @@ from mitmproxy.tools.web.master import WebMaster async def test_reuse(): + async def handler(r, w): + pass + server = await asyncio.start_server( - MagicMock(), host="127.0.0.1", port=0, reuse_address=False + handler, host="127.0.0.1", port=0, reuse_address=False ) port = server.sockets[0].getsockname()[1] master = WebMaster(Options(), with_termlog=False) @@ -18,3 +20,6 @@ async def test_reuse(): with pytest.raises(OSError, match=f"--set web_port={port + 2}"): await master.running() server.close() + # tornado registers some callbacks, + # we want to run them to avoid fatal warnings. + await asyncio.sleep(0)