Merge pull request #4502 from mhils/inject

Add WebSocket/TCP Message Injection
This commit is contained in:
Maximilian Hils 2021-03-16 15:07:06 +01:00 committed by GitHub
commit 0650f132e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 449 additions and 60 deletions

View File

@ -0,0 +1,32 @@
"""
Inject a WebSocket message into a running connection.
This example shows how to inject a WebSocket message into a running connection.
"""
import asyncio
from mitmproxy import ctx, http
# Simple example: Inject a message as a response to an event
def websocket_message(flow):
last_message = flow.websocket.messages[-1]
if b"secret" in last_message.content:
last_message.kill()
ctx.master.commands.call("inject.websocket", flow, last_message.from_client, "ssssssh")
# Complex example: Schedule a periodic timer
async def inject_async(flow: http.HTTPFlow):
msg = "hello from mitmproxy! "
assert flow.websocket # make type checker happy
while flow.websocket.timestamp_end is None:
ctx.master.commands.call("inject.websocket", flow, True, msg)
await asyncio.sleep(1)
msg = msg[1:] + msg[:1]
def websocket_start(flow: http.HTTPFlow):
asyncio.create_task(inject_async(flow))

View File

@ -1,12 +1,15 @@
import asyncio
import warnings
from typing import Optional
from typing import Dict, Optional, Tuple
from mitmproxy import controller, ctx, flow, log, master, options, platform
from mitmproxy.flow import Error
from mitmproxy.proxy import commands
from mitmproxy import command, controller, ctx, flow, http, log, master, options, platform, tcp, websocket
from mitmproxy.flow import Error, Flow
from mitmproxy.proxy import commands, events
from mitmproxy.proxy import server
from mitmproxy.utils import asyncio_utils, human
from mitmproxy.proxy.layers.tcp import TcpMessageInjected
from mitmproxy.proxy.layers.websocket import WebSocketMessageInjected
from mitmproxy.utils import asyncio_utils, human, strutils
from wsproto.frame_protocol import Opcode
class AsyncReply(controller.Reply):
@ -67,11 +70,16 @@ class Proxyserver:
master: master.Master
options: options.Options
is_running: bool
_connections: Dict[Tuple, ProxyConnectionHandler]
def __init__(self):
self._lock = asyncio.Lock()
self.server = None
self.is_running = False
self._connections = {}
def __repr__(self):
return f"ProxyServer({'running' if self.server else 'stopped'}, {len(self._connections)} active conns)"
def load(self, loader):
loader.add_option(
@ -121,10 +129,11 @@ class Proxyserver:
self.server = None
async def handle_connection(self, r, w):
peername = w.get_extra_info('peername')
asyncio_utils.set_task_debug_info(
asyncio.current_task(),
name=f"Proxyserver.handle_connection",
client=w.get_extra_info('peername'),
client=peername,
)
handler = ProxyConnectionHandler(
self.master,
@ -132,4 +141,42 @@ class Proxyserver:
w,
self.options
)
await handler.handle_client()
self._connections[peername] = handler
try:
await handler.handle_client()
finally:
del self._connections[peername]
def inject_event(self, event: events.MessageInjected):
if event.flow.client_conn.peername not in self._connections:
raise ValueError("Flow is not from a live connection.")
self._connections[event.flow.client_conn.peername].server_event(event)
@command.command("inject.websocket")
def inject_websocket(self, flow: Flow, to_client: bool, message: str, is_text: bool = True):
if not isinstance(flow, http.HTTPFlow) or not flow.websocket:
ctx.log.warn("Cannot inject WebSocket messages into non-WebSocket flows.")
message_bytes = strutils.escaped_str_to_bytes(message)
msg = websocket.WebSocketMessage(
Opcode.TEXT if is_text else Opcode.BINARY,
not to_client,
message_bytes
)
event = WebSocketMessageInjected(flow, msg)
try:
self.inject_event(event)
except ValueError as e:
ctx.log.warn(str(e))
@command.command("inject.tcp")
def inject_tcp(self, flow: Flow, to_client: bool, message: str):
if not isinstance(flow, tcp.TCPFlow):
ctx.log.warn("Cannot inject TCP messages into non-TCP flows.")
message_bytes = strutils.escaped_str_to_bytes(message)
event = TcpMessageInjected(flow, tcp.TCPMessage(not to_client, message_bytes))
try:
self.inject_event(event)
except ValueError as e:
ctx.log.warn(str(e))

View File

@ -291,6 +291,7 @@ def convert_11_12(data):
"closed_by_client": ws_flow["close_sender"] == "client",
"close_code": ws_flow["close_code"],
"close_reason": ws_flow["close_reason"],
"timestamp_end": data.get("server_conn", {}).get("timestamp_end", None),
}
else:

