diff --git a/examples/addons/websocket-inject-message.py b/examples/addons/websocket-inject-message.py new file mode 100644 index 000000000..7bc60764e --- /dev/null +++ b/examples/addons/websocket-inject-message.py @@ -0,0 +1,32 @@ +""" +Inject a WebSocket message into a running connection. + +This example shows how to inject a WebSocket message into a running connection. +""" +import asyncio + +from mitmproxy import ctx, http + + +# Simple example: Inject a message as a response to an event + +def websocket_message(flow): + last_message = flow.websocket.messages[-1] + if b"secret" in last_message.content: + last_message.kill() + ctx.master.commands.call("inject.websocket", flow, last_message.from_client, "ssssssh") + + +# Complex example: Schedule a periodic timer + +async def inject_async(flow: http.HTTPFlow): + msg = "hello from mitmproxy! " + assert flow.websocket # make type checker happy + while flow.websocket.timestamp_end is None: + ctx.master.commands.call("inject.websocket", flow, True, msg) + await asyncio.sleep(1) + msg = msg[1:] + msg[:1] + + +def websocket_start(flow: http.HTTPFlow): + asyncio.create_task(inject_async(flow)) diff --git a/mitmproxy/addons/proxyserver.py b/mitmproxy/addons/proxyserver.py index 912d60f56..549d05739 100644 --- a/mitmproxy/addons/proxyserver.py +++ b/mitmproxy/addons/proxyserver.py @@ -1,12 +1,15 @@ import asyncio import warnings -from typing import Optional +from typing import Dict, Optional, Tuple -from mitmproxy import controller, ctx, flow, log, master, options, platform -from mitmproxy.flow import Error -from mitmproxy.proxy import commands +from mitmproxy import command, controller, ctx, flow, http, log, master, options, platform, tcp, websocket +from mitmproxy.flow import Error, Flow +from mitmproxy.proxy import commands, events from mitmproxy.proxy import server -from mitmproxy.utils import asyncio_utils, human +from mitmproxy.proxy.layers.tcp import TcpMessageInjected +from mitmproxy.proxy.layers.websocket import WebSocketMessageInjected +from mitmproxy.utils import asyncio_utils, human, strutils +from wsproto.frame_protocol import Opcode class AsyncReply(controller.Reply): @@ -67,11 +70,16 @@ class Proxyserver: master: master.Master options: options.Options is_running: bool + _connections: Dict[Tuple, ProxyConnectionHandler] def __init__(self): self._lock = asyncio.Lock() self.server = None self.is_running = False + self._connections = {} + + def __repr__(self): + return f"ProxyServer({'running' if self.server else 'stopped'}, {len(self._connections)} active conns)" def load(self, loader): loader.add_option( @@ -121,10 +129,11 @@ class Proxyserver: self.server = None async def handle_connection(self, r, w): + peername = w.get_extra_info('peername') asyncio_utils.set_task_debug_info( asyncio.current_task(), name=f"Proxyserver.handle_connection", - client=w.get_extra_info('peername'), + client=peername, ) handler = ProxyConnectionHandler( self.master, @@ -132,4 +141,42 @@ class Proxyserver: w, self.options ) - await handler.handle_client() + self._connections[peername] = handler + try: + await handler.handle_client() + finally: + del self._connections[peername] + + def inject_event(self, event: events.MessageInjected): + if event.flow.client_conn.peername not in self._connections: + raise ValueError("Flow is not from a live connection.") + self._connections[event.flow.client_conn.peername].server_event(event) + + @command.command("inject.websocket") + def inject_websocket(self, flow: Flow, to_client: bool, message: str, is_text: bool = True): + if not isinstance(flow, http.HTTPFlow) or not flow.websocket: + ctx.log.warn("Cannot inject WebSocket messages into non-WebSocket flows.") + + message_bytes = strutils.escaped_str_to_bytes(message) + msg = websocket.WebSocketMessage( + Opcode.TEXT if is_text else Opcode.BINARY, + not to_client, + message_bytes + ) + event = WebSocketMessageInjected(flow, msg) + try: + self.inject_event(event) + except ValueError as e: + ctx.log.warn(str(e)) + + @command.command("inject.tcp") + def inject_tcp(self, flow: Flow, to_client: bool, message: str): + if not isinstance(flow, tcp.TCPFlow): + ctx.log.warn("Cannot inject TCP messages into non-TCP flows.") + + message_bytes = strutils.escaped_str_to_bytes(message) + event = TcpMessageInjected(flow, tcp.TCPMessage(not to_client, message_bytes)) + try: + self.inject_event(event) + except ValueError as e: + ctx.log.warn(str(e)) diff --git a/mitmproxy/io/compat.py b/mitmproxy/io/compat.py index 16887dd1a..795dda8af 100644 --- a/mitmproxy/io/compat.py +++ b/mitmproxy/io/compat.py @@ -291,6 +291,7 @@ def convert_11_12(data): "closed_by_client": ws_flow["close_sender"] == "client", "close_code": ws_flow["close_code"], "close_reason": ws_flow["close_reason"], + "timestamp_end": data.get("server_conn", {}).get("timestamp_end", None), } else: diff --git a/mitmproxy/proxy/commands.py b/mitmproxy/proxy/commands.py index 04429800e..8f0799005 100644 --- a/mitmproxy/proxy/commands.py +++ b/mitmproxy/proxy/commands.py @@ -23,6 +23,9 @@ class Command: blocking: Union[bool, "mitmproxy.proxy.layer.Layer"] = False """ Determines if the command blocks until it has been completed. + For practical purposes, this attribute should be thought of as a boolean value, + layers may swap out `True` with a reference to themselves to signal to outer layers + that they do not need to block as well. Example: diff --git a/mitmproxy/proxy/events.py b/mitmproxy/proxy/events.py index fc9bee681..1b6b6ea02 100644 --- a/mitmproxy/proxy/events.py +++ b/mitmproxy/proxy/events.py @@ -8,6 +8,7 @@ import typing import warnings from dataclasses import dataclass, is_dataclass +from mitmproxy import flow from mitmproxy.proxy import commands from mitmproxy.connection import Connection @@ -106,3 +107,15 @@ class HookCompleted(CommandCompleted): class GetSocketCompleted(CommandCompleted): command: commands.GetSocket reply: socket.socket + + +T = typing.TypeVar('T') + + +@dataclass +class MessageInjected(Event, typing.Generic[T]): + """ + The user has injected a custom WebSocket/TCP/... message. + """ + flow: flow.Flow + message: T diff --git a/mitmproxy/proxy/layer.py b/mitmproxy/proxy/layer.py index d4f7820f6..ef0881d62 100644 --- a/mitmproxy/proxy/layer.py +++ b/mitmproxy/proxy/layer.py @@ -33,21 +33,32 @@ class Layer: Layers interface with their child layer(s) by calling .handle_event(event), which returns a list (more precisely: a generator) of commands. - Most layers only implement ._handle_event, which is called by the default implementation of .handle_event. - The default implementation allows layers to emulate blocking code: + Most layers do not implement .directly, but instead implement ._handle_event, which + is called by the default implementation of .handle_event. + The default implementation of .handle_event allows layers to emulate blocking code: When ._handle_event yields a command that has its blocking attribute set to True, .handle_event pauses - the execution of ._handle_event and waits until it is called with the corresponding CommandCompleted event. All - events encountered in the meantime are buffered and replayed after execution is resumed. + the execution of ._handle_event and waits until it is called with the corresponding CommandCompleted event. + All events encountered in the meantime are buffered and replayed after execution is resumed. The result is code that looks like blocking code, but is not blocking: def _handle_event(self, event): err = yield OpenConnection(server) # execution continues here after a connection has been established. + + Technically this is very similar to how coroutines are implemented. """ __last_debug_message: ClassVar[str] = "" context: Context _paused: Optional[Paused] + """ + If execution is currently paused, this attribute stores the paused coroutine + and the command for which we are expecting a reply. + """ _paused_event_queue: Deque[events.Event] + """ + All events that have occurred since execution was paused. + These will be replayed to ._child_layer once we resume. + """ debug: Optional[str] = None """ Enable debug logging by assigning a prefix string for log messages. @@ -75,6 +86,7 @@ class Layer: return f"{type(self).__name__}({state})" def __debug(self, message): + """yield a Log command indicating what message is passing through this layer.""" if len(message) > 512: message = message[:512] + "…" if Layer.__last_debug_message == message: @@ -88,6 +100,16 @@ class Layer: "debug" ) + @property + def stack_pos(self) -> str: + """repr() for this layer and all its parent layers, only useful for debugging.""" + try: + idx = self.context.layers.index(self) + except ValueError: + return repr(self) + else: + return " >> ".join(repr(x) for x in self.context.layers[:idx + 1]) + @abstractmethod def _handle_event(self, event: events.Event) -> CommandGenerator[None]: """Handle a proxy server event""" @@ -111,15 +133,53 @@ class Layer: if self.debug is not None: yield self.__debug(f">> {event}") command_generator = self._handle_event(event) - yield from self.__process(command_generator) + send = None + + # inlined copy of __process to reduce call stack. + # <✂✂✂> + try: + # Run ._handle_event to the next yield statement. + # If you are not familiar with generators and their .send() method, + # https://stackoverflow.com/a/12638313/934719 has a good explanation. + command = command_generator.send(send) + except StopIteration: + return + + while True: + if self.debug is not None: + if not isinstance(command, commands.Log): + yield self.__debug(f"<< {command}") + if command.blocking is True: + # We only want this layer to block, the outer layers should not block. + # For example, take an HTTP/2 connection: If we intercept one particular request, + # we don't want all other requests in the connection to be blocked a well. + # We signal to outer layers that this command is already handled by assigning our layer to + # `.blocking` here (upper layers explicitly check for `is True`). + command.blocking = self + self._paused = Paused( + command, + command_generator, + ) + yield command + return + else: + yield command + try: + command = next(command_generator) + except StopIteration: + return + # def __process(self, command_generator: CommandGenerator, send=None): """ - yield all commands from a generator. - if a command is blocking, the layer is paused and this function returns before - processing any other commands. + Yield commands from a generator. + If a command is blocking, execution is paused and this function returns without + processing any further commands. """ try: + # Run ._handle_event to the next yield statement. + # If you are not familiar with generators and their .send() method, + # https://stackoverflow.com/a/12638313/934719 has a good explanation. command = command_generator.send(send) except StopIteration: return @@ -129,7 +189,12 @@ class Layer: if not isinstance(command, commands.Log): yield self.__debug(f"<< {command}") if command.blocking is True: - command.blocking = self # assign to our layer so that higher layers don't block. + # We only want this layer to block, the outer layers should not block. + # For example, take an HTTP/2 connection: If we intercept one particular request, + # we don't want all other requests in the connection to be blocked a well. + # We signal to outer layers that this command is already handled by assigning our layer to + # `.blocking` here (upper layers explicitly check for `is True`). + command.blocking = self self._paused = Paused( command, command_generator, @@ -144,7 +209,11 @@ class Layer: return def __continue(self, event: events.CommandCompleted): - """continue processing events after being paused""" + """ + Continue processing events after being paused. + The tricky part here is that events in the event queue may trigger commands which again pause the execution, + so we may not be able to process the entire queue. + """ assert self._paused is not None command_generator = self._paused.generator self._paused = None diff --git a/mitmproxy/proxy/layers/http/__init__.py b/mitmproxy/proxy/layers/http/__init__.py index 77568d4c1..5d71e70bc 100644 --- a/mitmproxy/proxy/layers/http/__init__.py +++ b/mitmproxy/proxy/layers/http/__init__.py @@ -20,7 +20,7 @@ from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders ResponseData, ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError, ResponseTrailers from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \ HttpResponseHook -from ._http1 import Http1Client, Http1Server +from ._http1 import Http1Client, Http1Connection, Http1Server from ._http2 import Http2Client, Http2Server from ...context import Context @@ -114,13 +114,16 @@ class HttpStream(layer.Layer): self.stream_id = stream_id def __repr__(self): - return ( - f"HttpStream(" - f"id={self.stream_id}, " - f"client_state={self.client_state.__name__}, " - f"server_state={self.server_state.__name__}" - f")" - ) + if self._handle_event == self.passthrough: + return f"HttpStream(id={self.stream_id}, passthrough)" + else: + return ( + f"HttpStream(" + f"id={self.stream_id}, " + f"client_state={self.client_state.__name__}, " + f"server_state={self.server_state.__name__}" + f")" + ) @expect(events.Start, HttpEvent) def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: @@ -608,6 +611,26 @@ class HttpLayer(layer.Layer): elif isinstance(event, events.CommandCompleted): stream = self.command_sources.pop(event.command) yield from self.event_to_child(stream, event) + elif isinstance(event, events.MessageInjected): + # For injected messages we pass the HTTP stacks entirely and directly address the stream. + try: + conn = self.connections[event.flow.server_conn] + except KeyError: + # We have a miss for the server connection, which means we're looking at a connection object + # that is tunneled over another connection (for example: over an upstream HTTP proxy). + # We now take the stream associated with the client connection. That won't work for HTTP/2, + # but it's good enough for HTTP/1. + conn = self.connections[event.flow.client_conn] + if isinstance(conn, HttpStream): + stream_id = conn.stream_id + else: + # We reach to the end of the connection's child stack to get the HTTP/1 client layer, + # which tells us which stream we are dealing with. + conn = conn.context.layers[-1] + assert isinstance(conn, Http1Connection) + assert conn.stream_id + stream_id = conn.stream_id + yield from self.event_to_child(self.streams[stream_id], event) elif isinstance(event, events.ConnectionEvent): if event.connection == self.context.server and self.context.server not in self.connections: # We didn't do anything with this connection yet, now the peer has closed it - let's close it too! @@ -745,6 +768,8 @@ class HttpLayer(layer.Layer): class HttpClient(layer.Layer): + child_layer: layer.Layer + @expect(events.Start) def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: err: Optional[str] @@ -753,11 +778,10 @@ class HttpClient(layer.Layer): else: err = yield commands.OpenConnection(self.context.server) if not err: - child_layer: layer.Layer if self.context.server.alpn == b"h2": - child_layer = Http2Client(self.context) + self.child_layer = Http2Client(self.context) else: - child_layer = Http1Client(self.context) - self._handle_event = child_layer.handle_event + self.child_layer = Http1Client(self.context) + self._handle_event = self.child_layer.handle_event yield from self._handle_event(event) yield RegisterHttpConnection(self.context.server, err) diff --git a/mitmproxy/proxy/layers/tcp.py b/mitmproxy/proxy/layers/tcp.py index 6b6f65f20..c6c75c583 100644 --- a/mitmproxy/proxy/layers/tcp.py +++ b/mitmproxy/proxy/layers/tcp.py @@ -6,6 +6,7 @@ from mitmproxy.proxy import commands, events, layer from mitmproxy.proxy.commands import StartHook from mitmproxy.connection import ConnectionState, Connection from mitmproxy.proxy.context import Context +from mitmproxy.proxy.events import MessageInjected from mitmproxy.proxy.utils import expect @@ -45,6 +46,12 @@ class TcpErrorHook(StartHook): flow: tcp.TCPFlow +class TcpMessageInjected(MessageInjected[tcp.TCPMessage]): + """ + The user has injected a custom TCP message. + """ + + class TCPLayer(layer.Layer): """ Simple TCP layer that just relays messages right now. @@ -76,8 +83,18 @@ class TCPLayer(layer.Layer): _handle_event = start - @expect(events.DataReceived, events.ConnectionClosed) - def relay_messages(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]: + @expect(events.DataReceived, events.ConnectionClosed, TcpMessageInjected) + def relay_messages(self, event: events.Event) -> layer.CommandGenerator[None]: + + if isinstance(event, TcpMessageInjected): + # we just spoof that we received data here and then process that regularly. + event = events.DataReceived( + self.context.client if event.message.from_client else self.context.server, + event.message.content, + ) + + assert isinstance(event, events.ConnectionEvent) + from_client = event.connection == self.context.client send_to: Connection if from_client: @@ -110,7 +127,9 @@ class TCPLayer(layer.Layer): yield TcpEndHook(self.flow) else: yield commands.CloseConnection(send_to, half_close=True) + else: + raise AssertionError(f"Unexpected event: {event}") - @expect(events.DataReceived, events.ConnectionClosed) + @expect(events.DataReceived, events.ConnectionClosed, TcpMessageInjected) def done(self, _) -> layer.CommandGenerator[None]: yield from () diff --git a/mitmproxy/proxy/layers/websocket.py b/mitmproxy/proxy/layers/websocket.py index f83ef14f0..cdf10ab1d 100644 --- a/mitmproxy/proxy/layers/websocket.py +++ b/mitmproxy/proxy/layers/websocket.py @@ -1,3 +1,4 @@ +import time from dataclasses import dataclass from typing import Iterator, List @@ -9,6 +10,7 @@ from mitmproxy import connection, flow, http, websocket from mitmproxy.proxy import commands, events, layer from mitmproxy.proxy.commands import StartHook from mitmproxy.proxy.context import Context +from mitmproxy.proxy.events import MessageInjected from mitmproxy.proxy.utils import expect from wsproto import ConnectionState from wsproto.frame_protocol import CloseReason, Opcode @@ -52,6 +54,12 @@ class WebsocketErrorHook(StartHook): flow: http.HTTPFlow +class WebSocketMessageInjected(MessageInjected[websocket.WebSocketMessage]): + """ + The user has injected a custom WebSocket message. + """ + + class WebsocketConnection(wsproto.Connection): """ A very thin wrapper around wsproto.Connection: @@ -120,11 +128,17 @@ class WebsocketLayer(layer.Layer): _handle_event = start - @expect(events.DataReceived, events.ConnectionClosed) - def relay_messages(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]: + @expect(events.DataReceived, events.ConnectionClosed, WebSocketMessageInjected) + def relay_messages(self, event: events.Event) -> layer.CommandGenerator[None]: assert self.flow.websocket # satisfy type checker - from_client = event.connection == self.context.client + if isinstance(event, events.ConnectionEvent): + from_client = event.connection == self.context.client + elif isinstance(event, WebSocketMessageInjected): + from_client = event.message.from_client + else: + raise AssertionError(f"Unexpected event: {event}") + from_str = 'client' if from_client else 'server' if from_client: src_ws = self.client_ws @@ -137,6 +151,11 @@ class WebsocketLayer(layer.Layer): src_ws.receive_data(event.data) elif isinstance(event, events.ConnectionClosed): src_ws.receive_data(None) + elif isinstance(event, WebSocketMessageInjected): + fragmentizer = Fragmentizer([], event.message.type == Opcode.TEXT) + src_ws._events.extend( + fragmentizer(event.message.content) + ) else: # pragma: no cover raise AssertionError(f"Unexpected event: {event}") @@ -171,6 +190,7 @@ class WebsocketLayer(layer.Layer): ) yield dst_ws.send2(ws_event) elif isinstance(ws_event, wsproto.events.CloseConnection): + self.flow.websocket.timestamp_end = time.time() self.flow.websocket.closed_by_client = from_client self.flow.websocket.close_code = ws_event.code self.flow.websocket.close_reason = ws_event.reason @@ -189,7 +209,7 @@ class WebsocketLayer(layer.Layer): else: # pragma: no cover raise AssertionError(f"Unexpected WebSocket event: {ws_event}") - @expect(events.DataReceived, events.ConnectionClosed) + @expect(events.DataReceived, events.ConnectionClosed, WebSocketMessageInjected) def done(self, _) -> layer.CommandGenerator[None]: yield from () @@ -219,7 +239,6 @@ class Fragmentizer: FRAGMENT_SIZE = 4000 def __init__(self, fragments: List[bytes], is_text: bool): - assert fragments self.fragment_lengths = [len(x) for x in fragments] self.is_text = is_text diff --git a/mitmproxy/tools/console/flowview.py b/mitmproxy/tools/console/flowview.py index 4c6a11176..2f876aa32 100644 --- a/mitmproxy/tools/console/flowview.py +++ b/mitmproxy/tools/console/flowview.py @@ -3,19 +3,18 @@ import sys from functools import lru_cache from typing import Optional, Union # noqa -import urwid - import mitmproxy.flow +import mitmproxy.tools.console.master # noqa +import urwid from mitmproxy import contentviews from mitmproxy import ctx from mitmproxy import http from mitmproxy import tcp from mitmproxy.tools.console import common -from mitmproxy.tools.console import layoutwidget from mitmproxy.tools.console import flowdetailview +from mitmproxy.tools.console import layoutwidget from mitmproxy.tools.console import searchable from mitmproxy.tools.console import tabs -import mitmproxy.tools.console.master # noqa from mitmproxy.utils import strutils @@ -26,8 +25,8 @@ class SearchError(Exception): class FlowViewHeader(urwid.WidgetWrap): def __init__( - self, - master: "mitmproxy.tools.console.master.ConsoleMaster", + self, + master: "mitmproxy.tools.console.master.ConsoleMaster", ) -> None: self.master = master self.focus_changed() @@ -140,6 +139,9 @@ class FlowDetails(tabs.Tabs): contentview_status_bar = urwid.AttrWrap(urwid.Columns(cols), "heading") return contentview_status_bar + FROM_CLIENT_MARKER = ("from_client", f"{common.SYMBOL_FROM_CLIENT} ") + TO_CLIENT_MARKER = ("to_client", f"{common.SYMBOL_TO_CLIENT} ") + def view_websocket_messages(self): flow = self.flow assert isinstance(flow, http.HTTPFlow) @@ -156,12 +158,19 @@ class FlowDetails(tabs.Tabs): for line in lines: if m.from_client: - line.insert(0, ("from_client", f"{common.SYMBOL_FROM_CLIENT} ")) + line.insert(0, self.FROM_CLIENT_MARKER) else: - line.insert(0, ("to_client", f"{common.SYMBOL_TO_CLIENT} ")) + line.insert(0, self.TO_CLIENT_MARKER) widget_lines.append(urwid.Text(line)) + if flow.websocket.closed_by_client is not None: + widget_lines.append(urwid.Text([ + (self.FROM_CLIENT_MARKER if flow.websocket.closed_by_client else self.TO_CLIENT_MARKER), + ("alert" if flow.websocket.close_code in (1000, 1001, 1005) else "error", + f"Connection closed: {flow.websocket.close_code} {flow.websocket.close_reason}") + ])) + if flow.intercepted: markup = widget_lines[-1].get_text()[0] widget_lines[-1].set_text(("intercept", markup)) @@ -198,9 +207,9 @@ class FlowDetails(tabs.Tabs): for line in lines: if from_client: - line.insert(0, ("from_client", f"{common.SYMBOL_FROM_CLIENT} ")) + line.insert(0, self.FROM_CLIENT_MARKER) else: - line.insert(0, ("to_client", f"{common.SYMBOL_TO_CLIENT} ")) + line.insert(0, self.TO_CLIENT_MARKER) widget_lines.append(urwid.Text(line)) diff --git a/mitmproxy/websocket.py b/mitmproxy/websocket.py index 841944058..7fd0fcb02 100644 --- a/mitmproxy/websocket.py +++ b/mitmproxy/websocket.py @@ -106,11 +106,15 @@ class WebSocketData(stateobject.StateObject): close_reason: Optional[str] = None """[Close Reason](https://tools.ietf.org/html/rfc6455#section-7.1.6)""" + timestamp_end: Optional[float] = None + """*Timestamp:* WebSocket connection closed.""" + _stateobject_attributes = dict( messages=List[WebSocketMessage], closed_by_client=bool, close_code=int, close_reason=str, + timestamp_end=float, ) def __init__(self): diff --git a/test/mitmproxy/addons/test_proxyserver.py b/test/mitmproxy/addons/test_proxyserver.py index 333405ec7..0e85054de 100644 --- a/test/mitmproxy/addons/test_proxyserver.py +++ b/test/mitmproxy/addons/test_proxyserver.py @@ -7,7 +7,7 @@ from mitmproxy.addons.proxyserver import Proxyserver from mitmproxy.proxy.layers.http import HTTPMode from mitmproxy.proxy import layers from mitmproxy.connection import Address -from mitmproxy.test import taddons +from mitmproxy.test import taddons, tflow class HelperAddon: @@ -15,12 +15,16 @@ class HelperAddon: self.flows = [] self.layers = [ lambda ctx: layers.modes.HttpProxy(ctx), - lambda ctx: layers.HttpLayer(ctx, HTTPMode.regular) + lambda ctx: layers.HttpLayer(ctx, HTTPMode.regular), + lambda ctx: layers.TCPLayer(ctx), ] def request(self, f): self.flows.append(f) + def tcp_start(self, f): + self.flows.append(f) + def next_layer(self, nl): nl.layer = self.layers.pop(0)(nl.context) @@ -59,6 +63,7 @@ async def test_start_stop(): req = f"GET http://{addr[0]}:{addr[1]}/hello HTTP/1.1\r\n\r\n" writer.write(req.encode()) assert await reader.readuntil(b"\r\n\r\n") == b"HTTP/1.1 204 No Content\r\n\r\n" + assert repr(ps) == "ProxyServer(running, 1 active conns)" tctx.configure(ps, server=False) await tctx.master.await_log("Stopping server", level="info") @@ -67,6 +72,80 @@ async def test_start_stop(): assert state.flows[0].request.path == "/hello" assert state.flows[0].response.status_code == 204 + # Waiting here until everything is really torn down... takes some effort. + conn_handler = list(ps._connections.values())[0] + client_handler = conn_handler.transports[conn_handler.client].handler + writer.close() + await writer.wait_closed() + try: + await client_handler + except asyncio.CancelledError: + pass + for _ in range(5): + # Get all other scheduled coroutines to run. + await asyncio.sleep(0) + assert repr(ps) == "ProxyServer(stopped, 0 active conns)" + + +@pytest.mark.asyncio +async def test_inject(): + async def server_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + while s := await reader.read(1): + writer.write(s.upper()) + + ps = Proxyserver() + with taddons.context(ps) as tctx: + state = HelperAddon() + tctx.master.addons.add(state) + async with tcp_server(server_handler) as addr: + tctx.configure(ps, listen_host="127.0.0.1", listen_port=0) + ps.running() + await tctx.master.await_log("Proxy server listening", level="info") + proxy_addr = ps.server.sockets[0].getsockname()[:2] + reader, writer = await asyncio.open_connection(*proxy_addr) + + req = f"CONNECT {addr[0]}:{addr[1]} HTTP/1.1\r\n\r\n" + writer.write(req.encode()) + assert await reader.readuntil(b"\r\n\r\n") == b"HTTP/1.1 200 Connection established\r\n\r\n" + + writer.write(b"a") + assert await reader.read(1) == b"A" + ps.inject_tcp(state.flows[0], False, "b") + assert await reader.read(1) == b"B" + ps.inject_tcp(state.flows[0], True, "c") + assert await reader.read(1) == b"c" + + +@pytest.mark.asyncio +async def test_inject_fail(): + ps = Proxyserver() + with taddons.context(ps) as tctx: + ps.inject_websocket( + tflow.tflow(), + True, + "test" + ) + await tctx.master.await_log("Cannot inject WebSocket messages into non-WebSocket flows.", level="warn") + ps.inject_tcp( + tflow.tflow(), + True, + "test" + ) + await tctx.master.await_log("Cannot inject TCP messages into non-TCP flows.", level="warn") + + ps.inject_websocket( + tflow.twebsocketflow(), + True, + "test" + ) + await tctx.master.await_log("Flow is not from a live connection.", level="warn") + ps.inject_websocket( + tflow.ttcpflow(), + True, + "test" + ) + await tctx.master.await_log("Flow is not from a live connection.", level="warn") + @pytest.mark.asyncio async def test_warn_no_nextlayer(): diff --git a/test/mitmproxy/proxy/layers/http/test_http.py b/test/mitmproxy/proxy/layers/http/test_http.py index 31f3b530b..7a25ea079 100644 --- a/test/mitmproxy/proxy/layers/http/test_http.py +++ b/test/mitmproxy/proxy/layers/http/test_http.py @@ -9,9 +9,9 @@ from mitmproxy.proxy.commands import CloseConnection, OpenConnection, SendData, from mitmproxy.connection import ConnectionState, Server from mitmproxy.proxy.events import ConnectionClosed, DataReceived from mitmproxy.proxy.layers import TCPLayer, http, tls -from mitmproxy.proxy.layers.tcp import TcpStartHook +from mitmproxy.proxy.layers.tcp import TcpMessageInjected, TcpStartHook from mitmproxy.proxy.layers.websocket import WebsocketStartHook -from mitmproxy.tcp import TCPFlow +from mitmproxy.tcp import TCPFlow, TCPMessage from test.mitmproxy.proxy.tutils import Placeholder, Playbook, reply, reply_next_layer @@ -545,6 +545,7 @@ def test_upstream_proxy(tctx, redirect, scheme): def test_http_proxy_tcp(tctx, mode, close_first): """Test TCP over HTTP CONNECT.""" server = Placeholder(Server) + f = Placeholder(TCPFlow) if mode == "upstream": tctx.options.mode = "upstream:http://proxy:8080" @@ -560,7 +561,9 @@ def test_http_proxy_tcp(tctx, mode, close_first): << SendData(tctx.client, b"HTTP/1.1 200 Connection established\r\n\r\n") >> DataReceived(tctx.client, b"this is not http") << layer.NextLayerHook(Placeholder()) - >> reply_next_layer(lambda ctx: TCPLayer(ctx, ignore=True)) + >> reply_next_layer(lambda ctx: TCPLayer(ctx, ignore=False)) + << TcpStartHook(f) + >> reply() << OpenConnection(server) ) @@ -581,6 +584,12 @@ def test_http_proxy_tcp(tctx, mode, close_first): else: assert server().address == ("proxy", 8080) + assert ( + playbook + >> TcpMessageInjected(f, TCPMessage(False, b"fake news from your friendly man-in-the-middle")) + << SendData(tctx.client, b"fake news from your friendly man-in-the-middle") + ) + if close_first == "client": a, b = tctx.client, server else: diff --git a/test/mitmproxy/proxy/layers/test_tcp.py b/test/mitmproxy/proxy/layers/test_tcp.py index f899c637f..0c46e65b2 100644 --- a/test/mitmproxy/proxy/layers/test_tcp.py +++ b/test/mitmproxy/proxy/layers/test_tcp.py @@ -4,7 +4,8 @@ from mitmproxy.proxy.commands import CloseConnection, OpenConnection, SendData from mitmproxy.connection import ConnectionState from mitmproxy.proxy.events import ConnectionClosed, DataReceived from mitmproxy.proxy.layers import tcp -from mitmproxy.tcp import TCPFlow +from mitmproxy.proxy.layers.tcp import TcpMessageInjected +from mitmproxy.tcp import TCPFlow, TCPMessage from ..tutils import Placeholder, Playbook, reply @@ -122,3 +123,27 @@ def test_ignore(tctx, ignore): else: with pytest.raises(AssertionError): no_flow_hooks() + + +def test_inject(tctx): + """inject data into an open connection.""" + f = Placeholder(TCPFlow) + + assert ( + Playbook(tcp.TCPLayer(tctx)) + << tcp.TcpStartHook(f) + >> TcpMessageInjected(f, TCPMessage(True, b"hello!")) + >> reply(to=-2) + << OpenConnection(tctx.server) + >> reply(None) + << tcp.TcpMessageHook(f) + >> reply() + << SendData(tctx.server, b"hello!") + # and the other way... + >> TcpMessageInjected(f, TCPMessage(False, b"I have already done the greeting for you.")) + << tcp.TcpMessageHook(f) + >> reply() + << SendData(tctx.client, b"I have already done the greeting for you.") + << None + ) + assert len(f().messages) == 2 diff --git a/test/mitmproxy/proxy/layers/test_websocket.py b/test/mitmproxy/proxy/layers/test_websocket.py index 5d372a2a3..7e3ef5cd8 100644 --- a/test/mitmproxy/proxy/layers/test_websocket.py +++ b/test/mitmproxy/proxy/layers/test_websocket.py @@ -11,7 +11,8 @@ from mitmproxy.proxy.commands import SendData, CloseConnection, Log from mitmproxy.connection import ConnectionState from mitmproxy.proxy.events import DataReceived, ConnectionClosed from mitmproxy.proxy.layers import http, websocket -from mitmproxy.websocket import WebSocketData +from mitmproxy.proxy.layers.websocket import WebSocketMessageInjected +from mitmproxy.websocket import WebSocketData, WebSocketMessage from test.mitmproxy.proxy.tutils import Placeholder, Playbook, reply from wsproto.frame_protocol import Opcode @@ -319,3 +320,21 @@ class TestFragmentizer: wsproto.events.BytesMessage(b"foob", message_finished=False), wsproto.events.BytesMessage(b"ar", message_finished=True), ] + + +def test_inject_message(ws_testdata): + tctx, playbook, flow = ws_testdata + assert ( + playbook + << websocket.WebsocketStartHook(flow) + >> reply() + >> WebSocketMessageInjected(flow, WebSocketMessage(Opcode.TEXT, False, b"hello")) + << websocket.WebsocketMessageHook(flow) + ) + assert flow.websocket.messages[-1].content == b"hello" + assert flow.websocket.messages[-1].from_client is False + assert ( + playbook + >> reply() + << SendData(tctx.client, b"\x81\x05hello") + ) diff --git a/test/mitmproxy/proxy/test_layer.py b/test/mitmproxy/proxy/test_layer.py index a590ecfaf..c9dbbfe90 100644 --- a/test/mitmproxy/proxy/test_layer.py +++ b/test/mitmproxy/proxy/test_layer.py @@ -1,11 +1,26 @@ import pytest from mitmproxy.proxy import commands, events, layer +from mitmproxy.proxy.context import Context from test.mitmproxy.proxy import tutils class TestLayer: - def test_debug_messages(self, tctx): + def test_continue(self, tctx: Context): + class TLayer(layer.Layer): + def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]: + yield commands.OpenConnection(self.context.server) + yield commands.OpenConnection(self.context.server) + + assert ( + tutils.Playbook(TLayer(tctx)) + << commands.OpenConnection(tctx.server) + >> tutils.reply(None) + << commands.OpenConnection(tctx.server) + >> tutils.reply(None) + ) + + def test_debug_messages(self, tctx: Context): tctx.server.id = "serverid" class TLayer(layer.Layer): @@ -53,7 +68,7 @@ class TestLayer: class TestNextLayer: - def test_simple(self, tctx): + def test_simple(self, tctx: Context): nl = layer.NextLayer(tctx, ask_on_start=True) nl.debug = " " playbook = tutils.Playbook(nl, hooks=True) @@ -79,7 +94,7 @@ class TestNextLayer: << commands.SendData(tctx.client, b"bar") ) - def test_late_hook_reply(self, tctx): + def test_late_hook_reply(self, tctx: Context): """ Properly handle case where we receive an additional event while we are waiting for a reply from the proxy core. @@ -104,7 +119,7 @@ class TestNextLayer: ) @pytest.mark.parametrize("layer_found", [True, False]) - def test_receive_close(self, tctx, layer_found): + def test_receive_close(self, tctx: Context, layer_found: bool): """Test that we abort a client connection which has disconnected without any layer being found.""" nl = layer.NextLayer(tctx) playbook = tutils.Playbook(nl) @@ -128,7 +143,7 @@ class TestNextLayer: << commands.CloseConnection(tctx.client) ) - def test_func_references(self, tctx): + def test_func_references(self, tctx: Context): nl = layer.NextLayer(tctx) playbook = tutils.Playbook(nl) @@ -147,7 +162,9 @@ class TestNextLayer: sd, = handle(events.DataReceived(tctx.client, b"bar")) assert isinstance(sd, commands.SendData) - def test_repr(self, tctx): + def test_repr(self, tctx: Context): nl = layer.NextLayer(tctx) nl.layer = tutils.EchoLayer(tctx) assert repr(nl) + assert nl.stack_pos + assert nl.layer.stack_pos