fix a bug where connections would not be fully closed (#6543)

This commit is contained in:
Maximilian Hils 2023-12-12 19:15:19 +01:00 committed by GitHub
parent 1fcd0335d5
commit 0a3e016d39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 144 additions and 78 deletions

View File

@ -11,6 +11,8 @@
([#6548](https://github.com/mitmproxy/mitmproxy/pull/6548), @zanieb)
* Improved handling for `--allow-hosts`/`--ignore-hosts` options in WireGuard mode (#5930).
([#6513](https://github.com/mitmproxy/mitmproxy/pull/6513), @dsphper)
* Fix a bug where TCP connections were not closed properly.
([#6543](https://github.com/mitmproxy/mitmproxy/pull/6543), @mhils)
* DNS resolution is now exempted from `--ignore-hosts` in WireGuard Mode.
([#6513](https://github.com/mitmproxy/mitmproxy/pull/6513), @dsphper)
* Fix a bug where logging was stopped prematurely during shutdown.

View File

@ -161,9 +161,13 @@ class ClientPlayback:
)
self.options = ctx.options
def done(self):
async def done(self):
if self.playback_task:
self.playback_task.cancel()
try:
await self.playback_task
except asyncio.CancelledError:
pass
async def playback(self):
while True:

View File

@ -523,7 +523,7 @@ def parse(text):
if not text:
return {}
try:
yaml = ruamel.yaml.YAML(typ="unsafe", pure=True)
yaml = ruamel.yaml.YAML(typ="safe", pure=True)
data = yaml.load(text)
except ruamel.yaml.error.YAMLError as v:
if hasattr(v, "problem_mark"):

View File

@ -199,6 +199,7 @@ class ServerInstance(Generic[M], metaclass=ABCMeta):
original_dst = platform.original_addr(s)
except Exception as e:
logger.error(f"Transparent mode failure: {e!r}")
writer.close()
return
else:
handler.layer.context.client.sockname = original_dst

View File

@ -308,7 +308,10 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
# we may still use this connection to *send* stuff,
# even though the remote has closed their side of the connection.
# to make this work we keep this task running and wait for cancellation.
await asyncio.Event().wait()
try:
await asyncio.Event().wait()
except asyncio.CancelledError as e:
cancelled = e
try:
writer = self.transports[connection].writer
@ -336,10 +339,15 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
transport.handler.cancel(f"Error sending data: {e}")
async def on_timeout(self) -> None:
self.log(f"Closing connection due to inactivity: {self.client}")
handler = self.transports[self.client].handler
assert handler
handler.cancel("timeout")
try:
handler = self.transports[self.client].handler
except KeyError: # pragma: no cover
# there is a super short window between connection close and watchdog cancellation
pass
else:
self.log(f"Closing connection due to inactivity: {self.client}")
assert handler
handler.cancel("timeout")
async def hook_task(self, hook: commands.StartHook) -> None:
await self.handle_hook(hook)

View File

@ -130,6 +130,8 @@ testpaths = "test"
addopts = "--capture=no --color=yes"
filterwarnings = [
"ignore::DeprecationWarning:tornado.*:",
"error::RuntimeWarning",
"error::pytest.PytestUnraisableExceptionWarning",
]
[tool.mypy]

View File

@ -57,57 +57,70 @@ async def test_asgi_full(caplog):
assert await ps.setup_servers()
proxy_addr = ("127.0.0.1", ps.listen_addrs()[0][1])
reader, writer = await asyncio.open_connection(*proxy_addr)
# We parallelize connection establishment/closure because those operations tend to be slow.
[
(r1, w1),
(r2, w2),
(r3, w3),
(r4, w4),
(r5, w5),
] = await asyncio.gather(
asyncio.open_connection(*proxy_addr),
asyncio.open_connection(*proxy_addr),
asyncio.open_connection(*proxy_addr),
asyncio.open_connection(*proxy_addr),
asyncio.open_connection(*proxy_addr),
)
req = f"GET http://testapp:80/ HTTP/1.1\r\n\r\n"
writer.write(req.encode())
header = await reader.readuntil(b"\r\n\r\n")
w1.write(req.encode())
header = await r1.readuntil(b"\r\n\r\n")
assert header.startswith(b"HTTP/1.1 200 OK")
body = await reader.readuntil(b"testapp")
body = await r1.readuntil(b"testapp")
assert body == b"testapp"
writer.close()
await writer.wait_closed()
reader, writer = await asyncio.open_connection(*proxy_addr)
req = f"GET http://testapp:80/parameters?param1=1&param2=2 HTTP/1.1\r\n\r\n"
writer.write(req.encode())
header = await reader.readuntil(b"\r\n\r\n")
w2.write(req.encode())
header = await r2.readuntil(b"\r\n\r\n")
assert header.startswith(b"HTTP/1.1 200 OK")
body = await reader.readuntil(b"}")
body = await r2.readuntil(b"}")
assert body == b'{"param1": "1", "param2": "2"}'
writer.close()
await writer.wait_closed()
reader, writer = await asyncio.open_connection(*proxy_addr)
req = f"POST http://testapp:80/requestbody HTTP/1.1\r\nContent-Length: 6\r\n\r\nHello!"
writer.write(req.encode())
header = await reader.readuntil(b"\r\n\r\n")
w3.write(req.encode())
header = await r3.readuntil(b"\r\n\r\n")
assert header.startswith(b"HTTP/1.1 200 OK")
body = await reader.readuntil(b"}")
body = await r3.readuntil(b"}")
assert body == b'{"body": "Hello!"}'
writer.close()
await writer.wait_closed()
reader, writer = await asyncio.open_connection(*proxy_addr)
req = f"GET http://errapp:80/?foo=bar HTTP/1.1\r\n\r\n"
writer.write(req.encode())
header = await reader.readuntil(b"\r\n\r\n")
w4.write(req.encode())
header = await r4.readuntil(b"\r\n\r\n")
assert header.startswith(b"HTTP/1.1 500")
body = await reader.readuntil(b"ASGI Error")
body = await r4.readuntil(b"ASGI Error")
assert body == b"ASGI Error"
writer.close()
await writer.wait_closed()
assert "ValueError" in caplog.text
reader, writer = await asyncio.open_connection(*proxy_addr)
req = f"GET http://noresponseapp:80/ HTTP/1.1\r\n\r\n"
writer.write(req.encode())
header = await reader.readuntil(b"\r\n\r\n")
w5.write(req.encode())
header = await r5.readuntil(b"\r\n\r\n")
assert header.startswith(b"HTTP/1.1 500")
body = await reader.readuntil(b"ASGI Error")
body = await r5.readuntil(b"ASGI Error")
assert body == b"ASGI Error"
writer.close()
await writer.wait_closed()
assert "no response sent" in caplog.text
w1.close()
w2.close()
w3.close()
w4.close()
w5.close()
await asyncio.gather(
w1.wait_closed(),
w2.wait_closed(),
w3.wait_closed(),
w4.wait_closed(),
w5.wait_closed(),
)
tctx.configure(ps, server=False)
assert await ps.setup_servers()

View File

@ -17,23 +17,48 @@ from mitmproxy.test import tflow
@asynccontextmanager
async def tcp_server(handle_conn, **server_args) -> Address:
server = await asyncio.start_server(handle_conn, "127.0.0.1", 0, **server_args)
"""TCP server context manager that...
1. Exits only after all handlers have returned.
2. Ensures that all handlers are closed properly. If we don't do that,
we get ghost errors in others tests from StreamWriter.__del__.
Spawning a TCP server is relatively slow. Consider using in-memory networking for faster tests.
"""
if not hasattr(asyncio, "TaskGroup"):
pytest.skip("Skipped because asyncio.TaskGroup is unavailable.")
tasks = asyncio.TaskGroup()
async def handle_conn_wrapper(
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
) -> None:
try:
await handle_conn(reader, writer)
except Exception as e:
print(f"!!! TCP handler failed: {e}")
raise
finally:
if not writer.is_closing():
writer.close()
await writer.wait_closed()
async def _handle(r, w):
tasks.create_task(handle_conn_wrapper(r, w))
server = await asyncio.start_server(_handle, "127.0.0.1", 0, **server_args)
await server.start_serving()
try:
yield server.sockets[0].getsockname()
finally:
server.close()
async with server:
async with tasks:
yield server.sockets[0].getsockname()
@pytest.mark.parametrize("mode", ["http", "https", "upstream", "err"])
@pytest.mark.parametrize("concurrency", [-1, 1])
async def test_playback(tdata, mode, concurrency):
handler_ok = asyncio.Event()
async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
if mode == "err":
writer.close()
handler_ok.set()
return
req = await reader.readline()
if mode == "upstream":
@ -49,7 +74,6 @@ async def test_playback(tdata, mode, concurrency):
writer.write(b"HTTP/1.1 204 No Content\r\n\r\n")
await writer.drain()
assert not await reader.read()
handler_ok.set()
cp = ClientPlayback()
ps = Proxyserver()
@ -92,22 +116,20 @@ async def test_playback(tdata, mode, concurrency):
cp.start_replay([flow])
assert cp.count() == 1
await asyncio.wait_for(cp.queue.join(), 5)
await asyncio.wait_for(handler_ok.wait(), 5)
cp.done()
if mode != "err":
assert flow.response.status_code == 204
while cp.replay_tasks:
await asyncio.sleep(0.001)
if mode != "err":
assert flow.response.status_code == 204
await cp.done()
async def test_playback_https_upstream():
handler_ok = asyncio.Event()
async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
conn_req = await reader.readuntil(b"\r\n\r\n")
assert conn_req == b"CONNECT address:22 HTTP/1.1\r\n\r\n"
writer.write(b"HTTP/1.1 502 Bad Gateway\r\n\r\n")
await writer.drain()
assert not await reader.read()
handler_ok.set()
cp = ClientPlayback()
ps = Proxyserver()
@ -122,17 +144,17 @@ async def test_playback_https_upstream():
cp.start_replay([flow])
assert cp.count() == 1
await asyncio.wait_for(cp.queue.join(), 5)
await asyncio.wait_for(handler_ok.wait(), 5)
cp.done()
assert flow.response is None
assert (
str(flow.error)
== f"Upstream proxy {addr[0]}:{addr[1]} refused HTTP CONNECT request: 502 Bad Gateway"
)
assert flow.response is None
assert (
str(flow.error)
== f"Upstream proxy {addr[0]}:{addr[1]} refused HTTP CONNECT request: 502 Bad Gateway"
)
await cp.done()
async def test_playback_crash(monkeypatch, caplog_async):
async def raise_err():
async def raise_err(*_, **__):
raise ValueError("oops")
monkeypatch.setattr(ReplayHandler, "replay", raise_err)
@ -141,8 +163,9 @@ async def test_playback_crash(monkeypatch, caplog_async):
cp.running()
cp.start_replay([tflow.tflow(live=False)])
await caplog_async.await_log("Client replay has crashed!")
assert "oops" in caplog_async.caplog.text
assert cp.count() == 0
cp.done()
await cp.done()
def test_check():

View File

@ -23,6 +23,7 @@ from aioquic.quic.configuration import QuicConfiguration
from aioquic.quic.connection import QuicConnection
from aioquic.quic.connection import QuicConnectionError
from .test_clientplayback import tcp_server
import mitmproxy.platform
from mitmproxy import dns
from mitmproxy import exceptions
@ -55,16 +56,6 @@ class HelperAddon:
self.flows.append(f)
@asynccontextmanager
async def tcp_server(handle_conn) -> Address:
server = await asyncio.start_server(handle_conn, "127.0.0.1", 0)
await server.start_serving()
try:
yield server.sockets[0].getsockname()
finally:
server.close()
async def test_start_stop(caplog_async):
caplog_async.set_level("INFO")
@ -74,7 +65,6 @@ async def test_start_stop(caplog_async):
assert await reader.readuntil(b"\r\n\r\n") == b"GET /hello HTTP/1.1\r\n\r\n"
writer.write(b"HTTP/1.1 204 No Content\r\n\r\n")
await writer.drain()
writer.close()
ps = Proxyserver()
nl = NextLayer()
@ -160,6 +150,9 @@ async def test_inject() -> None:
ps.inject_tcp(state.flows[0], True, b"c")
assert await reader.read(1) == b"c"
writer.close()
await writer.wait_closed()
async def test_inject_fail(caplog) -> None:
ps = Proxyserver()
@ -311,6 +304,9 @@ async def test_dns(caplog_async) -> None:
tctx.configure(ps, server=False)
await caplog_async.await_log("stopped")
w.close()
await w.wait_closed()
def test_validation_no_transparent(monkeypatch):
monkeypatch.setattr(mitmproxy.platform, "original_addr", None)
@ -373,6 +369,9 @@ async def test_udp(caplog_async) -> None:
tctx.configure(ps, server=False)
await caplog_async.await_log("stopped")
w.close()
await w.wait_closed()
class H3EchoServer(QuicConnectionProtocol):
def __init__(self, *args, **kwargs) -> None:
@ -779,6 +778,11 @@ async def test_reverse_http3_and_quic_stream(
await _test_echo(client, strict=scheme == "http3")
assert len(ps.connections) == 1
# dirty hack: forcibly close all connections so that there are no unexpected asyncio tasks
# that may cause test failures because they have not been run.
for conn in ps.servers[mode].manager.connections.values():
await conn.on_timeout()
tctx.configure(ps, server=False)
await caplog_async.await_log(f"stopped")

View File

@ -328,6 +328,7 @@ def test_order(tdata, capsys):
]
)
time = r"\[[\d:.]+\] "
out = capsys.readouterr().out
assert re.match(
rf"{time}Loading script.+recorder.py\n"
rf"{time}\('recorder', 'load', .+\n"
@ -335,5 +336,5 @@ def test_order(tdata, capsys):
rf"{time}Loading script.+shutdown.py\n"
rf"{time}\('recorder', 'running', .+\n"
rf"{time}\('recorder', 'done', .+\n$",
capsys.readouterr().out,
out,
)

View File

@ -269,6 +269,7 @@ async def test_udp_start_stop(caplog_async):
assert await caplog_async.await_log("sent an invalid message")
writer.close()
await writer.wait_closed()
await inst.stop()
assert await caplog_async.await_log("stopped")
@ -324,6 +325,7 @@ async def test_udp_dual_stack(caplog_async):
writer.write(b"\x00\x00\x01")
assert await caplog_async.await_log("sent an invalid message")
writer.close()
await writer.wait_closed()
if "listening on IPv4 only" not in caplog_async.caplog.text:
caplog_async.clear()
@ -331,6 +333,7 @@ async def test_udp_dual_stack(caplog_async):
writer.write(b"\x00\x00\x01")
assert await caplog_async.await_log("sent an invalid message")
writer.close()
await writer.wait_closed()
await inst.stop()
assert await caplog_async.await_log("stopped")
@ -338,7 +341,7 @@ async def test_udp_dual_stack(caplog_async):
@pytest.fixture()
def patched_local_redirector(monkeypatch):
start_local_redirector = AsyncMock()
start_local_redirector = AsyncMock(return_value=Mock())
monkeypatch.setattr(mitmproxy_rs, "start_local_redirector", start_local_redirector)
# make sure _server and _instance are restored after this test
monkeypatch.setattr(LocalRedirectorInstance, "_server", None)

View File

@ -1,5 +1,4 @@
import asyncio
from unittest.mock import MagicMock
import pytest
@ -8,8 +7,11 @@ from mitmproxy.tools.web.master import WebMaster
async def test_reuse():
async def handler(r, w):
pass
server = await asyncio.start_server(
MagicMock(), host="127.0.0.1", port=0, reuse_address=False
handler, host="127.0.0.1", port=0, reuse_address=False
)
port = server.sockets[0].getsockname()[1]
master = WebMaster(Options(), with_termlog=False)
@ -18,3 +20,6 @@ async def test_reuse():
with pytest.raises(OSError, match=f"--set web_port={port + 2}"):
await master.running()
server.close()
# tornado registers some callbacks,
# we want to run them to avoid fatal warnings.
await asyncio.sleep(0)