View File

@ -23,6 +23,9 @@ class Command:
blocking: Union[bool, "mitmproxy.proxy.layer.Layer"] = False
"""
Determines if the command blocks until it has been completed.
For practical purposes, this attribute should be thought of as a boolean value,
layers may swap out `True` with a reference to themselves to signal to outer layers
that they do not need to block as well.
Example:

View File

@ -8,6 +8,7 @@ import typing
import warnings
from dataclasses import dataclass, is_dataclass
from mitmproxy import flow
from mitmproxy.proxy import commands
from mitmproxy.connection import Connection
@ -106,3 +107,15 @@ class HookCompleted(CommandCompleted):
class GetSocketCompleted(CommandCompleted):
command: commands.GetSocket
reply: socket.socket
T = typing.TypeVar('T')
@dataclass
class MessageInjected(Event, typing.Generic[T]):
"""
The user has injected a custom WebSocket/TCP/... message.
"""
flow: flow.Flow
message: T

View File

@ -33,21 +33,32 @@ class Layer:
Layers interface with their child layer(s) by calling .handle_event(event),
which returns a list (more precisely: a generator) of commands.
Most layers only implement ._handle_event, which is called by the default implementation of .handle_event.
The default implementation allows layers to emulate blocking code:
Most layers do not implement .directly, but instead implement ._handle_event, which
is called by the default implementation of .handle_event.
The default implementation of .handle_event allows layers to emulate blocking code:
When ._handle_event yields a command that has its blocking attribute set to True, .handle_event pauses
the execution of ._handle_event and waits until it is called with the corresponding CommandCompleted event. All
events encountered in the meantime are buffered and replayed after execution is resumed.
the execution of ._handle_event and waits until it is called with the corresponding CommandCompleted event.
All events encountered in the meantime are buffered and replayed after execution is resumed.
The result is code that looks like blocking code, but is not blocking:
def _handle_event(self, event):
err = yield OpenConnection(server) # execution continues here after a connection has been established.
Technically this is very similar to how coroutines are implemented.
"""
__last_debug_message: ClassVar[str] = ""
context: Context
_paused: Optional[Paused]
"""
If execution is currently paused, this attribute stores the paused coroutine
and the command for which we are expecting a reply.
"""
_paused_event_queue: Deque[events.Event]
"""
All events that have occurred since execution was paused.
These will be replayed to ._child_layer once we resume.
"""
debug: Optional[str] = None
"""
Enable debug logging by assigning a prefix string for log messages.
@ -75,6 +86,7 @@ class Layer:
return f"{type(self).__name__}({state})"
def __debug(self, message):
"""yield a Log command indicating what message is passing through this layer."""
if len(message) > 512:
message = message[:512] + ""
if Layer.__last_debug_message == message:
@ -88,6 +100,16 @@ class Layer:
"debug"
)
@property
def stack_pos(self) -> str:
"""repr() for this layer and all its parent layers, only useful for debugging."""
try:
idx = self.context.layers.index(self)
except ValueError:
return repr(self)
else:
return " >> ".join(repr(x) for x in self.context.layers[:idx + 1])
@abstractmethod
def _handle_event(self, event: events.Event) -> CommandGenerator[None]:
"""Handle a proxy server event"""
@ -111,15 +133,53 @@ class Layer:
if self.debug is not None:
yield self.__debug(f">> {event}")
command_generator = self._handle_event(event)
yield from self.__process(command_generator)
send = None
# inlined copy of __process to reduce call stack.
# <✂✂✂>
try:
# Run ._handle_event to the next yield statement.
# If you are not familiar with generators and their .send() method,
# https://stackoverflow.com/a/12638313/934719 has a good explanation.
command = command_generator.send(send)
except StopIteration:
return
while True:
if self.debug is not None:
if not isinstance(command, commands.Log):
yield self.__debug(f"<< {command}")
if command.blocking is True:
# We only want this layer to block, the outer layers should not block.
# For example, take an HTTP/2 connection: If we intercept one particular request,
# we don't want all other requests in the connection to be blocked a well.
# We signal to outer layers that this command is already handled by assigning our layer to
# `.blocking` here (upper layers explicitly check for `is True`).
command.blocking = self
self._paused = Paused(
command,
command_generator,
)
yield command
return
else:
yield command
try:
command = next(command_generator)
except StopIteration:
return
# </✂✂✂>
def __process(self, command_generator: CommandGenerator, send=None):
"""
yield all commands from a generator.
if a command is blocking, the layer is paused and this function returns before
processing any other commands.
Yield commands from a generator.
If a command is blocking, execution is paused and this function returns without
processing any further commands.
"""
try:
# Run ._handle_event to the next yield statement.
# If you are not familiar with generators and their .send() method,
# https://stackoverflow.com/a/12638313/934719 has a good explanation.
command = command_generator.send(send)
except StopIteration:
return
@ -129,7 +189,12 @@ class Layer:
if not isinstance(command, commands.Log):
yield self.__debug(f"<< {command}")
if command.blocking is True:
command.blocking = self # assign to our layer so that higher layers don't block.
# We only want this layer to block, the outer layers should not block.
# For example, take an HTTP/2 connection: If we intercept one particular request,
# we don't want all other requests in the connection to be blocked a well.
# We signal to outer layers that this command is already handled by assigning our layer to
# `.blocking` here (upper layers explicitly check for `is True`).
command.blocking = self
self._paused = Paused(
command,
command_generator,
@ -144,7 +209,11 @@ class Layer:
return
def __continue(self, event: events.CommandCompleted):
"""continue processing events after being paused"""
"""
Continue processing events after being paused.
The tricky part here is that events in the event queue may trigger commands which again pause the execution,
so we may not be able to process the entire queue.
"""
assert self._paused is not None
command_generator = self._paused.generator
self._paused = None

View File

@ -20,7 +20,7 @@ from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders
ResponseData, ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError, ResponseTrailers
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
HttpResponseHook
from ._http1 import Http1Client, Http1Server
from ._http1 import Http1Client, Http1Connection, Http1Server
from ._http2 import Http2Client, Http2Server
from ...context import Context
@ -114,13 +114,16 @@ class HttpStream(layer.Layer):
self.stream_id = stream_id
def __repr__(self):
return (
f"HttpStream("
f"id={self.stream_id}, "
f"client_state={self.client_state.__name__}, "
f"server_state={self.server_state.__name__}"
f")"
)
if self._handle_event == self.passthrough:
return f"HttpStream(id={self.stream_id}, passthrough)"
else:
return (
f"HttpStream("
f"id={self.stream_id}, "
f"client_state={self.client_state.__name__}, "
f"server_state={self.server_state.__name__}"
f")"
)
@expect(events.Start, HttpEvent)
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
@ -608,6 +611,26 @@ class HttpLayer(layer.Layer):
elif isinstance(event, events.CommandCompleted):
stream = self.command_sources.pop(event.command)
yield from self.event_to_child(stream, event)
elif isinstance(event, events.MessageInjected):
# For injected messages we pass the HTTP stacks entirely and directly address the stream.
try:
conn = self.connections[event.flow.server_conn]
except KeyError:
# We have a miss for the server connection, which means we're looking at a connection object
# that is tunneled over another connection (for example: over an upstream HTTP proxy).
# We now take the stream associated with the client connection. That won't work for HTTP/2,
# but it's good enough for HTTP/1.
conn = self.connections[event.flow.client_conn]
if isinstance(conn, HttpStream):
stream_id = conn.stream_id
else:
# We reach to the end of the connection's child stack to get the HTTP/1 client layer,
# which tells us which stream we are dealing with.
conn = conn.context.layers[-1]
assert isinstance(conn, Http1Connection)
assert conn.stream_id
stream_id = conn.stream_id
yield from self.event_to_child(self.streams[stream_id], event)
elif isinstance(event, events.ConnectionEvent):
if event.connection == self.context.server and self.context.server not in self.connections:
# We didn't do anything with this connection yet, now the peer has closed it - let's close it too!
@ -745,6 +768,8 @@ class HttpLayer(layer.Layer):
class HttpClient(layer.Layer):
child_layer: layer.Layer
@expect(events.Start)
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
err: Optional[str]
@ -753,11 +778,10 @@ class HttpClient(layer.Layer):
else:
err = yield commands.OpenConnection(self.context.server)
if not err:
child_layer: layer.Layer
if self.context.server.alpn == b"h2":
child_layer = Http2Client(self.context)
self.child_layer = Http2Client(self.context)
else:
child_layer = Http1Client(self.context)
self._handle_event = child_layer.handle_event
self.child_layer = Http1Client(self.context)
self._handle_event = self.child_layer.handle_event
yield from self._handle_event(event)
yield RegisterHttpConnection(self.context.server, err)

View File

@ -6,6 +6,7 @@ from mitmproxy.proxy import commands, events, layer
from mitmproxy.proxy.commands import StartHook
from mitmproxy.connection import ConnectionState, Connection
from mitmproxy.proxy.context import Context
from mitmproxy.proxy.events import MessageInjected
from mitmproxy.proxy.utils import expect
@ -45,6 +46,12 @@ class TcpErrorHook(StartHook):
flow: tcp.TCPFlow
class TcpMessageInjected(MessageInjected[tcp.TCPMessage]):
"""
The user has injected a custom TCP message.
"""
class TCPLayer(layer.Layer):
"""
Simple TCP layer that just relays messages right now.
@ -76,8 +83,18 @@ class TCPLayer(layer.Layer):
_handle_event = start
@expect(events.DataReceived, events.ConnectionClosed)
def relay_messages(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
@expect(events.DataReceived, events.ConnectionClosed, TcpMessageInjected)
def relay_messages(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, TcpMessageInjected):
# we just spoof that we received data here and then process that regularly.
event = events.DataReceived(
self.context.client if event.message.from_client else self.context.server,
event.message.content,
)
assert isinstance(event, events.ConnectionEvent)
from_client = event.connection == self.context.client
send_to: Connection
if from_client:
@ -110,7 +127,9 @@ class TCPLayer(layer.Layer):
yield TcpEndHook(self.flow)
else:
yield commands.CloseConnection(send_to, half_close=True)
else:
raise AssertionError(f"Unexpected event: {event}")
@expect(events.DataReceived, events.ConnectionClosed)
@expect(events.DataReceived, events.ConnectionClosed, TcpMessageInjected)
def done(self, _) -> layer.CommandGenerator[None]:
yield from ()

View File

@ -1,3 +1,4 @@
import time
from dataclasses import dataclass
from typing import Iterator, List
@ -9,6 +10,7 @@ from mitmproxy import connection, flow, http, websocket
from mitmproxy.proxy import commands, events, layer
from mitmproxy.proxy.commands import StartHook
from mitmproxy.proxy.context import Context
from mitmproxy.proxy.events import MessageInjected
from mitmproxy.proxy.utils import expect
from wsproto import ConnectionState
from wsproto.frame_protocol import CloseReason, Opcode
@ -52,6 +54,12 @@ class WebsocketErrorHook(StartHook):
flow: http.HTTPFlow
class WebSocketMessageInjected(MessageInjected[websocket.WebSocketMessage]):
"""
The user has injected a custom WebSocket message.
"""
class WebsocketConnection(wsproto.Connection):
"""
A very thin wrapper around wsproto.Connection:
@ -120,11 +128,17 @@ class WebsocketLayer(layer.Layer):
_handle_event = start
@expect(events.DataReceived, events.ConnectionClosed)
def relay_messages(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
@expect(events.DataReceived, events.ConnectionClosed, WebSocketMessageInjected)
def relay_messages(self, event: events.Event) -> layer.CommandGenerator[None]:
assert self.flow.websocket # satisfy type checker
from_client = event.connection == self.context.client
if isinstance(event, events.ConnectionEvent):
from_client = event.connection == self.context.client
elif isinstance(event, WebSocketMessageInjected):
from_client = event.message.from_client
else:
raise AssertionError(f"Unexpected event: {event}")
from_str = 'client' if from_client else 'server'
if from_client:
src_ws = self.client_ws
@ -137,6 +151,11 @@ class WebsocketLayer(layer.Layer):
src_ws.receive_data(event.data)
elif isinstance(event, events.ConnectionClosed):
src_ws.receive_data(None)
elif isinstance(event, WebSocketMessageInjected):
fragmentizer = Fragmentizer([], event.message.type == Opcode.TEXT)
src_ws._events.extend(
fragmentizer(event.message.content)
)
else: # pragma: no cover
raise AssertionError(f"Unexpected event: {event}")
@ -171,6 +190,7 @@ class WebsocketLayer(layer.Layer):
)
yield dst_ws.send2(ws_event)
elif isinstance(ws_event, wsproto.events.CloseConnection):
self.flow.websocket.timestamp_end = time.time()
self.flow.websocket.closed_by_client = from_client
self.flow.websocket.close_code = ws_event.code
self.flow.websocket.close_reason = ws_event.reason
@ -189,7 +209,7 @@ class WebsocketLayer(layer.Layer):
else: # pragma: no cover
raise AssertionError(f"Unexpected WebSocket event: {ws_event}")
@expect(events.DataReceived, events.ConnectionClosed)
@expect(events.DataReceived, events.ConnectionClosed, WebSocketMessageInjected)
def done(self, _) -> layer.CommandGenerator[None]:
yield from ()
@ -219,7 +239,6 @@ class Fragmentizer:
FRAGMENT_SIZE = 4000
def __init__(self, fragments: List[bytes], is_text: bool):
assert fragments
self.fragment_lengths = [len(x) for x in fragments]
self.is_text = is_text

View File

@ -3,19 +3,18 @@ import sys
from functools import lru_cache
from typing import Optional, Union # noqa
import urwid
import mitmproxy.flow
import mitmproxy.tools.console.master # noqa
import urwid
from mitmproxy import contentviews
from mitmproxy import ctx
from mitmproxy import http
from mitmproxy import tcp
from mitmproxy.tools.console import common
from mitmproxy.tools.console import layoutwidget
from mitmproxy.tools.console import flowdetailview
from mitmproxy.tools.console import layoutwidget
from mitmproxy.tools.console import searchable
from mitmproxy.tools.console import tabs
import mitmproxy.tools.console.master # noqa
from mitmproxy.utils import strutils
@ -26,8 +25,8 @@ class SearchError(Exception):
class FlowViewHeader(urwid.WidgetWrap):
def __init__(
self,
master: "mitmproxy.tools.console.master.ConsoleMaster",
self,
master: "mitmproxy.tools.console.master.ConsoleMaster",
) -> None:
self.master = master
self.focus_changed()
@ -140,6 +139,9 @@ class FlowDetails(tabs.Tabs):
contentview_status_bar = urwid.AttrWrap(urwid.Columns(cols), "heading")
return contentview_status_bar
FROM_CLIENT_MARKER = ("from_client", f"{common.SYMBOL_FROM_CLIENT} ")
TO_CLIENT_MARKER = ("to_client", f"{common.SYMBOL_TO_CLIENT} ")
def view_websocket_messages(self):
flow = self.flow
assert isinstance(flow, http.HTTPFlow)
@ -156,12 +158,19 @@ class FlowDetails(tabs.Tabs):
for line in lines:
if m.from_client:
line.insert(0, ("from_client", f"{common.SYMBOL_FROM_CLIENT} "))
line.insert(0, self.FROM_CLIENT_MARKER)
else:
line.insert(0, ("to_client", f"{common.SYMBOL_TO_CLIENT} "))
line.insert(0, self.TO_CLIENT_MARKER)
widget_lines.append(urwid.Text(line))
if flow.websocket.closed_by_client is not None:
widget_lines.append(urwid.Text([
(self.FROM_CLIENT_MARKER if flow.websocket.closed_by_client else self.TO_CLIENT_MARKER),
("alert" if flow.websocket.close_code in (1000, 1001, 1005) else "error",
f"Connection closed: {flow.websocket.close_code} {flow.websocket.close_reason}")
]))
if flow.intercepted:
markup = widget_lines[-1].get_text()[0]
widget_lines[-1].set_text(("intercept", markup))
@ -198,9 +207,9 @@ class FlowDetails(tabs.Tabs):
for line in lines:
if from_client:
line.insert(0, ("from_client", f"{common.SYMBOL_FROM_CLIENT} "))
line.insert(0, self.FROM_CLIENT_MARKER)
else:
line.insert(0, ("to_client", f"{common.SYMBOL_TO_CLIENT} "))
line.insert(0, self.TO_CLIENT_MARKER)
widget_lines.append(urwid.Text(line))

View File

@ -106,11 +106,15 @@ class WebSocketData(stateobject.StateObject):
close_reason: Optional[str] = None
"""[Close Reason](https://tools.ietf.org/html/rfc6455#section-7.1.6)"""
timestamp_end: Optional[float] = None
"""*Timestamp:* WebSocket connection closed."""
_stateobject_attributes = dict(
messages=List[WebSocketMessage],
closed_by_client=bool,
close_code=int,
close_reason=str,
timestamp_end=float,
)
def __init__(self):

View File

@ -7,7 +7,7 @@ from mitmproxy.addons.proxyserver import Proxyserver
from mitmproxy.proxy.layers.http import HTTPMode
from mitmproxy.proxy import layers
from mitmproxy.connection import Address
from mitmproxy.test import taddons
from mitmproxy.test import taddons, tflow
class HelperAddon:
@ -15,12 +15,16 @@ class HelperAddon:
self.flows = []
self.layers = [
lambda ctx: layers.modes.HttpProxy(ctx),
lambda ctx: layers.HttpLayer(ctx, HTTPMode.regular)
lambda ctx: layers.HttpLayer(ctx, HTTPMode.regular),
lambda ctx: layers.TCPLayer(ctx),
]
def request(self, f):
self.flows.append(f)
def tcp_start(self, f):
self.flows.append(f)
def next_layer(self, nl):
nl.layer = self.layers.pop(0)(nl.context)
@ -59,6 +63,7 @@ async def test_start_stop():
req = f"GET http://{addr[0]}:{addr[1]}/hello HTTP/1.1\r\n\r\n"
writer.write(req.encode())
assert await reader.readuntil(b"\r\n\r\n") == b"HTTP/1.1 204 No Content\r\n\r\n"
assert repr(ps) == "ProxyServer(running, 1 active conns)"
tctx.configure(ps, server=False)
await tctx.master.await_log("Stopping server", level="info")
@ -67,6 +72,80 @@ async def test_start_stop():
assert state.flows[0].request.path == "/hello"
assert state.flows[0].response.status_code == 204
# Waiting here until everything is really torn down... takes some effort.
conn_handler = list(ps._connections.values())[0]
client_handler = conn_handler.transports[conn_handler.client].handler
writer.close()
await writer.wait_closed()
try:
await client_handler
except asyncio.CancelledError:
pass
for _ in range(5):
# Get all other scheduled coroutines to run.
await asyncio.sleep(0)
assert repr(ps) == "ProxyServer(stopped, 0 active conns)"
@pytest.mark.asyncio
async def test_inject():
async def server_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
while s := await reader.read(1):
writer.write(s.upper())
ps = Proxyserver()
with taddons.context(ps) as tctx:
state = HelperAddon()
tctx.master.addons.add(state)
async with tcp_server(server_handler) as addr:
tctx.configure(ps, listen_host="127.0.0.1", listen_port=0)
ps.running()
await tctx.master.await_log("Proxy server listening", level="info")
proxy_addr = ps.server.sockets[0].getsockname()[:2]
reader, writer = await asyncio.open_connection(*proxy_addr)
req = f"CONNECT {addr[0]}:{addr[1]} HTTP/1.1\r\n\r\n"
writer.write(req.encode())
assert await reader.readuntil(b"\r\n\r\n") == b"HTTP/1.1 200 Connection established\r\n\r\n"
writer.write(b"a")
assert await reader.read(1) == b"A"
ps.inject_tcp(state.flows[0], False, "b")
assert await reader.read(1) == b"B"
ps.inject_tcp(state.flows[0], True, "c")
assert await reader.read(1) == b"c"
@pytest.mark.asyncio
async def test_inject_fail():
ps = Proxyserver()
with taddons.context(ps) as tctx:
ps.inject_websocket(
tflow.tflow(),
True,
"test"
)
await tctx.master.await_log("Cannot inject WebSocket messages into non-WebSocket flows.", level="warn")
ps.inject_tcp(
tflow.tflow(),
True,
"test"
)
await tctx.master.await_log("Cannot inject TCP messages into non-TCP flows.", level="warn")
ps.inject_websocket(
tflow.twebsocketflow(),
True,
"test"
)
await tctx.master.await_log("Flow is not from a live connection.", level="warn")
ps.inject_websocket(
tflow.ttcpflow(),
True,
"test"
)
await tctx.master.await_log("Flow is not from a live connection.", level="warn")
@pytest.mark.asyncio
async def test_warn_no_nextlayer():

View File

@ -9,9 +9,9 @@ from mitmproxy.proxy.commands import CloseConnection, OpenConnection, SendData,
from mitmproxy.connection import ConnectionState, Server
from mitmproxy.proxy.events import ConnectionClosed, DataReceived
from mitmproxy.proxy.layers import TCPLayer, http, tls
from mitmproxy.proxy.layers.tcp import TcpStartHook
from mitmproxy.proxy.layers.tcp import TcpMessageInjected, TcpStartHook
from mitmproxy.proxy.layers.websocket import WebsocketStartHook
from mitmproxy.tcp import TCPFlow
from mitmproxy.tcp import TCPFlow, TCPMessage
from test.mitmproxy.proxy.tutils import Placeholder, Playbook, reply, reply_next_layer
@ -545,6 +545,7 @@ def test_upstream_proxy(tctx, redirect, scheme):
def test_http_proxy_tcp(tctx, mode, close_first):
"""Test TCP over HTTP CONNECT."""
server = Placeholder(Server)
f = Placeholder(TCPFlow)
if mode == "upstream":
tctx.options.mode = "upstream:http://proxy:8080"
@ -560,7 +561,9 @@ def test_http_proxy_tcp(tctx, mode, close_first):
<< SendData(tctx.client, b"HTTP/1.1 200 Connection established\r\n\r\n")
>> DataReceived(tctx.client, b"this is not http")
<< layer.NextLayerHook(Placeholder())
>> reply_next_layer(lambda ctx: TCPLayer(ctx, ignore=True))
>> reply_next_layer(lambda ctx: TCPLayer(ctx, ignore=False))
<< TcpStartHook(f)
>> reply()
<< OpenConnection(server)
)
@ -581,6 +584,12 @@ def test_http_proxy_tcp(tctx, mode, close_first):
else:
assert server().address == ("proxy", 8080)
assert (
playbook
>> TcpMessageInjected(f, TCPMessage(False, b"fake news from your friendly man-in-the-middle"))
<< SendData(tctx.client, b"fake news from your friendly man-in-the-middle")
)
if close_first == "client":
a, b = tctx.client, server
else:

View File

@ -4,7 +4,8 @@ from mitmproxy.proxy.commands import CloseConnection, OpenConnection, SendData
from mitmproxy.connection import ConnectionState
from mitmproxy.proxy.events import ConnectionClosed, DataReceived
from mitmproxy.proxy.layers import tcp
from mitmproxy.tcp import TCPFlow
from mitmproxy.proxy.layers.tcp import TcpMessageInjected
from mitmproxy.tcp import TCPFlow, TCPMessage
from ..tutils import Placeholder, Playbook, reply
@ -122,3 +123,27 @@ def test_ignore(tctx, ignore):
else:
with pytest.raises(AssertionError):
no_flow_hooks()
def test_inject(tctx):
"""inject data into an open connection."""
f = Placeholder(TCPFlow)
assert (
Playbook(tcp.TCPLayer(tctx))
<< tcp.TcpStartHook(f)
>> TcpMessageInjected(f, TCPMessage(True, b"hello!"))
>> reply(to=-2)
<< OpenConnection(tctx.server)
>> reply(None)
<< tcp.TcpMessageHook(f)
>> reply()
<< SendData(tctx.server, b"hello!")
# and the other way...
>> TcpMessageInjected(f, TCPMessage(False, b"I have already done the greeting for you."))
<< tcp.TcpMessageHook(f)
>> reply()
<< SendData(tctx.client, b"I have already done the greeting for you.")
<< None
)
assert len(f().messages) == 2

View File

@ -11,7 +11,8 @@ from mitmproxy.proxy.commands import SendData, CloseConnection, Log
from mitmproxy.connection import ConnectionState
from mitmproxy.proxy.events import DataReceived, ConnectionClosed
from mitmproxy.proxy.layers import http, websocket
from mitmproxy.websocket import WebSocketData
from mitmproxy.proxy.layers.websocket import WebSocketMessageInjected
from mitmproxy.websocket import WebSocketData, WebSocketMessage
from test.mitmproxy.proxy.tutils import Placeholder, Playbook, reply
from wsproto.frame_protocol import Opcode
@ -319,3 +320,21 @@ class TestFragmentizer:
wsproto.events.BytesMessage(b"foob", message_finished=False),
wsproto.events.BytesMessage(b"ar", message_finished=True),
]
def test_inject_message(ws_testdata):
tctx, playbook, flow = ws_testdata
assert (
playbook
<< websocket.WebsocketStartHook(flow)
>> reply()
>> WebSocketMessageInjected(flow, WebSocketMessage(Opcode.TEXT, False, b"hello"))
<< websocket.WebsocketMessageHook(flow)
)
assert flow.websocket.messages[-1].content == b"hello"
assert flow.websocket.messages[-1].from_client is False
assert (
playbook
>> reply()
<< SendData(tctx.client, b"\x81\x05hello")
)

