fix a bug where connections would not be fully closed (#6543)
This commit is contained in:
parent
1fcd0335d5
commit
0a3e016d39
|
@ -11,6 +11,8 @@
|
|||
([#6548](https://github.com/mitmproxy/mitmproxy/pull/6548), @zanieb)
|
||||
* Improved handling for `--allow-hosts`/`--ignore-hosts` options in WireGuard mode (#5930).
|
||||
([#6513](https://github.com/mitmproxy/mitmproxy/pull/6513), @dsphper)
|
||||
* Fix a bug where TCP connections were not closed properly.
|
||||
([#6543](https://github.com/mitmproxy/mitmproxy/pull/6543), @mhils)
|
||||
* DNS resolution is now exempted from `--ignore-hosts` in WireGuard Mode.
|
||||
([#6513](https://github.com/mitmproxy/mitmproxy/pull/6513), @dsphper)
|
||||
* Fix a bug where logging was stopped prematurely during shutdown.
|
||||
|
|
|
@ -161,9 +161,13 @@ class ClientPlayback:
|
|||
)
|
||||
self.options = ctx.options
|
||||
|
||||
def done(self):
|
||||
async def done(self):
|
||||
if self.playback_task:
|
||||
self.playback_task.cancel()
|
||||
try:
|
||||
await self.playback_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def playback(self):
|
||||
while True:
|
||||
|
|
|
@ -523,7 +523,7 @@ def parse(text):
|
|||
if not text:
|
||||
return {}
|
||||
try:
|
||||
yaml = ruamel.yaml.YAML(typ="unsafe", pure=True)
|
||||
yaml = ruamel.yaml.YAML(typ="safe", pure=True)
|
||||
data = yaml.load(text)
|
||||
except ruamel.yaml.error.YAMLError as v:
|
||||
if hasattr(v, "problem_mark"):
|
||||
|
|
|
@ -199,6 +199,7 @@ class ServerInstance(Generic[M], metaclass=ABCMeta):
|
|||
original_dst = platform.original_addr(s)
|
||||
except Exception as e:
|
||||
logger.error(f"Transparent mode failure: {e!r}")
|
||||
writer.close()
|
||||
return
|
||||
else:
|
||||
handler.layer.context.client.sockname = original_dst
|
||||
|
|
|
@ -308,7 +308,10 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||
# we may still use this connection to *send* stuff,
|
||||
# even though the remote has closed their side of the connection.
|
||||
# to make this work we keep this task running and wait for cancellation.
|
||||
await asyncio.Event().wait()
|
||||
try:
|
||||
await asyncio.Event().wait()
|
||||
except asyncio.CancelledError as e:
|
||||
cancelled = e
|
||||
|
||||
try:
|
||||
writer = self.transports[connection].writer
|
||||
|
@ -336,10 +339,15 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||
transport.handler.cancel(f"Error sending data: {e}")
|
||||
|
||||
async def on_timeout(self) -> None:
|
||||
self.log(f"Closing connection due to inactivity: {self.client}")
|
||||
handler = self.transports[self.client].handler
|
||||
assert handler
|
||||
handler.cancel("timeout")
|
||||
try:
|
||||
handler = self.transports[self.client].handler
|
||||
except KeyError: # pragma: no cover
|
||||
# there is a super short window between connection close and watchdog cancellation
|
||||
pass
|
||||
else:
|
||||
self.log(f"Closing connection due to inactivity: {self.client}")
|
||||
assert handler
|
||||
handler.cancel("timeout")
|
||||
|
||||
async def hook_task(self, hook: commands.StartHook) -> None:
|
||||
await self.handle_hook(hook)
|
||||
|
|
|
@ -130,6 +130,8 @@ testpaths = "test"
|
|||
addopts = "--capture=no --color=yes"
|
||||
filterwarnings = [
|
||||
"ignore::DeprecationWarning:tornado.*:",
|
||||
"error::RuntimeWarning",
|
||||
"error::pytest.PytestUnraisableExceptionWarning",
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
|
|
|
@ -57,57 +57,70 @@ async def test_asgi_full(caplog):
|
|||
assert await ps.setup_servers()
|
||||
proxy_addr = ("127.0.0.1", ps.listen_addrs()[0][1])
|
||||
|
||||
reader, writer = await asyncio.open_connection(*proxy_addr)
|
||||
# We parallelize connection establishment/closure because those operations tend to be slow.
|
||||
[
|
||||
(r1, w1),
|
||||
(r2, w2),
|
||||
(r3, w3),
|
||||
(r4, w4),
|
||||
(r5, w5),
|
||||
] = await asyncio.gather(
|
||||
asyncio.open_connection(*proxy_addr),
|
||||
asyncio.open_connection(*proxy_addr),
|
||||
asyncio.open_connection(*proxy_addr),
|
||||
asyncio.open_connection(*proxy_addr),
|
||||
asyncio.open_connection(*proxy_addr),
|
||||
)
|
||||
|
||||
req = f"GET http://testapp:80/ HTTP/1.1\r\n\r\n"
|
||||
writer.write(req.encode())
|
||||
header = await reader.readuntil(b"\r\n\r\n")
|
||||
w1.write(req.encode())
|
||||
header = await r1.readuntil(b"\r\n\r\n")
|
||||
assert header.startswith(b"HTTP/1.1 200 OK")
|
||||
body = await reader.readuntil(b"testapp")
|
||||
body = await r1.readuntil(b"testapp")
|
||||
assert body == b"testapp"
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
reader, writer = await asyncio.open_connection(*proxy_addr)
|
||||
req = f"GET http://testapp:80/parameters?param1=1¶m2=2 HTTP/1.1\r\n\r\n"
|
||||
writer.write(req.encode())
|
||||
header = await reader.readuntil(b"\r\n\r\n")
|
||||
w2.write(req.encode())
|
||||
header = await r2.readuntil(b"\r\n\r\n")
|
||||
assert header.startswith(b"HTTP/1.1 200 OK")
|
||||
body = await reader.readuntil(b"}")
|
||||
body = await r2.readuntil(b"}")
|
||||
assert body == b'{"param1": "1", "param2": "2"}'
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
reader, writer = await asyncio.open_connection(*proxy_addr)
|
||||
req = f"POST http://testapp:80/requestbody HTTP/1.1\r\nContent-Length: 6\r\n\r\nHello!"
|
||||
writer.write(req.encode())
|
||||
header = await reader.readuntil(b"\r\n\r\n")
|
||||
w3.write(req.encode())
|
||||
header = await r3.readuntil(b"\r\n\r\n")
|
||||
assert header.startswith(b"HTTP/1.1 200 OK")
|
||||
body = await reader.readuntil(b"}")
|
||||
body = await r3.readuntil(b"}")
|
||||
assert body == b'{"body": "Hello!"}'
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
reader, writer = await asyncio.open_connection(*proxy_addr)
|
||||
req = f"GET http://errapp:80/?foo=bar HTTP/1.1\r\n\r\n"
|
||||
writer.write(req.encode())
|
||||
header = await reader.readuntil(b"\r\n\r\n")
|
||||
w4.write(req.encode())
|
||||
header = await r4.readuntil(b"\r\n\r\n")
|
||||
assert header.startswith(b"HTTP/1.1 500")
|
||||
body = await reader.readuntil(b"ASGI Error")
|
||||
body = await r4.readuntil(b"ASGI Error")
|
||||
assert body == b"ASGI Error"
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
assert "ValueError" in caplog.text
|
||||
|
||||
reader, writer = await asyncio.open_connection(*proxy_addr)
|
||||
req = f"GET http://noresponseapp:80/ HTTP/1.1\r\n\r\n"
|
||||
writer.write(req.encode())
|
||||
header = await reader.readuntil(b"\r\n\r\n")
|
||||
w5.write(req.encode())
|
||||
header = await r5.readuntil(b"\r\n\r\n")
|
||||
assert header.startswith(b"HTTP/1.1 500")
|
||||
body = await reader.readuntil(b"ASGI Error")
|
||||
body = await r5.readuntil(b"ASGI Error")
|
||||
assert body == b"ASGI Error"
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
assert "no response sent" in caplog.text
|
||||
|
||||
w1.close()
|
||||
w2.close()
|
||||
w3.close()
|
||||
w4.close()
|
||||
w5.close()
|
||||
await asyncio.gather(
|
||||
w1.wait_closed(),
|
||||
w2.wait_closed(),
|
||||
w3.wait_closed(),
|
||||
w4.wait_closed(),
|
||||
w5.wait_closed(),
|
||||
)
|
||||
|
||||
tctx.configure(ps, server=False)
|
||||
assert await ps.setup_servers()
|
||||
|
|
|
@ -17,23 +17,48 @@ from mitmproxy.test import tflow
|
|||
|
||||
@asynccontextmanager
|
||||
async def tcp_server(handle_conn, **server_args) -> Address:
|
||||
server = await asyncio.start_server(handle_conn, "127.0.0.1", 0, **server_args)
|
||||
"""TCP server context manager that...
|
||||
|
||||
1. Exits only after all handlers have returned.
|
||||
2. Ensures that all handlers are closed properly. If we don't do that,
|
||||
we get ghost errors in others tests from StreamWriter.__del__.
|
||||
|
||||
Spawning a TCP server is relatively slow. Consider using in-memory networking for faster tests.
|
||||
"""
|
||||
if not hasattr(asyncio, "TaskGroup"):
|
||||
pytest.skip("Skipped because asyncio.TaskGroup is unavailable.")
|
||||
|
||||
tasks = asyncio.TaskGroup()
|
||||
|
||||
async def handle_conn_wrapper(
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
) -> None:
|
||||
try:
|
||||
await handle_conn(reader, writer)
|
||||
except Exception as e:
|
||||
print(f"!!! TCP handler failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
if not writer.is_closing():
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
async def _handle(r, w):
|
||||
tasks.create_task(handle_conn_wrapper(r, w))
|
||||
|
||||
server = await asyncio.start_server(_handle, "127.0.0.1", 0, **server_args)
|
||||
await server.start_serving()
|
||||
try:
|
||||
yield server.sockets[0].getsockname()
|
||||
finally:
|
||||
server.close()
|
||||
async with server:
|
||||
async with tasks:
|
||||
yield server.sockets[0].getsockname()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mode", ["http", "https", "upstream", "err"])
|
||||
@pytest.mark.parametrize("concurrency", [-1, 1])
|
||||
async def test_playback(tdata, mode, concurrency):
|
||||
handler_ok = asyncio.Event()
|
||||
|
||||
async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
if mode == "err":
|
||||
writer.close()
|
||||
handler_ok.set()
|
||||
return
|
||||
req = await reader.readline()
|
||||
if mode == "upstream":
|
||||
|
@ -49,7 +74,6 @@ async def test_playback(tdata, mode, concurrency):
|
|||
writer.write(b"HTTP/1.1 204 No Content\r\n\r\n")
|
||||
await writer.drain()
|
||||
assert not await reader.read()
|
||||
handler_ok.set()
|
||||
|
||||
cp = ClientPlayback()
|
||||
ps = Proxyserver()
|
||||
|
@ -92,22 +116,20 @@ async def test_playback(tdata, mode, concurrency):
|
|||
cp.start_replay([flow])
|
||||
assert cp.count() == 1
|
||||
await asyncio.wait_for(cp.queue.join(), 5)
|
||||
await asyncio.wait_for(handler_ok.wait(), 5)
|
||||
cp.done()
|
||||
if mode != "err":
|
||||
assert flow.response.status_code == 204
|
||||
while cp.replay_tasks:
|
||||
await asyncio.sleep(0.001)
|
||||
if mode != "err":
|
||||
assert flow.response.status_code == 204
|
||||
await cp.done()
|
||||
|
||||
|
||||
async def test_playback_https_upstream():
|
||||
handler_ok = asyncio.Event()
|
||||
|
||||
async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
conn_req = await reader.readuntil(b"\r\n\r\n")
|
||||
assert conn_req == b"CONNECT address:22 HTTP/1.1\r\n\r\n"
|
||||
writer.write(b"HTTP/1.1 502 Bad Gateway\r\n\r\n")
|
||||
await writer.drain()
|
||||
assert not await reader.read()
|
||||
handler_ok.set()
|
||||
|
||||
cp = ClientPlayback()
|
||||
ps = Proxyserver()
|
||||
|
@ -122,17 +144,17 @@ async def test_playback_https_upstream():
|
|||
cp.start_replay([flow])
|
||||
assert cp.count() == 1
|
||||
await asyncio.wait_for(cp.queue.join(), 5)
|
||||
await asyncio.wait_for(handler_ok.wait(), 5)
|
||||
cp.done()
|
||||
assert flow.response is None
|
||||
assert (
|
||||
str(flow.error)
|
||||
== f"Upstream proxy {addr[0]}:{addr[1]} refused HTTP CONNECT request: 502 Bad Gateway"
|
||||
)
|
||||
|
||||
assert flow.response is None
|
||||
assert (
|
||||
str(flow.error)
|
||||
== f"Upstream proxy {addr[0]}:{addr[1]} refused HTTP CONNECT request: 502 Bad Gateway"
|
||||
)
|
||||
await cp.done()
|
||||
|
||||
|
||||
async def test_playback_crash(monkeypatch, caplog_async):
|
||||
async def raise_err():
|
||||
async def raise_err(*_, **__):
|
||||
raise ValueError("oops")
|
||||
|
||||
monkeypatch.setattr(ReplayHandler, "replay", raise_err)
|
||||
|
@ -141,8 +163,9 @@ async def test_playback_crash(monkeypatch, caplog_async):
|
|||
cp.running()
|
||||
cp.start_replay([tflow.tflow(live=False)])
|
||||
await caplog_async.await_log("Client replay has crashed!")
|
||||
assert "oops" in caplog_async.caplog.text
|
||||
assert cp.count() == 0
|
||||
cp.done()
|
||||
await cp.done()
|
||||
|
||||
|
||||
def test_check():
|
||||
|
|
|
@ -23,6 +23,7 @@ from aioquic.quic.configuration import QuicConfiguration
|
|||
from aioquic.quic.connection import QuicConnection
|
||||
from aioquic.quic.connection import QuicConnectionError
|
||||
|
||||
from .test_clientplayback import tcp_server
|
||||
import mitmproxy.platform
|
||||
from mitmproxy import dns
|
||||
from mitmproxy import exceptions
|
||||
|
@ -55,16 +56,6 @@ class HelperAddon:
|
|||
self.flows.append(f)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def tcp_server(handle_conn) -> Address:
|
||||
server = await asyncio.start_server(handle_conn, "127.0.0.1", 0)
|
||||
await server.start_serving()
|
||||
try:
|
||||
yield server.sockets[0].getsockname()
|
||||
finally:
|
||||
server.close()
|
||||
|
||||
|
||||
async def test_start_stop(caplog_async):
|
||||
caplog_async.set_level("INFO")
|
||||
|
||||
|
@ -74,7 +65,6 @@ async def test_start_stop(caplog_async):
|
|||
assert await reader.readuntil(b"\r\n\r\n") == b"GET /hello HTTP/1.1\r\n\r\n"
|
||||
writer.write(b"HTTP/1.1 204 No Content\r\n\r\n")
|
||||
await writer.drain()
|
||||
writer.close()
|
||||
|
||||
ps = Proxyserver()
|
||||
nl = NextLayer()
|
||||
|
@ -160,6 +150,9 @@ async def test_inject() -> None:
|
|||
ps.inject_tcp(state.flows[0], True, b"c")
|
||||
assert await reader.read(1) == b"c"
|
||||
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
|
||||
async def test_inject_fail(caplog) -> None:
|
||||
ps = Proxyserver()
|
||||
|
@ -311,6 +304,9 @@ async def test_dns(caplog_async) -> None:
|
|||
tctx.configure(ps, server=False)
|
||||
await caplog_async.await_log("stopped")
|
||||
|
||||
w.close()
|
||||
await w.wait_closed()
|
||||
|
||||
|
||||
def test_validation_no_transparent(monkeypatch):
|
||||
monkeypatch.setattr(mitmproxy.platform, "original_addr", None)
|
||||
|
@ -373,6 +369,9 @@ async def test_udp(caplog_async) -> None:
|
|||
tctx.configure(ps, server=False)
|
||||
await caplog_async.await_log("stopped")
|
||||
|
||||
w.close()
|
||||
await w.wait_closed()
|
||||
|
||||
|
||||
class H3EchoServer(QuicConnectionProtocol):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
|
@ -779,6 +778,11 @@ async def test_reverse_http3_and_quic_stream(
|
|||
await _test_echo(client, strict=scheme == "http3")
|
||||
assert len(ps.connections) == 1
|
||||
|
||||
# dirty hack: forcibly close all connections so that there are no unexpected asyncio tasks
|
||||
# that may cause test failures because they have not been run.
|
||||
for conn in ps.servers[mode].manager.connections.values():
|
||||
await conn.on_timeout()
|
||||
|
||||
tctx.configure(ps, server=False)
|
||||
await caplog_async.await_log(f"stopped")
|
||||
|
||||
|
|
|
@ -328,6 +328,7 @@ def test_order(tdata, capsys):
|
|||
]
|
||||
)
|
||||
time = r"\[[\d:.]+\] "
|
||||
out = capsys.readouterr().out
|
||||
assert re.match(
|
||||
rf"{time}Loading script.+recorder.py\n"
|
||||
rf"{time}\('recorder', 'load', .+\n"
|
||||
|
@ -335,5 +336,5 @@ def test_order(tdata, capsys):
|
|||
rf"{time}Loading script.+shutdown.py\n"
|
||||
rf"{time}\('recorder', 'running', .+\n"
|
||||
rf"{time}\('recorder', 'done', .+\n$",
|
||||
capsys.readouterr().out,
|
||||
out,
|
||||
)
|
||||
|
|
|
@ -269,6 +269,7 @@ async def test_udp_start_stop(caplog_async):
|
|||
assert await caplog_async.await_log("sent an invalid message")
|
||||
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
await inst.stop()
|
||||
assert await caplog_async.await_log("stopped")
|
||||
|
@ -324,6 +325,7 @@ async def test_udp_dual_stack(caplog_async):
|
|||
writer.write(b"\x00\x00\x01")
|
||||
assert await caplog_async.await_log("sent an invalid message")
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
if "listening on IPv4 only" not in caplog_async.caplog.text:
|
||||
caplog_async.clear()
|
||||
|
@ -331,6 +333,7 @@ async def test_udp_dual_stack(caplog_async):
|
|||
writer.write(b"\x00\x00\x01")
|
||||
assert await caplog_async.await_log("sent an invalid message")
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
await inst.stop()
|
||||
assert await caplog_async.await_log("stopped")
|
||||
|
@ -338,7 +341,7 @@ async def test_udp_dual_stack(caplog_async):
|
|||
|
||||
@pytest.fixture()
|
||||
def patched_local_redirector(monkeypatch):
|
||||
start_local_redirector = AsyncMock()
|
||||
start_local_redirector = AsyncMock(return_value=Mock())
|
||||
monkeypatch.setattr(mitmproxy_rs, "start_local_redirector", start_local_redirector)
|
||||
# make sure _server and _instance are restored after this test
|
||||
monkeypatch.setattr(LocalRedirectorInstance, "_server", None)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -8,8 +7,11 @@ from mitmproxy.tools.web.master import WebMaster
|
|||
|
||||
|
||||
async def test_reuse():
|
||||
async def handler(r, w):
|
||||
pass
|
||||
|
||||
server = await asyncio.start_server(
|
||||
MagicMock(), host="127.0.0.1", port=0, reuse_address=False
|
||||
handler, host="127.0.0.1", port=0, reuse_address=False
|
||||
)
|
||||
port = server.sockets[0].getsockname()[1]
|
||||
master = WebMaster(Options(), with_termlog=False)
|
||||
|
@ -18,3 +20,6 @@ async def test_reuse():
|
|||
with pytest.raises(OSError, match=f"--set web_port={port + 2}"):
|
||||
await master.running()
|
||||
server.close()
|
||||
# tornado registers some callbacks,
|
||||
# we want to run them to avoid fatal warnings.
|
||||
await asyncio.sleep(0)
|
||||
|
|
Loading…
Reference in New Issue