diff --git a/CHANGELOG.md b/CHANGELOG.md index 1df19bb26..1f50fde7b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ * Add support for editing non text files in a hex editor ([#6768](https://github.com/mitmproxy/mitmproxy/pull/6768), @wnyyyy) +* Add `server_connect_error` hook that is triggered when connection establishment fails. + ([#6806](https://github.com/mitmproxy/mitmproxy/pull/6806), @haanhvu, @spacewasp, @mhils) * Add section in mitmweb for rendering, adding and removing a comment ([#6709](https://github.com/mitmproxy/mitmproxy/pull/6709), @lups2000) * Fix multipart form content view being unusable. diff --git a/docs/scripts/api-events.py b/docs/scripts/api-events.py index be3f1e0b0..f99166baa 100644 --- a/docs/scripts/api-events.py +++ b/docs/scripts/api-events.py @@ -97,6 +97,7 @@ with outfile.open("w") as f, contextlib.redirect_stdout(f): server_hooks.ServerConnectHook, server_hooks.ServerConnectedHook, server_hooks.ServerDisconnectedHook, + server_hooks.ServerConnectErrorHook, ], ) diff --git a/mitmproxy/proxy/server.py b/mitmproxy/proxy/server.py index ebe23ed55..2a631b179 100644 --- a/mitmproxy/proxy/server.py +++ b/mitmproxy/proxy/server.py @@ -196,6 +196,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta): self.log( f"server connection to {human.format_address(command.connection.address)} killed before connect: {err}" ) + await self.handle_hook(server_hooks.ServerConnectErrorHook(hook_data)) self.server_event( events.OpenConnectionCompleted(command, f"Connection killed: {err}") ) @@ -224,6 +225,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta): err = "connection cancelled" self.log(f"error establishing server connection: {err}") command.connection.error = err + await self.handle_hook(server_hooks.ServerConnectErrorHook(hook_data)) self.server_event(events.OpenConnectionCompleted(command, err)) if isinstance(e, asyncio.CancelledError): # From https://docs.python.org/3/library/asyncio-exceptions.html#asyncio.CancelledError: @@ -237,8 +239,11 @@ class ConnectionHandler(metaclass=abc.ABCMeta): command.connection.state = ConnectionState.OPEN command.connection.peername = writer.get_extra_info("peername") command.connection.sockname = writer.get_extra_info("sockname") - self.transports[command.connection].reader = reader - self.transports[command.connection].writer = writer + self.transports[command.connection] = ConnectionIO( + handler=asyncio.current_task(), + reader=reader, + writer=writer, + ) assert command.connection.peername if command.connection.address[0] != command.connection.peername[0]: diff --git a/mitmproxy/proxy/server_hooks.py b/mitmproxy/proxy/server_hooks.py index a00c6ca19..1da4b60f2 100644 --- a/mitmproxy/proxy/server_hooks.py +++ b/mitmproxy/proxy/server_hooks.py @@ -63,3 +63,14 @@ class ServerDisconnectedHook(commands.StartHook): """ data: ServerConnectionHookData + + +@dataclass +class ServerConnectErrorHook(commands.StartHook): + """ + Mitmproxy failed to connect to a server. + + Every server connection will receive either a server_connected or a server_connect_error event, but not both. + """ + + data: ServerConnectionHookData diff --git a/test/mitmproxy/proxy/test_server.py b/test/mitmproxy/proxy/test_server.py index e69de29bb..0de636fda 100644 --- a/test/mitmproxy/proxy/test_server.py +++ b/test/mitmproxy/proxy/test_server.py @@ -0,0 +1,66 @@ +import asyncio +import collections +from unittest import mock + +import pytest + +from mitmproxy import options +from mitmproxy.connection import Server +from mitmproxy.proxy import commands +from mitmproxy.proxy import server +from mitmproxy.proxy import server_hooks +from mitmproxy.proxy.mode_specs import ProxyMode + + +class MockConnectionHandler(server.SimpleConnectionHandler): + hook_handlers: dict[str, mock.Mock] + + def __init__(self): + super().__init__( + reader=mock.Mock(), + writer=mock.Mock(), + options=options.Options(), + mode=ProxyMode.parse("regular"), + hooks=collections.defaultdict(lambda: mock.Mock()), + ) + + +@pytest.mark.parametrize("result", ("success", "killed", "failed")) +async def test_open_connection(result, monkeypatch): + handler = MockConnectionHandler() + server_connect = handler.hook_handlers["server_connect"] + server_connected = handler.hook_handlers["server_connected"] + server_connect_error = handler.hook_handlers["server_connect_error"] + server_disconnected = handler.hook_handlers["server_disconnected"] + + match result: + case "success": + monkeypatch.setattr( + asyncio, + "open_connection", + mock.AsyncMock(return_value=(mock.MagicMock(), mock.MagicMock())), + ) + monkeypatch.setattr( + MockConnectionHandler, "handle_connection", mock.AsyncMock() + ) + case "failed": + monkeypatch.setattr( + asyncio, "open_connection", mock.AsyncMock(side_effect=OSError) + ) + case "killed": + + def _kill(d: server_hooks.ServerConnectionHookData) -> None: + d.server.error = "do not connect" + + server_connect.side_effect = _kill + + await handler.open_connection( + commands.OpenConnection(connection=Server(address=("server", 1234))) + ) + + assert server_connect.call_args[0][0].server.address == ("server", 1234) + + assert server_connected.called == (result == "success") + assert server_connect_error.called == (result != "success") + + assert server_disconnected.called == (result == "success")