View File

@ -1,11 +1,26 @@
import pytest
from mitmproxy.proxy import commands, events, layer
from mitmproxy.proxy.context import Context
from test.mitmproxy.proxy import tutils
class TestLayer:
def test_debug_messages(self, tctx):
def test_continue(self, tctx: Context):
class TLayer(layer.Layer):
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
yield commands.OpenConnection(self.context.server)
yield commands.OpenConnection(self.context.server)
assert (
tutils.Playbook(TLayer(tctx))
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
<< commands.OpenConnection(tctx.server)
>> tutils.reply(None)
)
def test_debug_messages(self, tctx: Context):
tctx.server.id = "serverid"
class TLayer(layer.Layer):
@ -53,7 +68,7 @@ class TestLayer:
class TestNextLayer:
def test_simple(self, tctx):
def test_simple(self, tctx: Context):
nl = layer.NextLayer(tctx, ask_on_start=True)
nl.debug = " "
playbook = tutils.Playbook(nl, hooks=True)
@ -79,7 +94,7 @@ class TestNextLayer:
<< commands.SendData(tctx.client, b"bar")
)
def test_late_hook_reply(self, tctx):
def test_late_hook_reply(self, tctx: Context):
"""
Properly handle case where we receive an additional event while we are waiting for
a reply from the proxy core.
@ -104,7 +119,7 @@ class TestNextLayer:
)
@pytest.mark.parametrize("layer_found", [True, False])
def test_receive_close(self, tctx, layer_found):
def test_receive_close(self, tctx: Context, layer_found: bool):
"""Test that we abort a client connection which has disconnected without any layer being found."""
nl = layer.NextLayer(tctx)
playbook = tutils.Playbook(nl)
@ -128,7 +143,7 @@ class TestNextLayer:
<< commands.CloseConnection(tctx.client)
)
def test_func_references(self, tctx):
def test_func_references(self, tctx: Context):
nl = layer.NextLayer(tctx)
playbook = tutils.Playbook(nl)
@ -147,7 +162,9 @@ class TestNextLayer:
sd, = handle(events.DataReceived(tctx.client, b"bar"))
assert isinstance(sd, commands.SendData)
def test_repr(self, tctx):
def test_repr(self, tctx: Context):
nl = layer.NextLayer(tctx)
nl.layer = tutils.EchoLayer(tctx)
assert repr(nl)
assert nl.stack_pos
assert nl.layer.stack_pos