Guard server_event against reentrancy, fix #7027 (#7031)

* guard server_event against reentrancy, fix #7027

* [autofix.ci] apply automated fixes

* attribution for excellent repros

* simplify test for compatibility with older Python versions

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Maximilian Hils 2024-07-23 11:12:09 +02:00 committed by GitHub
parent 317f5b9dce
commit 74424df839
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 131 additions and 66 deletions

View File

@ -33,6 +33,8 @@
([#7021](https://github.com/mitmproxy/mitmproxy/pull/7021), @petsneakers)
* Support all query types in DNS mode
([#6975](https://github.com/mitmproxy/mitmproxy/pull/6975), @errorxyz)
* Fix a bug where mitmproxy would crash for pipelined HTTP flows.
([#7031](https://github.com/mitmproxy/mitmproxy/pull/7031), @gdiepen, @mhils)
## 12 June 2024: mitmproxy 10.3.1

View File

@ -110,7 +110,7 @@ class ReplayHandler(server.ConnectionHandler):
self.done = asyncio.Event()
async def replay(self) -> None:
self.server_event(events.Start())
await self.server_event(events.Start())
await self.done.wait()
def log(

View File

@ -121,11 +121,13 @@ class Proxyserver(ServerManager):
is_running: bool
_connect_addr: Address | None = None
_update_task: asyncio.Task | None = None
_inject_tasks: set[asyncio.Task]
def __init__(self):
self.connections = {}
self.servers = Servers(self)
self.is_running = False
self._inject_tasks = set()
def __repr__(self):
return f"Proxyserver({len(self.connections)} active conns)"
@ -308,7 +310,15 @@ class Proxyserver(ServerManager):
)
if connection_id not in self.connections:
raise ValueError("Flow is not from a live connection.")
self.connections[connection_id].server_event(event)
t = asyncio_utils.create_task(
self.connections[connection_id].server_event(event),
name=f"inject_event",
client=event.flow.client_conn.peername,
)
# Python 3.11 Use TaskGroup instead.
self._inject_tasks.add(t)
t.add_done_callback(self._inject_tasks.remove)
@command.command("inject.websocket")
def inject_websocket(

View File

@ -122,6 +122,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
timeout = UDP_TIMEOUT
self.timeout_watchdog = TimeoutWatchdog(timeout, self.on_timeout)
self._server_event_lock = asyncio.Lock()
# workaround for https://bugs.python.org/issue40124 / https://bugs.python.org/issue29930
self._drain_lock = asyncio.Lock()
@ -144,7 +146,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
assert writer
writer.close()
else:
self.server_event(events.Start())
await self.server_event(events.Start())
handler = asyncio_utils.create_task(
self.handle_connection(self.client),
name=f"client connection handler",
@ -181,7 +183,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
async def open_connection(self, command: commands.OpenConnection) -> None:
if not command.connection.address:
self.log(f"Cannot open connection, no hostname given.")
self.server_event(
await self.server_event(
events.OpenConnectionCompleted(
command, f"Cannot open connection, no hostname given."
)
@ -197,7 +199,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
f"server connection to {human.format_address(command.connection.address)} killed before connect: {err}"
)
await self.handle_hook(server_hooks.ServerConnectErrorHook(hook_data))
self.server_event(
await self.server_event(
events.OpenConnectionCompleted(command, f"Connection killed: {err}")
)
return
@ -226,7 +228,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
self.log(f"error establishing server connection: {err}")
command.connection.error = err
await self.handle_hook(server_hooks.ServerConnectErrorHook(hook_data))
self.server_event(events.OpenConnectionCompleted(command, err))
await self.server_event(events.OpenConnectionCompleted(command, err))
if isinstance(e, asyncio.CancelledError):
# From https://docs.python.org/3/library/asyncio-exceptions.html#asyncio.CancelledError:
# > In almost all situations the exception must be re-raised.
@ -252,7 +254,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
addr = human.format_address(command.connection.address)
self.log(f"server connect {addr}")
await self.handle_hook(server_hooks.ServerConnectedHook(hook_data))
self.server_event(events.OpenConnectionCompleted(command, None))
await self.server_event(events.OpenConnectionCompleted(command, None))
try:
await self.handle_connection(command.connection)
@ -268,7 +270,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
task = asyncio.current_task()
assert task is not None
self.wakeup_timer.discard(task)
self.server_event(events.Wakeup(request))
await self.server_event(events.Wakeup(request))
async def handle_connection(self, connection: Connection) -> None:
"""
@ -290,7 +292,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
cancelled = e
break
self.server_event(events.DataReceived(connection, data))
await self.server_event(events.DataReceived(connection, data))
try:
await self.drain_writers()
@ -304,7 +306,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
else:
connection.state = ConnectionState.CLOSED
self.server_event(events.ConnectionClosed(connection))
await self.server_event(events.ConnectionClosed(connection))
if connection.state is ConnectionState.CAN_WRITE:
# we may still use this connection to *send* stuff,
@ -355,7 +357,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
async def hook_task(self, hook: commands.StartHook) -> None:
await self.handle_hook(hook)
if hook.blocking:
self.server_event(events.HookCompleted(hook))
await self.server_event(events.HookCompleted(hook))
@abc.abstractmethod
async def handle_hook(self, hook: commands.StartHook) -> None:
@ -373,56 +375,67 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
level, message, extra={"client": self.client.peername}, exc_info=exc_info
)
def server_event(self, event: events.Event) -> None:
self.timeout_watchdog.register_activity()
try:
layer_commands = self.layer.handle_event(event)
for command in layer_commands:
if isinstance(command, commands.OpenConnection):
assert command.connection not in self.transports
handler = asyncio_utils.create_task(
self.open_connection(command),
name=f"server connection handler {command.connection.address}",
client=self.client.peername,
)
self.transports[command.connection] = ConnectionIO(handler=handler)
elif isinstance(command, commands.RequestWakeup):
task = asyncio_utils.create_task(
self.wakeup(command),
name=f"wakeup timer ({command.delay:.1f}s)",
client=self.client.peername,
)
assert task is not None
self.wakeup_timer.add(task)
elif (
isinstance(command, commands.ConnectionCommand)
and command.connection not in self.transports
):
pass # The connection has already been closed.
elif isinstance(command, commands.SendData):
writer = self.transports[command.connection].writer
assert writer
if not writer.is_closing():
writer.write(command.data)
elif isinstance(command, commands.CloseTcpConnection):
self.close_connection(command.connection, command.half_close)
elif isinstance(command, commands.CloseConnection):
self.close_connection(command.connection, False)
elif isinstance(command, commands.StartHook):
t = asyncio_utils.create_task(
self.hook_task(command),
name=f"handle_hook({command.name})",
client=self.client.peername,
)
# Python 3.11 Use TaskGroup instead.
self.hook_tasks.add(t)
t.add_done_callback(self.hook_tasks.remove)
elif isinstance(command, commands.Log):
self.log(command.message, command.level)
else:
raise RuntimeError(f"Unexpected command: {command}")
except Exception:
self.log(f"mitmproxy has crashed!", logging.ERROR, exc_info=True)
async def server_event(self, event: events.Event) -> None:
# server_event is supposed to be completely sync without any `await` that could pause execution.
# However, create_task with an [eager task factory] will schedule tasks immediately,
# which causes [reentrancy issues]. So we put the entire thing behind a lock.
#
# [eager task factory]: https://docs.python.org/3/library/asyncio-task.html#eager-task-factory
# [reentrancy issues]: https://github.com/mitmproxy/mitmproxy/issues/7027.
async with self._server_event_lock:
# No `await` beyond this point.
self.timeout_watchdog.register_activity()
try:
layer_commands = self.layer.handle_event(event)
for command in layer_commands:
if isinstance(command, commands.OpenConnection):
assert command.connection not in self.transports
handler = asyncio_utils.create_task(
self.open_connection(command),
name=f"server connection handler {command.connection.address}",
client=self.client.peername,
)
self.transports[command.connection] = ConnectionIO(
handler=handler
)
elif isinstance(command, commands.RequestWakeup):
task = asyncio_utils.create_task(
self.wakeup(command),
name=f"wakeup timer ({command.delay:.1f}s)",
client=self.client.peername,
)
assert task is not None
self.wakeup_timer.add(task)
elif (
isinstance(command, commands.ConnectionCommand)
and command.connection not in self.transports
):
pass # The connection has already been closed.
elif isinstance(command, commands.SendData):
writer = self.transports[command.connection].writer
assert writer
if not writer.is_closing():
writer.write(command.data)
elif isinstance(command, commands.CloseTcpConnection):
self.close_connection(command.connection, command.half_close)
elif isinstance(command, commands.CloseConnection):
self.close_connection(command.connection, False)
elif isinstance(command, commands.StartHook):
t = asyncio_utils.create_task(
self.hook_task(command),
name=f"handle_hook({command.name})",
client=self.client.peername,
)
# Python 3.11 Use TaskGroup instead.
self.hook_tasks.add(t)
t.add_done_callback(self.hook_tasks.remove)
elif isinstance(command, commands.Log):
self.log(command.message, command.level)
else:
raise RuntimeError(f"Unexpected command: {command}")
except Exception:
self.log(f"mitmproxy has crashed!", logging.ERROR, exc_info=True)
def close_connection(
self, connection: Connection, half_close: bool = False
@ -478,9 +491,9 @@ class SimpleConnectionHandler(LiveConnectionHandler): # pragma: no cover
hook_handlers: dict[str, Callable]
def __init__(self, reader, writer, options, mode, hooks):
def __init__(self, reader, writer, options, mode, hook_handlers):
super().__init__(reader, writer, options, mode)
self.hook_handlers = hooks
self.hook_handlers = hook_handlers
async def handle_hook(self, hook: commands.StartHook) -> None:
if hook.name in self.hook_handlers:

View File

@ -35,7 +35,7 @@ class EagerTaskCreationEventLoopPolicy(asyncio.DefaultEventLoopPolicy):
return loop
@pytest.fixture()
@pytest.fixture(scope="session")
def event_loop_policy(request):
return EagerTaskCreationEventLoopPolicy()

View File

@ -1,5 +1,8 @@
import asyncio
import collections
import textwrap
from dataclasses import dataclass
from typing import Callable
from unittest import mock
import pytest
@ -7,13 +10,17 @@ import pytest
from mitmproxy import options
from mitmproxy.connection import Server
from mitmproxy.proxy import commands
from mitmproxy.proxy import layer
from mitmproxy.proxy import server
from mitmproxy.proxy import server_hooks
from mitmproxy.proxy.events import Event
from mitmproxy.proxy.events import HookCompleted
from mitmproxy.proxy.events import Start
from mitmproxy.proxy.mode_specs import ProxyMode
class MockConnectionHandler(server.SimpleConnectionHandler):
hook_handlers: dict[str, mock.Mock]
hook_handlers: dict[str, mock.Mock | Callable]
def __init__(self):
super().__init__(
@ -21,7 +28,7 @@ class MockConnectionHandler(server.SimpleConnectionHandler):
writer=mock.Mock(),
options=options.Options(),
mode=ProxyMode.parse("regular"),
hooks=collections.defaultdict(lambda: mock.Mock()),
hook_handlers=collections.defaultdict(lambda: mock.Mock()),
)
@ -64,3 +71,36 @@ async def test_open_connection(result, monkeypatch):
assert server_connect_error.called == (result != "success")
assert server_disconnected.called == (result == "success")
async def test_no_reentrancy(capsys):
class ReentrancyTestLayer(layer.Layer):
def handle_event(self, event: Event) -> layer.CommandGenerator[None]:
if isinstance(event, Start):
print("Starting...")
yield FastHook()
print("Start completed.")
elif isinstance(event, HookCompleted):
print(f"Hook completed (must not happen before start is completed).")
def _handle_event(self, event: Event) -> layer.CommandGenerator[None]:
raise NotImplementedError
@dataclass
class FastHook(commands.StartHook):
pass
handler = MockConnectionHandler()
handler.layer = ReentrancyTestLayer(handler.layer.context)
# This instead would fail: handler._server_event(Start())
await handler.server_event(Start())
await asyncio.sleep(0)
assert capsys.readouterr().out == textwrap.dedent(
"""\
Starting...
Start completed.
Hook completed (must not happen before start is completed).
"""
)