* guard server_event against reentrancy, fix #7027 * [autofix.ci] apply automated fixes * attribution for excellent repros * simplify test for compatibility with older Python versions --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
317f5b9dce
commit
74424df839
|
@ -33,6 +33,8 @@
|
|||
([#7021](https://github.com/mitmproxy/mitmproxy/pull/7021), @petsneakers)
|
||||
* Support all query types in DNS mode
|
||||
([#6975](https://github.com/mitmproxy/mitmproxy/pull/6975), @errorxyz)
|
||||
* Fix a bug where mitmproxy would crash for pipelined HTTP flows.
|
||||
([#7031](https://github.com/mitmproxy/mitmproxy/pull/7031), @gdiepen, @mhils)
|
||||
|
||||
## 12 June 2024: mitmproxy 10.3.1
|
||||
|
||||
|
|
|
@ -110,7 +110,7 @@ class ReplayHandler(server.ConnectionHandler):
|
|||
self.done = asyncio.Event()
|
||||
|
||||
async def replay(self) -> None:
|
||||
self.server_event(events.Start())
|
||||
await self.server_event(events.Start())
|
||||
await self.done.wait()
|
||||
|
||||
def log(
|
||||
|
|
|
@ -121,11 +121,13 @@ class Proxyserver(ServerManager):
|
|||
is_running: bool
|
||||
_connect_addr: Address | None = None
|
||||
_update_task: asyncio.Task | None = None
|
||||
_inject_tasks: set[asyncio.Task]
|
||||
|
||||
def __init__(self):
|
||||
self.connections = {}
|
||||
self.servers = Servers(self)
|
||||
self.is_running = False
|
||||
self._inject_tasks = set()
|
||||
|
||||
def __repr__(self):
|
||||
return f"Proxyserver({len(self.connections)} active conns)"
|
||||
|
@ -308,7 +310,15 @@ class Proxyserver(ServerManager):
|
|||
)
|
||||
if connection_id not in self.connections:
|
||||
raise ValueError("Flow is not from a live connection.")
|
||||
self.connections[connection_id].server_event(event)
|
||||
|
||||
t = asyncio_utils.create_task(
|
||||
self.connections[connection_id].server_event(event),
|
||||
name=f"inject_event",
|
||||
client=event.flow.client_conn.peername,
|
||||
)
|
||||
# Python 3.11 Use TaskGroup instead.
|
||||
self._inject_tasks.add(t)
|
||||
t.add_done_callback(self._inject_tasks.remove)
|
||||
|
||||
@command.command("inject.websocket")
|
||||
def inject_websocket(
|
||||
|
|
|
@ -122,6 +122,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||
timeout = UDP_TIMEOUT
|
||||
self.timeout_watchdog = TimeoutWatchdog(timeout, self.on_timeout)
|
||||
|
||||
self._server_event_lock = asyncio.Lock()
|
||||
|
||||
# workaround for https://bugs.python.org/issue40124 / https://bugs.python.org/issue29930
|
||||
self._drain_lock = asyncio.Lock()
|
||||
|
||||
|
@ -144,7 +146,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||
assert writer
|
||||
writer.close()
|
||||
else:
|
||||
self.server_event(events.Start())
|
||||
await self.server_event(events.Start())
|
||||
handler = asyncio_utils.create_task(
|
||||
self.handle_connection(self.client),
|
||||
name=f"client connection handler",
|
||||
|
@ -181,7 +183,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||
async def open_connection(self, command: commands.OpenConnection) -> None:
|
||||
if not command.connection.address:
|
||||
self.log(f"Cannot open connection, no hostname given.")
|
||||
self.server_event(
|
||||
await self.server_event(
|
||||
events.OpenConnectionCompleted(
|
||||
command, f"Cannot open connection, no hostname given."
|
||||
)
|
||||
|
@ -197,7 +199,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||
f"server connection to {human.format_address(command.connection.address)} killed before connect: {err}"
|
||||
)
|
||||
await self.handle_hook(server_hooks.ServerConnectErrorHook(hook_data))
|
||||
self.server_event(
|
||||
await self.server_event(
|
||||
events.OpenConnectionCompleted(command, f"Connection killed: {err}")
|
||||
)
|
||||
return
|
||||
|
@ -226,7 +228,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||
self.log(f"error establishing server connection: {err}")
|
||||
command.connection.error = err
|
||||
await self.handle_hook(server_hooks.ServerConnectErrorHook(hook_data))
|
||||
self.server_event(events.OpenConnectionCompleted(command, err))
|
||||
await self.server_event(events.OpenConnectionCompleted(command, err))
|
||||
if isinstance(e, asyncio.CancelledError):
|
||||
# From https://docs.python.org/3/library/asyncio-exceptions.html#asyncio.CancelledError:
|
||||
# > In almost all situations the exception must be re-raised.
|
||||
|
@ -252,7 +254,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||
addr = human.format_address(command.connection.address)
|
||||
self.log(f"server connect {addr}")
|
||||
await self.handle_hook(server_hooks.ServerConnectedHook(hook_data))
|
||||
self.server_event(events.OpenConnectionCompleted(command, None))
|
||||
await self.server_event(events.OpenConnectionCompleted(command, None))
|
||||
|
||||
try:
|
||||
await self.handle_connection(command.connection)
|
||||
|
@ -268,7 +270,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||
task = asyncio.current_task()
|
||||
assert task is not None
|
||||
self.wakeup_timer.discard(task)
|
||||
self.server_event(events.Wakeup(request))
|
||||
await self.server_event(events.Wakeup(request))
|
||||
|
||||
async def handle_connection(self, connection: Connection) -> None:
|
||||
"""
|
||||
|
@ -290,7 +292,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||
cancelled = e
|
||||
break
|
||||
|
||||
self.server_event(events.DataReceived(connection, data))
|
||||
await self.server_event(events.DataReceived(connection, data))
|
||||
|
||||
try:
|
||||
await self.drain_writers()
|
||||
|
@ -304,7 +306,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||
else:
|
||||
connection.state = ConnectionState.CLOSED
|
||||
|
||||
self.server_event(events.ConnectionClosed(connection))
|
||||
await self.server_event(events.ConnectionClosed(connection))
|
||||
|
||||
if connection.state is ConnectionState.CAN_WRITE:
|
||||
# we may still use this connection to *send* stuff,
|
||||
|
@ -355,7 +357,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||
async def hook_task(self, hook: commands.StartHook) -> None:
|
||||
await self.handle_hook(hook)
|
||||
if hook.blocking:
|
||||
self.server_event(events.HookCompleted(hook))
|
||||
await self.server_event(events.HookCompleted(hook))
|
||||
|
||||
@abc.abstractmethod
|
||||
async def handle_hook(self, hook: commands.StartHook) -> None:
|
||||
|
@ -373,56 +375,67 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||
level, message, extra={"client": self.client.peername}, exc_info=exc_info
|
||||
)
|
||||
|
||||
def server_event(self, event: events.Event) -> None:
|
||||
self.timeout_watchdog.register_activity()
|
||||
try:
|
||||
layer_commands = self.layer.handle_event(event)
|
||||
for command in layer_commands:
|
||||
if isinstance(command, commands.OpenConnection):
|
||||
assert command.connection not in self.transports
|
||||
handler = asyncio_utils.create_task(
|
||||
self.open_connection(command),
|
||||
name=f"server connection handler {command.connection.address}",
|
||||
client=self.client.peername,
|
||||
)
|
||||
self.transports[command.connection] = ConnectionIO(handler=handler)
|
||||
elif isinstance(command, commands.RequestWakeup):
|
||||
task = asyncio_utils.create_task(
|
||||
self.wakeup(command),
|
||||
name=f"wakeup timer ({command.delay:.1f}s)",
|
||||
client=self.client.peername,
|
||||
)
|
||||
assert task is not None
|
||||
self.wakeup_timer.add(task)
|
||||
elif (
|
||||
isinstance(command, commands.ConnectionCommand)
|
||||
and command.connection not in self.transports
|
||||
):
|
||||
pass # The connection has already been closed.
|
||||
elif isinstance(command, commands.SendData):
|
||||
writer = self.transports[command.connection].writer
|
||||
assert writer
|
||||
if not writer.is_closing():
|
||||
writer.write(command.data)
|
||||
elif isinstance(command, commands.CloseTcpConnection):
|
||||
self.close_connection(command.connection, command.half_close)
|
||||
elif isinstance(command, commands.CloseConnection):
|
||||
self.close_connection(command.connection, False)
|
||||
elif isinstance(command, commands.StartHook):
|
||||
t = asyncio_utils.create_task(
|
||||
self.hook_task(command),
|
||||
name=f"handle_hook({command.name})",
|
||||
client=self.client.peername,
|
||||
)
|
||||
# Python 3.11 Use TaskGroup instead.
|
||||
self.hook_tasks.add(t)
|
||||
t.add_done_callback(self.hook_tasks.remove)
|
||||
elif isinstance(command, commands.Log):
|
||||
self.log(command.message, command.level)
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected command: {command}")
|
||||
except Exception:
|
||||
self.log(f"mitmproxy has crashed!", logging.ERROR, exc_info=True)
|
||||
async def server_event(self, event: events.Event) -> None:
|
||||
# server_event is supposed to be completely sync without any `await` that could pause execution.
|
||||
# However, create_task with an [eager task factory] will schedule tasks immediately,
|
||||
# which causes [reentrancy issues]. So we put the entire thing behind a lock.
|
||||
#
|
||||
# [eager task factory]: https://docs.python.org/3/library/asyncio-task.html#eager-task-factory
|
||||
# [reentrancy issues]: https://github.com/mitmproxy/mitmproxy/issues/7027.
|
||||
async with self._server_event_lock:
|
||||
# No `await` beyond this point.
|
||||
|
||||
self.timeout_watchdog.register_activity()
|
||||
try:
|
||||
layer_commands = self.layer.handle_event(event)
|
||||
for command in layer_commands:
|
||||
if isinstance(command, commands.OpenConnection):
|
||||
assert command.connection not in self.transports
|
||||
handler = asyncio_utils.create_task(
|
||||
self.open_connection(command),
|
||||
name=f"server connection handler {command.connection.address}",
|
||||
client=self.client.peername,
|
||||
)
|
||||
self.transports[command.connection] = ConnectionIO(
|
||||
handler=handler
|
||||
)
|
||||
elif isinstance(command, commands.RequestWakeup):
|
||||
task = asyncio_utils.create_task(
|
||||
self.wakeup(command),
|
||||
name=f"wakeup timer ({command.delay:.1f}s)",
|
||||
client=self.client.peername,
|
||||
)
|
||||
assert task is not None
|
||||
self.wakeup_timer.add(task)
|
||||
elif (
|
||||
isinstance(command, commands.ConnectionCommand)
|
||||
and command.connection not in self.transports
|
||||
):
|
||||
pass # The connection has already been closed.
|
||||
elif isinstance(command, commands.SendData):
|
||||
writer = self.transports[command.connection].writer
|
||||
assert writer
|
||||
if not writer.is_closing():
|
||||
writer.write(command.data)
|
||||
elif isinstance(command, commands.CloseTcpConnection):
|
||||
self.close_connection(command.connection, command.half_close)
|
||||
elif isinstance(command, commands.CloseConnection):
|
||||
self.close_connection(command.connection, False)
|
||||
elif isinstance(command, commands.StartHook):
|
||||
t = asyncio_utils.create_task(
|
||||
self.hook_task(command),
|
||||
name=f"handle_hook({command.name})",
|
||||
client=self.client.peername,
|
||||
)
|
||||
# Python 3.11 Use TaskGroup instead.
|
||||
self.hook_tasks.add(t)
|
||||
t.add_done_callback(self.hook_tasks.remove)
|
||||
elif isinstance(command, commands.Log):
|
||||
self.log(command.message, command.level)
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected command: {command}")
|
||||
except Exception:
|
||||
self.log(f"mitmproxy has crashed!", logging.ERROR, exc_info=True)
|
||||
|
||||
def close_connection(
|
||||
self, connection: Connection, half_close: bool = False
|
||||
|
@ -478,9 +491,9 @@ class SimpleConnectionHandler(LiveConnectionHandler): # pragma: no cover
|
|||
|
||||
hook_handlers: dict[str, Callable]
|
||||
|
||||
def __init__(self, reader, writer, options, mode, hooks):
|
||||
def __init__(self, reader, writer, options, mode, hook_handlers):
|
||||
super().__init__(reader, writer, options, mode)
|
||||
self.hook_handlers = hooks
|
||||
self.hook_handlers = hook_handlers
|
||||
|
||||
async def handle_hook(self, hook: commands.StartHook) -> None:
|
||||
if hook.name in self.hook_handlers:
|
||||
|
|
|
@ -35,7 +35,7 @@ class EagerTaskCreationEventLoopPolicy(asyncio.DefaultEventLoopPolicy):
|
|||
return loop
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop_policy(request):
|
||||
return EagerTaskCreationEventLoopPolicy()
|
||||
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
import asyncio
|
||||
import collections
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
@ -7,13 +10,17 @@ import pytest
|
|||
from mitmproxy import options
|
||||
from mitmproxy.connection import Server
|
||||
from mitmproxy.proxy import commands
|
||||
from mitmproxy.proxy import layer
|
||||
from mitmproxy.proxy import server
|
||||
from mitmproxy.proxy import server_hooks
|
||||
from mitmproxy.proxy.events import Event
|
||||
from mitmproxy.proxy.events import HookCompleted
|
||||
from mitmproxy.proxy.events import Start
|
||||
from mitmproxy.proxy.mode_specs import ProxyMode
|
||||
|
||||
|
||||
class MockConnectionHandler(server.SimpleConnectionHandler):
|
||||
hook_handlers: dict[str, mock.Mock]
|
||||
hook_handlers: dict[str, mock.Mock | Callable]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
@ -21,7 +28,7 @@ class MockConnectionHandler(server.SimpleConnectionHandler):
|
|||
writer=mock.Mock(),
|
||||
options=options.Options(),
|
||||
mode=ProxyMode.parse("regular"),
|
||||
hooks=collections.defaultdict(lambda: mock.Mock()),
|
||||
hook_handlers=collections.defaultdict(lambda: mock.Mock()),
|
||||
)
|
||||
|
||||
|
||||
|
@ -64,3 +71,36 @@ async def test_open_connection(result, monkeypatch):
|
|||
assert server_connect_error.called == (result != "success")
|
||||
|
||||
assert server_disconnected.called == (result == "success")
|
||||
|
||||
|
||||
async def test_no_reentrancy(capsys):
|
||||
class ReentrancyTestLayer(layer.Layer):
|
||||
def handle_event(self, event: Event) -> layer.CommandGenerator[None]:
|
||||
if isinstance(event, Start):
|
||||
print("Starting...")
|
||||
yield FastHook()
|
||||
print("Start completed.")
|
||||
elif isinstance(event, HookCompleted):
|
||||
print(f"Hook completed (must not happen before start is completed).")
|
||||
|
||||
def _handle_event(self, event: Event) -> layer.CommandGenerator[None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@dataclass
|
||||
class FastHook(commands.StartHook):
|
||||
pass
|
||||
|
||||
handler = MockConnectionHandler()
|
||||
handler.layer = ReentrancyTestLayer(handler.layer.context)
|
||||
|
||||
# This instead would fail: handler._server_event(Start())
|
||||
await handler.server_event(Start())
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert capsys.readouterr().out == textwrap.dedent(
|
||||
"""\
|
||||
Starting...
|
||||
Start completed.
|
||||
Hook completed (must not happen before start is completed).
|
||||
"""
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue