Merge pull request #4486 from mhils/websocket

Merge WebSocketFlow into HTTPFlow, add WebSocket UI
This commit is contained in:
Maximilian Hils 2021-03-11 11:02:40 +01:00 committed by GitHub
commit 70223163de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 494 additions and 610 deletions

View File

@ -5,17 +5,17 @@ from mitmproxy import ctx
def websocket_message(flow):
# get the latest message
message = flow.messages[-1]
message = flow.websocket.messages[-1]
# was the message sent from the client or server?
if message.from_client:
ctx.log.info(f"Client sent a message: {message.content}")
ctx.log.info(f"Client sent a message: {message.content!r}")
else:
ctx.log.info(f"Server sent a message: {message.content}")
ctx.log.info(f"Server sent a message: {message.content!r}")
# manipulate the message content
message.content = re.sub(r'^Hello', 'HAPPY', message.content)
message.content = re.sub(rb'^Hello', b'HAPPY', message.content)
if 'FOOBAR' in message.content:
if b'FOOBAR' in message.content:
# kill the message and not send it to the other endpoint
message.content = ""

View File

@ -145,6 +145,8 @@ class ClientPlayback:
return "Can't replay flow with missing request."
if f.request.raw_content is None:
return "Can't replay flow with missing content."
if f.websocket is not None:
return "Can't replay WebSocket flows."
else:
return "Can only replay HTTP flows."
return None

View File

@ -13,7 +13,7 @@ from mitmproxy import http
from mitmproxy.tcp import TCPFlow, TCPMessage
from mitmproxy.utils import human
from mitmproxy.utils import strutils
from mitmproxy.websocket import WebSocketFlow, WebSocketMessage
from mitmproxy.websocket import WebSocketMessage
def indent(n: int, text: str) -> str:
@ -98,7 +98,7 @@ class Dumper:
def _echo_message(
self,
message: Union[http.Message, TCPMessage, WebSocketMessage],
flow: Union[http.HTTPFlow, TCPFlow, WebSocketFlow]
flow: Union[http.HTTPFlow, TCPFlow]
):
_, lines, error = contentviews.get_message_content_view(
ctx.options.dumper_default_contentview,
@ -277,37 +277,36 @@ class Dumper:
if self.match(f):
self.echo_flow(f)
def websocket_error(self, f):
def websocket_error(self, f: http.HTTPFlow):
self.echo_error(
"Error in WebSocket connection to {}: {}".format(
human.format_address(f.server_conn.address), f.error
),
f"Error in WebSocket connection to {human.format_address(f.server_conn.address)}: {f.error}",
fg="red"
)
def websocket_message(self, f):
def websocket_message(self, f: http.HTTPFlow):
assert f.websocket is not None # satisfy type checker
if self.match(f):
message = f.messages[-1]
self.echo(f.message_info(message))
message = f.websocket.messages[-1]
direction = "->" if message.from_client else "<-"
self.echo(
f"{human.format_address(f.client_conn.peername)} "
f"{direction} WebSocket {message.type.name.lower()} message "
f"{direction} {human.format_address(f.server_conn.address)}{f.request.path}"
)
if ctx.options.flow_detail >= 3:
message = message.from_state(message.get_state())
message.content = message.content.encode() if isinstance(message.content, str) else message.content
self._echo_message(message, f)
def websocket_end(self, f):
def websocket_end(self, f: http.HTTPFlow):
assert f.websocket is not None # satisfy type checker
if self.match(f):
self.echo("WebSocket connection closed by {}: {} {}, {}".format(
f.close_sender,
f.close_code,
f.close_message,
f.close_reason))
c = 'client' if f.websocket.closed_by_client else 'server'
self.echo(f"WebSocket connection closed by {c}: {f.websocket.close_code} {f.websocket.close_reason}")
def tcp_error(self, f):
if self.match(f):
self.echo_error(
"Error in TCP connection to {}: {}".format(
human.format_address(f.server_conn.address), f.error
),
f"Error in TCP connection to {human.format_address(f.server_conn.address)}: {f.error}",
fg="red"
)

View File

@ -7,6 +7,7 @@ from mitmproxy import flowfilter
from mitmproxy import io
from mitmproxy import ctx
from mitmproxy import flow
from mitmproxy import http
import mitmproxy.types
@ -88,28 +89,26 @@ class Save:
def tcp_error(self, flow):
self.tcp_end(flow)
def websocket_start(self, flow):
if self.stream:
self.active_flows.add(flow)
def websocket_end(self, flow):
def websocket_end(self, flow: http.HTTPFlow):
if self.stream:
self.stream.add(flow)
self.active_flows.discard(flow)
def websocket_error(self, flow):
def websocket_error(self, flow: http.HTTPFlow):
self.websocket_end(flow)
def request(self, flow):
def request(self, flow: http.HTTPFlow):
if self.stream:
self.active_flows.add(flow)
def response(self, flow):
if self.stream:
def response(self, flow: http.HTTPFlow):
# websocket flows will receive either websocket_end or websocket_error,
# we don't want to persist them here already
if self.stream and flow.websocket is None:
self.stream.add(flow)
self.active_flows.discard(flow)
def error(self, flow):
def error(self, flow: http.HTTPFlow):
self.response(flow)
def done(self):

View File

@ -25,7 +25,7 @@ from . import (
from .base import View, KEY_MAX, format_text, format_dict, TViewResult
from ..http import HTTPFlow
from ..tcp import TCPMessage, TCPFlow
from ..websocket import WebSocketMessage, WebSocketFlow
from ..websocket import WebSocketMessage
views: List[View] = []
@ -67,7 +67,7 @@ def safe_to_print(lines, encoding="utf8"):
def get_message_content_view(
viewname: str,
message: Union[http.Message, TCPMessage, WebSocketMessage],
flow: Union[HTTPFlow, TCPFlow, WebSocketFlow],
flow: Union[HTTPFlow, TCPFlow],
):
"""
Like get_content_view, but also handles message encoding.
@ -79,7 +79,7 @@ def get_message_content_view(
content: Optional[bytes]
try:
content = message.content # type: ignore
content = message.content
except ValueError:
assert isinstance(message, http.Message)
content = message.raw_content

View File

@ -1,11 +1,10 @@
from typing import Iterator, Any, Dict, Type, Callable
from typing import Any, Callable, Dict, Iterator, Type
from mitmproxy import controller
from mitmproxy import hooks
from mitmproxy import flow
from mitmproxy import hooks
from mitmproxy import http
from mitmproxy import tcp
from mitmproxy import websocket
from mitmproxy.proxy import layers
TEventGenerator = Iterator[hooks.Hook]
@ -18,24 +17,21 @@ def _iterate_http(f: http.HTTPFlow) -> TEventGenerator:
if f.response:
yield layers.http.HttpResponseHeadersHook(f)
yield layers.http.HttpResponseHook(f)
if f.error:
if f.websocket:
message_queue = f.websocket.messages
f.websocket.messages = []
yield layers.websocket.WebsocketStartHook(f)
for m in message_queue:
f.websocket.messages.append(m)
yield layers.websocket.WebsocketMessageHook(f)
if f.error:
yield layers.websocket.WebsocketErrorHook(f)
else:
yield layers.websocket.WebsocketEndHook(f)
elif f.error:
yield layers.http.HttpErrorHook(f)
def _iterate_websocket(f: websocket.WebSocketFlow) -> TEventGenerator:
messages = f.messages
f.messages = []
f.reply = controller.DummyReply()
yield layers.websocket.WebsocketStartHook(f)
while messages:
f.messages.append(messages.pop(0))
yield layers.websocket.WebsocketMessageHook(f)
if f.error:
yield layers.websocket.WebsocketErrorHook(f)
else:
yield layers.websocket.WebsocketEndHook(f)
def _iterate_tcp(f: tcp.TCPFlow) -> TEventGenerator:
messages = f.messages
f.messages = []
@ -52,7 +48,6 @@ def _iterate_tcp(f: tcp.TCPFlow) -> TEventGenerator:
_iterate_map: Dict[Type[flow.Flow], Callable[[Any], TEventGenerator]] = {
http.HTTPFlow: _iterate_http,
websocket.WebSocketFlow: _iterate_websocket,
tcp.TCPFlow: _iterate_tcp,
}

View File

@ -39,8 +39,7 @@ from typing import Callable, ClassVar, Optional, Sequence, Type
import pyparsing as pp
from mitmproxy import flow, http, tcp, websocket
from mitmproxy.net.websocket import check_handshake
from mitmproxy import flow, http, tcp
def only(*types):
@ -102,15 +101,11 @@ class FHTTP(_Action):
class FWebSocket(_Action):
code = "websocket"
help = "Match WebSocket flows (and HTTP-WebSocket handshake flows)"
help = "Match WebSocket flows"
@only(http.HTTPFlow, websocket.WebSocketFlow)
def __call__(self, f):
m = (
(isinstance(f, http.HTTPFlow) and f.request and check_handshake(f.request.headers))
or isinstance(f, websocket.WebSocketFlow)
)
return m
@only(http.HTTPFlow)
def __call__(self, f: http.HTTPFlow):
return f.websocket is not None
class FTCP(_Action):
@ -258,7 +253,7 @@ class FBod(_Rex):
help = "Body"
flags = re.DOTALL
@only(http.HTTPFlow, websocket.WebSocketFlow, tcp.TCPFlow)
@only(http.HTTPFlow, tcp.TCPFlow)
def __call__(self, f):
if isinstance(f, http.HTTPFlow):
if f.request and f.request.raw_content:
@ -267,7 +262,11 @@ class FBod(_Rex):
if f.response and f.response.raw_content:
if self.re.search(f.response.get_content(strict=False)):
return True
elif isinstance(f, websocket.WebSocketFlow) or isinstance(f, tcp.TCPFlow):
if f.websocket:
for msg in f.websocket.messages:
if self.re.search(msg.content):
return True
elif isinstance(f, tcp.TCPFlow):
for msg in f.messages:
if self.re.search(msg.content):
return True
@ -279,13 +278,17 @@ class FBodRequest(_Rex):
help = "Request body"
flags = re.DOTALL
@only(http.HTTPFlow, websocket.WebSocketFlow, tcp.TCPFlow)
@only(http.HTTPFlow, tcp.TCPFlow)
def __call__(self, f):
if isinstance(f, http.HTTPFlow):
if f.request and f.request.raw_content:
if self.re.search(f.request.get_content(strict=False)):
return True
elif isinstance(f, websocket.WebSocketFlow) or isinstance(f, tcp.TCPFlow):
if f.websocket:
for msg in f.websocket.messages:
if msg.from_client and self.re.search(msg.content):
return True
elif isinstance(f, tcp.TCPFlow):
for msg in f.messages:
if msg.from_client and self.re.search(msg.content):
return True
@ -296,13 +299,17 @@ class FBodResponse(_Rex):
help = "Response body"
flags = re.DOTALL
@only(http.HTTPFlow, websocket.WebSocketFlow, tcp.TCPFlow)
@only(http.HTTPFlow, tcp.TCPFlow)
def __call__(self, f):
if isinstance(f, http.HTTPFlow):
if f.response and f.response.raw_content:
if self.re.search(f.response.get_content(strict=False)):
return True
elif isinstance(f, websocket.WebSocketFlow) or isinstance(f, tcp.TCPFlow):
if f.websocket:
for msg in f.websocket.messages:
if not msg.from_client and self.re.search(msg.content):
return True
elif isinstance(f, tcp.TCPFlow):
for msg in f.messages:
if not msg.from_client and self.re.search(msg.content):
return True
@ -324,10 +331,8 @@ class FDomain(_Rex):
flags = re.IGNORECASE
is_binary = False
@only(http.HTTPFlow, websocket.WebSocketFlow)
@only(http.HTTPFlow)
def __call__(self, f):
if isinstance(f, websocket.WebSocketFlow):
f = f.handshake_flow
return bool(
self.re.search(f.request.host) or
self.re.search(f.request.pretty_host)
@ -347,10 +352,8 @@ class FUrl(_Rex):
toks = toks[1:]
return klass(*toks)
@only(http.HTTPFlow, websocket.WebSocketFlow)
@only(http.HTTPFlow)
def __call__(self, f):
if isinstance(f, websocket.WebSocketFlow):
f = f.handshake_flow
if not f or not f.request:
return False
return self.re.search(f.request.pretty_url)
@ -482,9 +485,9 @@ def _make():
unicode_words = pp.CharsNotIn("()~'\"" + pp.ParserElement.DEFAULT_WHITE_CHARS)
unicode_words.skipWhitespace = True
regex = (
unicode_words
| pp.QuotedString('"', escChar='\\')
| pp.QuotedString("'", escChar='\\')
unicode_words
| pp.QuotedString('"', escChar='\\')
| pp.QuotedString("'", escChar='\\')
)
for cls in filter_rex:
f = pp.Literal(f"~{cls.code}") + pp.WordEnd() + regex.copy()

View File

@ -18,6 +18,7 @@ from typing import Union
from typing import cast
from mitmproxy import flow
from mitmproxy.websocket import WebSocketData
from mitmproxy.coretypes import multidict
from mitmproxy.coretypes import serializable
from mitmproxy.net import encoding
@ -1169,6 +1170,11 @@ class HTTPFlow(flow.Flow):
from the server, but there was an error sending it back to the client.
"""
websocket: Optional[WebSocketData] = None
"""
If this HTTP flow initiated a WebSocket connection, this attribute contains all associated WebSocket data.
"""
def __init__(self, client_conn, server_conn, live=None, mode="regular"):
super().__init__("http", client_conn, server_conn, live)
self.mode = mode
@ -1178,12 +1184,13 @@ class HTTPFlow(flow.Flow):
_stateobject_attributes.update(dict(
request=Request,
response=Response,
websocket=WebSocketData,
mode=str
))
def __repr__(self):
s = "<HTTPFlow"
for a in ("request", "response", "error", "client_conn", "server_conn"):
for a in ("request", "response", "websocket", "error", "client_conn", "server_conn"):
if getattr(self, a, False):
s += f"\r\n {a} = {{flow.{a}}}"
s += ">"

View File

@ -250,6 +250,55 @@ def convert_10_11(data):
return data
_websocket_handshakes = {}
def convert_11_12(data):
data["version"] = 12
if "websocket" in data["metadata"]:
_websocket_handshakes[data["id"]] = data
if "websocket_handshake" in data["metadata"]:
ws_flow = data
try:
data = _websocket_handshakes.pop(data["metadata"]["websocket_handshake"])
except KeyError:
# The handshake flow is missing, which should never really happen. We make up a dummy.
data = {
'client_conn': data["client_conn"],
'error': data["error"],
'id': data["id"],
'intercepted': data["intercepted"],
'is_replay': data["is_replay"],
'marked': data["marked"],
'metadata': {},
'mode': 'transparent',
'request': {'authority': b'', 'content': None, 'headers': [], 'host': b'unknown',
'http_version': b'HTTP/1.1', 'method': b'GET', 'path': b'/', 'port': 80, 'scheme': b'http',
'timestamp_end': 0, 'timestamp_start': 0, 'trailers': None, },
'response': None,
'server_conn': data["server_conn"],
'type': 'http',
'version': 12
}
data["metadata"]["duplicated"] = (
"This WebSocket flow has been migrated from an old file format version "
"and may appear duplicated."
)
data["websocket"] = {
"messages": ws_flow["messages"],
"closed_by_client": ws_flow["close_sender"] == "client",
"close_code": ws_flow["close_code"],
"close_reason": ws_flow["close_reason"],
}
else:
data["websocket"] = None
return data
def _convert_dict_keys(o: Any) -> Any:
if isinstance(o, dict):
return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}
@ -308,6 +357,7 @@ converters = {
8: convert_8_9,
9: convert_9_10,
10: convert_10_11,
11: convert_11_12,
}
@ -325,8 +375,8 @@ def migrate_flow(flow_data: Dict[Union[bytes, str], Any]) -> Dict[Union[bytes, s
flow_data = converters[flow_version](flow_data)
else:
should_upgrade = (
isinstance(flow_version, int)
and flow_version > version.FLOW_FORMAT_VERSION
isinstance(flow_version, int)
and flow_version > version.FLOW_FORMAT_VERSION
)
raise ValueError(
"{} cannot read files with flow format version {}{}.".format(

View File

@ -1,19 +1,16 @@
import os
from typing import Type, Iterable, Dict, Union, Any, cast # noqa
from typing import Any, Dict, Iterable, Type, Union, cast # noqa
from mitmproxy import exceptions
from mitmproxy import flow
from mitmproxy import flowfilter
from mitmproxy import http
from mitmproxy import tcp
from mitmproxy import websocket
from mitmproxy.io import compat
from mitmproxy.io import tnetstring
FLOW_TYPES: Dict[str, Type[flow.Flow]] = dict(
http=http.HTTPFlow,
websocket=websocket.WebSocketFlow,
tcp=tcp.TCPFlow,
)

View File

@ -11,7 +11,6 @@ from mitmproxy import eventsequence
from mitmproxy import http
from mitmproxy import log
from mitmproxy import options
from mitmproxy import websocket
from mitmproxy.net import server_spec
from . import ctx as mitmproxy_ctx
@ -34,7 +33,6 @@ class Master:
self.commands = command.CommandManager(self)
self.addons = addonmanager.AddonManager(self)
self._server = None
self.waiting_flows = []
self.log = log.Log(self)
mitmproxy_ctx.master = self
@ -111,24 +109,11 @@ class Master:
async def load_flow(self, f):
"""
Loads a flow and links websocket & handshake flows
Loads a flow
"""
if isinstance(f, http.HTTPFlow):
self._change_reverse_host(f)
if 'websocket' in f.metadata:
self.waiting_flows.append(f)
if isinstance(f, websocket.WebSocketFlow):
hfs = [hf for hf in self.waiting_flows if hf.id == f.metadata['websocket_handshake']]
if hfs:
hf = hfs[0]
f.handshake_flow = hf
self.waiting_flows.remove(hf)
self._change_reverse_host(f.handshake_flow)
else:
# this will fail - but at least it will load the remaining flows
f.handshake_flow = http.HTTPFlow(None, None)
f.reply = controller.DummyReply()
for e in eventsequence.iterate(f):

View File

@ -1,28 +0,0 @@
"""
Collection of WebSocket protocol utility functions (RFC6455)
Spec: https://tools.ietf.org/html/rfc6455
"""
def check_handshake(headers):
return (
"upgrade" in headers.get("connection", "").lower() and
headers.get("upgrade", "").lower() == "websocket" and
(headers.get("sec-websocket-key") is not None or headers.get("sec-websocket-accept") is not None)
)
def get_extensions(headers):
return headers.get("sec-websocket-extensions", None)
def get_protocol(headers):
return headers.get("sec-websocket-protocol", None)
def get_client_key(headers):
return headers.get("sec-websocket-key", None)
def get_server_accept(headers):
return headers.get("sec-websocket-accept", None)

View File

@ -4,6 +4,8 @@ import time
from dataclasses import dataclass
from typing import DefaultDict, Dict, List, Optional, Tuple, Union
import wsproto.handshake
from mitmproxy import flow, http
from mitmproxy.connection import Connection, Server
from mitmproxy.net import server_spec
@ -13,6 +15,7 @@ from mitmproxy.proxy.layers import tcp, tls, websocket
from mitmproxy.proxy.layers.http import _upstream_proxy
from mitmproxy.proxy.utils import expect
from mitmproxy.utils import human
from mitmproxy.websocket import WebSocketData
from ._base import HttpCommand, HttpConnection, ReceiveHttp, StreamId
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
@ -308,6 +311,21 @@ class HttpStream(layer.Layer):
"""We have either consumed the entire response from the server or the response was set by an addon."""
assert self.flow.response
self.flow.response.timestamp_end = time.time()
is_websocket = (
self.flow.response.status_code == 101
and
self.flow.response.headers.get("upgrade", "").lower() == "websocket"
and
self.flow.request.headers.get("Sec-WebSocket-Version", "").encode() == wsproto.handshake.WEBSOCKET_VERSION
and
self.context.options.websocket
)
if is_websocket:
# We need to set this before calling the response hook
# so that addons can determine if a WebSocket connection is following up.
self.flow.websocket = WebSocketData()
yield HttpResponseHook(self.flow)
self.server_state = self.state_done
if (yield from self.check_killed(False)):
@ -322,12 +340,7 @@ class HttpStream(layer.Layer):
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
if self.flow.response.status_code == 101:
is_websocket = (
self.flow.response.headers.get("upgrade", "").lower() == "websocket"
and
self.flow.request.headers.get("Sec-WebSocket-Version", "") == "13"
)
if is_websocket and self.context.options.websocket:
if is_websocket:
self.child_layer = websocket.WebsocketLayer(self.context, self.flow)
elif self.context.options.rawtcp:
self.child_layer = tcp.TCPLayer(self.context)

View File

@ -1,11 +1,11 @@
from dataclasses import dataclass
from typing import Union, List, Iterator
from typing import Iterator, List
import wsproto
import wsproto.extensions
import wsproto.frame_protocol
import wsproto.utilities
from mitmproxy import flow, websocket, http, connection
from mitmproxy import connection, flow, http, websocket
from mitmproxy.proxy import commands, events, layer
from mitmproxy.proxy.commands import StartHook
from mitmproxy.proxy.context import Context
@ -19,7 +19,7 @@ class WebsocketStartHook(StartHook):
"""
A WebSocket connection has commenced.
"""
flow: websocket.WebSocketFlow
flow: http.HTTPFlow
@dataclass
@ -30,7 +30,7 @@ class WebsocketMessageHook(StartHook):
message is user-modifiable. Currently there are two types of
messages, corresponding to the BINARY and TEXT frame types.
"""
flow: websocket.WebSocketFlow
flow: http.HTTPFlow
@dataclass
@ -39,7 +39,7 @@ class WebsocketEndHook(StartHook):
A WebSocket connection has ended.
"""
flow: websocket.WebSocketFlow
flow: http.HTTPFlow
@dataclass
@ -49,7 +49,7 @@ class WebsocketErrorHook(StartHook):
Every WebSocket flow will receive either a websocket_error or a websocket_end event, but not both.
"""
flow: websocket.WebSocketFlow
flow: http.HTTPFlow
class WebsocketConnection(wsproto.Connection):
@ -61,7 +61,7 @@ class WebsocketConnection(wsproto.Connection):
- we wrap .send() so that we can directly yield it.
"""
conn: connection.Connection
frame_buf: List[Union[str, bytes]]
frame_buf: List[bytes]
def __init__(self, *args, conn: connection.Connection, **kwargs):
super(WebsocketConnection, self).__init__(*args, **kwargs)
@ -80,13 +80,13 @@ class WebsocketLayer(layer.Layer):
"""
WebSocket layer that intercepts and relays messages.
"""
flow: websocket.WebSocketFlow
flow: http.HTTPFlow
client_ws: WebsocketConnection
server_ws: WebsocketConnection
def __init__(self, context: Context, handshake_flow: http.HTTPFlow):
def __init__(self, context: Context, flow: http.HTTPFlow):
super().__init__(context)
self.flow = websocket.WebSocketFlow(context.client, context.server, handshake_flow)
self.flow = flow
assert context.server.connected
@expect(events.Start)
@ -96,7 +96,8 @@ class WebsocketLayer(layer.Layer):
server_extensions = []
# Parse extension headers. We only support deflate at the moment and ignore everything else.
ext_header = self.flow.handshake_flow.response.headers.get("Sec-WebSocket-Extensions", "")
assert self.flow.response # satisfy type checker
ext_header = self.flow.response.headers.get("Sec-WebSocket-Extensions", "")
if ext_header:
for ext in wsproto.utilities.split_comma_header(ext_header.encode("ascii", "replace")):
ext_name = ext.split(";", 1)[0].strip()
@ -115,15 +116,14 @@ class WebsocketLayer(layer.Layer):
yield WebsocketStartHook(self.flow)
if self.flow.stream: # pragma: no cover
raise NotImplementedError("WebSocket streaming is not supported at the moment.")
self._handle_event = self.relay_messages
_handle_event = start
@expect(events.DataReceived, events.ConnectionClosed)
def relay_messages(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
assert self.flow.websocket # satisfy type checker
from_client = event.connection == self.context.client
from_str = 'client' if from_client else 'server'
if from_client:
@ -142,27 +142,27 @@ class WebsocketLayer(layer.Layer):
for ws_event in src_ws.events():
if isinstance(ws_event, wsproto.events.Message):
src_ws.frame_buf.append(ws_event.data)
is_text = isinstance(ws_event.data, str)
if is_text:
typ = Opcode.TEXT
src_ws.frame_buf.append(ws_event.data.encode())
else:
typ = Opcode.BINARY
src_ws.frame_buf.append(ws_event.data)
if ws_event.message_finished:
if isinstance(ws_event, wsproto.events.TextMessage):
frame_type = Opcode.TEXT
content = "".join(src_ws.frame_buf) # type: ignore
else:
frame_type = Opcode.BINARY
content = b"".join(src_ws.frame_buf) # type: ignore
content = b"".join(src_ws.frame_buf)
fragmentizer = Fragmentizer(src_ws.frame_buf)
fragmentizer = Fragmentizer(src_ws.frame_buf, is_text)
src_ws.frame_buf.clear()
message = websocket.WebSocketMessage(frame_type, from_client, content)
self.flow.messages.append(message)
message = websocket.WebSocketMessage(typ, from_client, content)
self.flow.websocket.messages.append(message)
yield WebsocketMessageHook(self.flow)
assert not message.killed # this is deprecated, instead we should have .content set to emptystr.
for msg in fragmentizer(message.content):
yield dst_ws.send2(msg)
if not message.killed:
for msg in fragmentizer(message.content):
yield dst_ws.send2(msg)
elif isinstance(ws_event, (wsproto.events.Ping, wsproto.events.Pong)):
yield commands.Log(
@ -171,9 +171,9 @@ class WebsocketLayer(layer.Layer):
)
yield dst_ws.send2(ws_event)
elif isinstance(ws_event, wsproto.events.CloseConnection):
self.flow.close_sender = from_str
self.flow.close_code = ws_event.code
self.flow.close_reason = ws_event.reason
self.flow.websocket.closed_by_client = from_client
self.flow.websocket.close_code = ws_event.code
self.flow.websocket.close_reason = ws_event.reason
for ws in [self.server_ws, self.client_ws]:
if ws.state in {ConnectionState.OPEN, ConnectionState.REMOTE_CLOSING}:
@ -215,27 +215,35 @@ class Fragmentizer:
As a workaround, we either retain the original chunking or, if the payload has been modified, use ~4kB chunks.
"""
# A bit less than 4kb to accomodate for headers.
# A bit less than 4kb to accommodate for headers.
FRAGMENT_SIZE = 4000
def __init__(self, fragments: List[Union[str, bytes]]):
def __init__(self, fragments: List[bytes], is_text: bool):
assert fragments
self.fragment_lengths = [len(x) for x in fragments]
self.is_text = is_text
def __call__(self, content: Union[str, bytes]) -> Iterator[wsproto.events.Message]:
def msg(self, data: bytes, message_finished: bool):
if self.is_text:
data_str = data.decode(errors="replace")
return wsproto.events.TextMessage(data_str, message_finished=message_finished)
else:
return wsproto.events.BytesMessage(data, message_finished=message_finished)
def __call__(self, content: bytes) -> Iterator[wsproto.events.Message]:
if not content:
return
if len(content) == sum(self.fragment_lengths):
# message has the same length, we can reuse the same sizes
offset = 0
for fl in self.fragment_lengths[:-1]:
yield wsproto.events.Message(content[offset:offset + fl], message_finished=False)
yield self.msg(content[offset:offset + fl], False)
offset += fl
yield wsproto.events.Message(content[offset:], message_finished=True)
yield self.msg(content[offset:], True)
else:
offset = 0
total = len(content) - self.FRAGMENT_SIZE
while offset < total:
yield wsproto.events.Message(content[offset:offset + self.FRAGMENT_SIZE], message_finished=False)
yield self.msg(content[offset:offset + self.FRAGMENT_SIZE], False)
offset += self.FRAGMENT_SIZE
yield wsproto.events.Message(content[offset:], message_finished=True)
yield self.msg(content[offset:], True)

View File

@ -6,7 +6,6 @@ from mitmproxy import flow
from mitmproxy import http
from mitmproxy import tcp
from mitmproxy import websocket
from mitmproxy.net.http import status_codes
from mitmproxy.test import tutils
from wsproto.frame_protocol import Opcode
@ -31,68 +30,55 @@ def ttcpflow(client_conn=True, server_conn=True, messages=True, err=None):
return f
def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None, handshake_flow=True):
if client_conn is True:
client_conn = tclient_conn()
if server_conn is True:
server_conn = tserver_conn()
if handshake_flow is True:
req = http.Request(
"example.com",
80,
b"GET",
b"http",
b"example.com",
b"/ws",
b"HTTP/1.1",
headers=http.Headers(
connection="upgrade",
upgrade="websocket",
sec_websocket_version="13",
sec_websocket_key="1234",
),
content=b'',
trailers=None,
timestamp_start=946681200,
timestamp_end=946681201,
def twebsocketflow(messages=True, err=None) -> http.HTTPFlow:
flow = http.HTTPFlow(tclient_conn(), tserver_conn())
flow.request = http.Request(
"example.com",
80,
b"GET",
b"http",
b"example.com",
b"/ws",
b"HTTP/1.1",
headers=http.Headers(
connection="upgrade",
upgrade="websocket",
sec_websocket_version="13",
sec_websocket_key="1234",
),
content=b'',
trailers=None,
timestamp_start=946681200,
timestamp_end=946681201,
)
resp = http.Response(
b"HTTP/1.1",
101,
reason=status_codes.RESPONSES.get(101),
headers=http.Headers(
connection='upgrade',
upgrade='websocket',
sec_websocket_accept=b'',
),
content=b'',
trailers=None,
timestamp_start=946681202,
timestamp_end=946681203,
)
handshake_flow = http.HTTPFlow(client_conn, server_conn)
handshake_flow.request = req
handshake_flow.response = resp
f = websocket.WebSocketFlow(client_conn, server_conn, handshake_flow)
f.metadata['websocket_handshake'] = handshake_flow.id
handshake_flow.metadata['websocket_flow'] = f.id
handshake_flow.metadata['websocket'] = True
)
flow.response = http.Response(
b"HTTP/1.1",
101,
reason=b"Switching Protocols",
headers=http.Headers(
connection='upgrade',
upgrade='websocket',
sec_websocket_accept=b'',
),
content=b'',
trailers=None,
timestamp_start=946681202,
timestamp_end=946681203,
)
flow.websocket = websocket.WebSocketData()
if messages is True:
messages = [
websocket.WebSocketMessage(Opcode.BINARY, True, b"hello binary"),
websocket.WebSocketMessage(Opcode.TEXT, True, b"hello text"),
websocket.WebSocketMessage(Opcode.TEXT, False, b"it's me"),
flow.websocket.messages = [
websocket.WebSocketMessage(Opcode.BINARY, True, b"hello binary", 946681203),
websocket.WebSocketMessage(Opcode.TEXT, True, b"hello text", 946681204),
websocket.WebSocketMessage(Opcode.TEXT, False, b"it's me", 946681205),
]
if err is True:
err = terr()
flow.error = terr()
f.messages = messages
f.error = err
f.reply = controller.DummyReply()
return f
flow.reply = controller.DummyReply()
return flow
def tflow(client_conn=True, server_conn=True, req=True, resp=None, err=None):

View File

@ -119,6 +119,8 @@ else:
SCHEME_STYLES = {
'http': 'scheme_http',
'https': 'scheme_https',
'ws': 'scheme_ws',
'wss': 'scheme_wss',
'tcp': 'scheme_tcp',
}
HTTP_REQUEST_METHOD_STYLES = {
@ -297,12 +299,8 @@ def colorize_url(url):
parts = url.split('/', 3)
if len(parts) < 4 or len(parts[1]) > 0 or parts[0][-1:] != ':':
return [('error', len(url))] # bad URL
schemes = {
'http:': 'scheme_http',
'https:': 'scheme_https',
}
return [
(schemes.get(parts[0], "scheme_other"), len(parts[0]) - 1),
(SCHEME_STYLES.get(parts[0], "scheme_other"), len(parts[0]) - 1),
('url_punctuation', 3), # ://
] + colorize_host(parts[2]) + colorize_req('/' + parts[3])
@ -699,6 +697,13 @@ def format_flow(
response_content_type = None
duration = None
scheme = f.request.scheme
if f.websocket is not None:
if scheme == "https":
scheme = "wss"
elif scheme == "http":
scheme = "ws"
if render_mode in (RenderMode.LIST, RenderMode.DETAILVIEW):
render_func = format_http_flow_list
else:
@ -709,7 +714,7 @@ def format_flow(
marked=f.marked,
is_replay=f.is_replay,
request_method=f.request.method,
request_scheme=f.request.scheme,
request_scheme=scheme,
request_host=f.request.pretty_host if hostheader else f.request.host,
request_path=f.request.path,
request_url=f.request.pretty_url if hostheader else f.request.url,

View File

@ -42,25 +42,6 @@ console_flowlist_layout = [
]
class UnsupportedLog:
"""
A small addon to dump info on flow types we don't support yet.
"""
def websocket_message(self, f):
message = f.messages[-1]
ctx.log.info(f.message_info(message))
ctx.log.debug(
message.content if isinstance(message.content, str) else strutils.bytes_to_escaped_str(message.content))
def websocket_end(self, f):
ctx.log.info("WebSocket connection closed by {}: {} {}, {}".format(
f.close_sender,
f.close_code,
f.close_message,
f.close_reason))
class ConsoleAddon:
"""
An addon that exposes console-specific commands, and hooks into required
@ -226,11 +207,11 @@ class ConsoleAddon:
@command.command("console.choose")
def console_choose(
self,
prompt: str,
choices: typing.Sequence[str],
cmd: mitmproxy.types.Cmd,
*args: mitmproxy.types.CmdArgs
self,
prompt: str,
choices: typing.Sequence[str],
cmd: mitmproxy.types.Cmd,
*args: mitmproxy.types.CmdArgs
) -> None:
"""
Prompt the user to choose from a specified list of strings, then
@ -252,11 +233,11 @@ class ConsoleAddon:
@command.command("console.choose.cmd")
def console_choose_cmd(
self,
prompt: str,
choicecmd: mitmproxy.types.Cmd,
subcmd: mitmproxy.types.Cmd,
*args: mitmproxy.types.CmdArgs
self,
prompt: str,
choicecmd: mitmproxy.types.Cmd,
subcmd: mitmproxy.types.Cmd,
*args: mitmproxy.types.CmdArgs
) -> None:
"""
Prompt the user to choose from a list of strings returned by a
@ -415,8 +396,8 @@ class ConsoleAddon:
flow.backup()
require_dummy_response = (
flow_part in ("response-headers", "response-body", "set-cookies") and
flow.response is None
flow_part in ("response-headers", "response-body", "set-cookies") and
flow.response is None
)
if require_dummy_response:
flow.response = http.Response.make()
@ -584,11 +565,11 @@ class ConsoleAddon:
@command.command("console.key.bind")
def key_bind(
self,
contexts: typing.Sequence[str],
key: str,
cmd: mitmproxy.types.Cmd,
*args: mitmproxy.types.CmdArgs
self,
contexts: typing.Sequence[str],
key: str,
cmd: mitmproxy.types.Cmd,
*args: mitmproxy.types.CmdArgs
) -> None:
"""
Bind a shortcut key.

View File

@ -60,14 +60,23 @@ class FlowDetails(tabs.Tabs):
return self.master.view.focus.flow
def focus_changed(self):
if self.flow:
if isinstance(self.flow, http.HTTPFlow):
self.tabs = [
(self.tab_http_request, self.view_request),
(self.tab_http_response, self.view_response),
(self.tab_details, self.view_details),
]
elif isinstance(self.flow, tcp.TCPFlow):
f = self.flow
if f:
if isinstance(f, http.HTTPFlow):
if f.websocket:
self.tabs = [
(self.tab_http_request, self.view_request),
(self.tab_http_response, self.view_response),
(self.tab_websocket_messages, self.view_websocket_messages),
(self.tab_details, self.view_details),
]
else:
self.tabs = [
(self.tab_http_request, self.view_request),
(self.tab_http_response, self.view_response),
(self.tab_details, self.view_details),
]
elif isinstance(f, tcp.TCPFlow):
self.tabs = [
(self.tab_tcp_stream, self.view_tcp_stream),
(self.tab_details, self.view_details),
@ -95,6 +104,9 @@ class FlowDetails(tabs.Tabs):
def tab_tcp_stream(self):
return "TCP Stream"
def tab_websocket_messages(self):
return "WebSocket Messages"
def tab_details(self):
return "Detail"
@ -128,6 +140,36 @@ class FlowDetails(tabs.Tabs):
contentview_status_bar = urwid.AttrWrap(urwid.Columns(cols), "heading")
return contentview_status_bar
def view_websocket_messages(self):
flow = self.flow
assert isinstance(flow, http.HTTPFlow)
assert flow.websocket is not None
if not flow.websocket.messages:
return searchable.Searchable([urwid.Text(("highlight", "No messages."))])
viewmode = self.master.commands.call("console.flowview.mode")
widget_lines = []
for m in flow.websocket.messages:
_, lines, _ = contentviews.get_message_content_view(viewmode, m, flow)
for line in lines:
if m.from_client:
line.insert(0, ("from_client", f"{common.SYMBOL_FROM_CLIENT} "))
else:
line.insert(0, ("to_client", f"{common.SYMBOL_TO_CLIENT} "))
widget_lines.append(urwid.Text(line))
if flow.intercepted:
markup = widget_lines[-1].get_text()[0]
widget_lines[-1].set_text(("intercept", markup))
widget_lines.insert(0, self._contentview_status_bar(viewmode.capitalize(), viewmode))
return searchable.Searchable(widget_lines)
def view_tcp_stream(self) -> urwid.Widget:
flow = self.flow
assert isinstance(flow, tcp.TCPFlow)

View File

@ -53,7 +53,6 @@ class ConsoleMaster(master.Master):
intercept.Intercept(),
self.view,
self.events,
consoleaddons.UnsupportedLog(),
readfile.ReadFile(),
consoleaddons.ConsoleAddon(self),
keymap.KeymapConfig(),

View File

@ -23,7 +23,7 @@ class Palette:
# List and Connections
'method_get', 'method_post', 'method_delete', 'method_other', 'method_head', 'method_put', 'method_http2_push',
'scheme_http', 'scheme_https', 'scheme_tcp', 'scheme_other',
'scheme_http', 'scheme_https', 'scheme_ws', 'scheme_wss', 'scheme_tcp', 'scheme_other',
'url_punctuation', 'url_domain', 'url_filename', 'url_extension', 'url_query_key', 'url_query_value',
'content_none', 'content_text', 'content_script', 'content_media', 'content_data', 'content_raw', 'content_other',
'focus',
@ -136,6 +136,8 @@ class LowDark(Palette):
scheme_http = ('dark cyan', 'default'),
scheme_https = ('dark green', 'default'),
scheme_ws=('brown', 'default'),
scheme_wss=('dark magenta', 'default'),
scheme_tcp=('dark magenta', 'default'),
scheme_other = ('dark magenta', 'default'),
@ -245,6 +247,8 @@ class LowLight(Palette):
scheme_http = ('dark cyan', 'default'),
scheme_https = ('light green', 'default'),
scheme_ws=('brown', 'default'),
scheme_wss=('light magenta', 'default'),
scheme_tcp=('light magenta', 'default'),
scheme_other = ('light magenta', 'default'),
@ -373,6 +377,8 @@ class SolarizedLight(LowLight):
scheme_http = (sol_cyan, 'default'),
scheme_https = ('light green', 'default'),
scheme_ws=(sol_orange, 'default'),
scheme_wss=('light magenta', 'default'),
scheme_tcp=('light magenta', 'default'),
scheme_other = ('light magenta', 'default'),

View File

@ -7,7 +7,7 @@ MITMPROXY = "mitmproxy " + VERSION
# Serialization format version. This is displayed nowhere, it just needs to be incremented by one
# for each change in the file format.
FLOW_FORMAT_VERSION = 11
FLOW_FORMAT_VERSION = 12
def get_dev_version() -> str:

View File

@ -1,168 +1,126 @@
"""
*Deprecation Notice:* Mitmproxy's WebSocket API is going to change soon,
see <https://github.com/mitmproxy/mitmproxy/issues/4425>.
Mitmproxy used to have its own WebSocketFlow type until mitmproxy 6, but now WebSocket connections now are represented
as HTTP flows as well. They can be distinguished from regular HTTP requests by having the
`mitmproxy.http.HTTPFlow.websocket` attribute set.
This module only defines the classes for individual `WebSocketMessage`s and the `WebSocketData` container.
"""
import queue
import time
import warnings
from typing import List
from typing import List, Tuple, Union
from typing import Optional
from typing import Union
from mitmproxy import flow
from mitmproxy import stateobject
from mitmproxy.coretypes import serializable
from mitmproxy.net import websocket
from mitmproxy.utils import human
from mitmproxy.utils import strutils
from wsproto.frame_protocol import CloseReason
from wsproto.frame_protocol import Opcode
WebSocketMessageState = Tuple[int, bool, bytes, float, bool]
class WebSocketMessage(serializable.Serializable):
"""
A WebSocket message sent from one endpoint to the other.
A single WebSocket message sent from one peer to the other.
Fragmented WebSocket messages are reassembled by mitmproxy and the
represented as a single instance of this class.
The [WebSocket RFC](https://tools.ietf.org/html/rfc6455) specifies both
text and binary messages. To avoid a whole class of nasty type confusion bugs,
mitmproxy stores all message contents as binary. If you need text, you can decode the `content` property:
>>> from wsproto.frame_protocol import Opcode
>>> if message.type == Opcode.TEXT:
>>> text = message.content.decode()
Per the WebSocket spec, text messages always use UTF-8 encoding.
"""
type: Opcode
"""indicates either TEXT or BINARY (from wsproto.frame_protocol.Opcode)."""
from_client: bool
"""True if this messages was sent by the client."""
content: Union[bytes, str]
type: Opcode
"""
The message type, as per RFC 6455's [opcode](https://tools.ietf.org/html/rfc6455#section-5.2).
Note that mitmproxy will always store the message contents as *bytes*.
A dedicated `.text` property for text messages is planned, see https://github.com/mitmproxy/mitmproxy/pull/4486.
"""
content: bytes
"""A byte-string representing the content of this message."""
timestamp: float
"""Timestamp of when this message was received or created."""
killed: bool
"""True if this messages was killed and should not be sent to the other endpoint."""
"""True if the message has not been forwarded by mitmproxy, False otherwise."""
def __init__(
self,
type: int,
type: Union[int, Opcode],
from_client: bool,
content: Union[bytes, str],
content: bytes,
timestamp: Optional[float] = None,
killed: bool = False
killed: bool = False,
) -> None:
self.type = Opcode(type) # type: ignore
self.from_client = from_client
self.type = Opcode(type)
self.content = content
self.timestamp: float = timestamp or time.time()
self.killed = killed
@classmethod
def from_state(cls, state):
def from_state(cls, state: WebSocketMessageState):
return cls(*state)
def get_state(self):
def get_state(self) -> WebSocketMessageState:
return int(self.type), self.from_client, self.content, self.timestamp, self.killed
def set_state(self, state):
self.type, self.from_client, self.content, self.timestamp, self.killed = state
self.type = Opcode(self.type) # replace enum with bare int
def set_state(self, state: WebSocketMessageState) -> None:
typ, self.from_client, self.content, self.timestamp, self.killed = state
self.type = Opcode(typ)
def __repr__(self):
if self.type == Opcode.TEXT:
return "text message: {}".format(repr(self.content))
return repr(self.content.decode(errors="replace"))
else:
return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content))
return repr(self.content)
def kill(self): # pragma: no cover
"""
Kill this message.
It will not be sent to the other endpoint. This has no effect in streaming mode.
"""
warnings.warn(
"WebSocketMessage.kill is deprecated, set an empty content instead.",
DeprecationWarning,
stacklevel=2,
)
# empty str or empty bytes.
self.content = type(self.content)()
def kill(self):
# Likely to be replaced with .drop() in the future, see https://github.com/mitmproxy/mitmproxy/pull/4486
self.killed = True
class WebSocketFlow(flow.Flow):
class WebSocketData(stateobject.StateObject):
"""
A WebSocketFlow is a simplified representation of a WebSocket connection.
A data container for everything related to a single WebSocket connection.
This is typically accessed as `mitmproxy.http.HTTPFlow.websocket`.
"""
def __init__(self, client_conn, server_conn, handshake_flow, live=None):
super().__init__("websocket", client_conn, server_conn, live)
messages: List[WebSocketMessage]
"""All `WebSocketMessage`s transferred over this connection."""
self.messages: List[WebSocketMessage] = []
"""A list containing all WebSocketMessage's."""
self.close_sender = 'client'
"""'client' if the client initiated connection closing."""
self.close_code = CloseReason.NORMAL_CLOSURE
"""WebSocket close code."""
self.close_message = '(message missing)'
"""WebSocket close message."""
self.close_reason = 'unknown status code'
"""WebSocket close reason."""
self.stream = False
"""True of this connection is streaming directly to the other endpoint."""
self.handshake_flow = handshake_flow
"""The HTTP flow containing the initial WebSocket handshake."""
self.ended = False
"""True when the WebSocket connection has been closed."""
closed_by_client: Optional[bool] = None
"""
True if the client closed the connection,
False if the server closed the connection,
None if the connection is active.
"""
close_code: Optional[int] = None
"""[Close Code](https://tools.ietf.org/html/rfc6455#section-7.1.5)"""
close_reason: Optional[str] = None
"""[Close Reason](https://tools.ietf.org/html/rfc6455#section-7.1.6)"""
self._inject_messages_client = queue.Queue(maxsize=1)
self._inject_messages_server = queue.Queue(maxsize=1)
if handshake_flow:
self.client_key = websocket.get_client_key(handshake_flow.request.headers)
self.client_protocol = websocket.get_protocol(handshake_flow.request.headers)
self.client_extensions = websocket.get_extensions(handshake_flow.request.headers)
self.server_accept = websocket.get_server_accept(handshake_flow.response.headers)
self.server_protocol = websocket.get_protocol(handshake_flow.response.headers)
self.server_extensions = websocket.get_extensions(handshake_flow.response.headers)
else:
self.client_key = ''
self.client_protocol = ''
self.client_extensions = ''
self.server_accept = ''
self.server_protocol = ''
self.server_extensions = ''
_stateobject_attributes = flow.Flow._stateobject_attributes.copy()
# mypy doesn't support update with kwargs
_stateobject_attributes.update(dict(
_stateobject_attributes = dict(
messages=List[WebSocketMessage],
close_sender=str,
closed_by_client=bool,
close_code=int,
close_message=str,
close_reason=str,
client_key=str,
client_protocol=str,
client_extensions=str,
server_accept=str,
server_protocol=str,
server_extensions=str,
# Do not include handshake_flow, to prevent recursive serialization!
# Since mitmproxy-console currently only displays HTTPFlows,
# dumping the handshake_flow will include the WebSocketFlow too.
))
)
def get_state(self):
d = super().get_state()
d['close_code'] = int(d['close_code']) # replace enum with bare int
return d
def __init__(self):
self.messages = []
def __repr__(self):
return f"<WebSocketData ({len(self.messages)} messages)>"
@classmethod
def from_state(cls, state):
f = cls(None, None, None)
f.set_state(state)
return f
def __repr__(self):
return "<WebSocketFlow ({} messages)>".format(len(self.messages))
def message_info(self, message: WebSocketMessage) -> str:
return "{client} {direction} WebSocket {type} message {direction} {server}{endpoint}".format(
type=message.type,
client=human.format_address(self.client_conn.peername),
server=human.format_address(self.server_conn.address),
direction="->" if message.from_client else "<-",
endpoint=self.handshake_flow.request.path,
)
d = WebSocketData()
d.set_state(state)
return d

View File

@ -110,7 +110,7 @@ async def test_start_stop(tdata):
assert cp.count() == 1
cp.start_replay([tflow.twebsocketflow()])
await tctx.master.await_log("Can only replay HTTP flows.", level="warn")
await tctx.master.await_log("Can't replay WebSocket flows.", level="warn")
assert cp.count() == 1
cp.stop_replay()

View File

@ -233,7 +233,7 @@ def test_websocket():
d.websocket_end(f)
assert "WebSocket connection closed by" in sio.getvalue()
f = tflow.twebsocketflow(client_conn=True, err=True)
f = tflow.twebsocketflow(err=True)
d.websocket_error(f)
assert "Error in WebSocket" in sio_err.getvalue()

View File

@ -55,11 +55,11 @@ def test_websocket(tmpdir):
tctx.configure(sa, save_stream_file=p)
f = tflow.twebsocketflow()
sa.websocket_start(f)
sa.request(f)
sa.websocket_end(f)
f = tflow.twebsocketflow()
sa.websocket_start(f)
sa.request(f)
sa.websocket_error(f)
tctx.configure(sa, save_stream_file=None)

View File

@ -8,7 +8,7 @@ from mitmproxy import exceptions
["dumpfile-011.mitm", "https://example.com/", 1],
["dumpfile-018.mitm", "https://www.example.com/", 1],
["dumpfile-019.mitm", "https://webrv.rtb-seller.com/", 1],
["dumpfile-7-websocket.mitm", "https://echo.websocket.org/", 5],
["dumpfile-7-websocket.mitm", "https://echo.websocket.org/", 6],
])
def test_load(tdata, dumpfile, url, count):
with open(tdata.path("mitmproxy/data/" + dumpfile), "rb") as f:

View File

@ -1,39 +0,0 @@
from mitmproxy.net import websocket
def test_check_handshake():
assert not websocket.check_handshake({
"connection": "upgrade",
"upgrade": "webFOOsocket",
"sec-websocket-key": "foo",
})
assert websocket.check_handshake({
"connection": "upgrade",
"upgrade": "websocket",
"sec-websocket-key": "foo",
})
assert websocket.check_handshake({
"connection": "upgrade",
"upgrade": "websocket",
"sec-websocket-accept": "bar",
})
def test_get_extensions():
assert websocket.get_extensions({}) is None
assert websocket.get_extensions({"sec-websocket-extensions": "foo"}) == "foo"
def test_get_protocol():
assert websocket.get_protocol({}) is None
assert websocket.get_protocol({"sec-websocket-protocol": "foo"}) == "foo"
def test_get_client_key():
assert websocket.get_client_key({}) is None
assert websocket.get_client_key({"sec-websocket-key": "foo"}) == "foo"
def test_get_server_accept():
assert websocket.get_server_accept({}) is None
assert websocket.get_server_accept({"sec-websocket-accept": "foo"}) == "foo"

View File

@ -12,7 +12,6 @@ from mitmproxy.proxy.layers import TCPLayer, http, tls
from mitmproxy.proxy.layers.tcp import TcpStartHook
from mitmproxy.proxy.layers.websocket import WebsocketStartHook
from mitmproxy.tcp import TCPFlow
from mitmproxy.websocket import WebSocketFlow
from test.mitmproxy.proxy.tutils import Placeholder, Playbook, reply, reply_next_layer
@ -960,7 +959,7 @@ def test_upgrade(tctx, proto):
b"\r\n")
)
if proto == "websocket":
assert playbook << WebsocketStartHook(Placeholder(WebSocketFlow))
assert playbook << WebsocketStartHook(http_flow)
elif proto == "tcp":
assert playbook << TcpStartHook(Placeholder(TCPFlow))
else:

View File

@ -11,8 +11,9 @@ from mitmproxy.proxy.commands import SendData, CloseConnection, Log
from mitmproxy.connection import ConnectionState
from mitmproxy.proxy.events import DataReceived, ConnectionClosed
from mitmproxy.proxy.layers import http, websocket
from mitmproxy.websocket import WebSocketFlow
from mitmproxy.websocket import WebSocketData
from test.mitmproxy.proxy.tutils import Placeholder, Playbook, reply
from wsproto.frame_protocol import Opcode
@dataclass
@ -53,8 +54,7 @@ def test_upgrade(tctx):
"""Test a HTTP -> WebSocket upgrade"""
tctx.server.address = ("example.com", 80)
tctx.server.state = ConnectionState.OPEN
http_flow = Placeholder(HTTPFlow)
flow = Placeholder(WebSocketFlow)
flow = Placeholder(HTTPFlow)
assert (
Playbook(http.HttpLayer(tctx, HTTPMode.transparent))
>> DataReceived(tctx.client,
@ -63,9 +63,9 @@ def test_upgrade(tctx):
b"Upgrade: websocket\r\n"
b"Sec-WebSocket-Version: 13\r\n"
b"\r\n")
<< http.HttpRequestHeadersHook(http_flow)
<< http.HttpRequestHeadersHook(flow)
>> reply()
<< http.HttpRequestHook(http_flow)
<< http.HttpRequestHook(flow)
>> reply()
<< SendData(tctx.server, b"GET / HTTP/1.1\r\n"
b"Connection: upgrade\r\n"
@ -76,9 +76,9 @@ def test_upgrade(tctx):
b"Upgrade: websocket\r\n"
b"Connection: Upgrade\r\n"
b"\r\n")
<< http.HttpResponseHeadersHook(http_flow)
<< http.HttpResponseHeadersHook(flow)
>> reply()
<< http.HttpResponseHook(http_flow)
<< http.HttpResponseHook(flow)
>> reply()
<< SendData(tctx.client, b"HTTP/1.1 101 Switching Protocols\r\n"
b"Upgrade: websocket\r\n"
@ -95,12 +95,13 @@ def test_upgrade(tctx):
>> reply()
<< SendData(tctx.client, b"\x82\nhello back")
)
assert flow().handshake_flow == http_flow()
assert len(flow().messages) == 2
assert flow().messages[0].content == "hello world"
assert flow().messages[0].from_client
assert flow().messages[1].content == b"hello back"
assert flow().messages[1].from_client is False
assert len(flow().websocket.messages) == 2
assert flow().websocket.messages[0].content == b"hello world"
assert flow().websocket.messages[0].from_client
assert flow().websocket.messages[0].type == Opcode.TEXT
assert flow().websocket.messages[1].content == b"hello back"
assert flow().websocket.messages[1].from_client is False
assert flow().websocket.messages[1].type == Opcode.BINARY
@pytest.fixture()
@ -120,12 +121,12 @@ def ws_testdata(tctx):
"Connection": "upgrade",
"Upgrade": "websocket",
})
return tctx, Playbook(websocket.WebsocketLayer(tctx, flow))
flow.websocket = WebSocketData()
return tctx, Playbook(websocket.WebsocketLayer(tctx, flow)), flow
def test_modify_message(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
tctx, playbook, flow = ws_testdata
assert (
playbook
<< websocket.WebsocketStartHook(flow)
@ -133,7 +134,7 @@ def test_modify_message(ws_testdata):
>> DataReceived(tctx.server, b"\x81\x03foo")
<< websocket.WebsocketMessageHook(flow)
)
flow().messages[-1].content = flow().messages[-1].content.replace("foo", "foobar")
flow.websocket.messages[-1].content = flow.websocket.messages[-1].content.replace(b"foo", b"foobar")
assert (
playbook
>> reply()
@ -142,8 +143,7 @@ def test_modify_message(ws_testdata):
def test_drop_message(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
tctx, playbook, flow = ws_testdata
assert (
playbook
<< websocket.WebsocketStartHook(flow)
@ -151,7 +151,7 @@ def test_drop_message(ws_testdata):
>> DataReceived(tctx.server, b"\x81\x03foo")
<< websocket.WebsocketMessageHook(flow)
)
flow().messages[-1].content = ""
flow.websocket.messages[-1].kill()
assert (
playbook
>> reply()
@ -160,8 +160,7 @@ def test_drop_message(ws_testdata):
def test_fragmented(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
tctx, playbook, flow = ws_testdata
assert (
playbook
<< websocket.WebsocketStartHook(flow)
@ -173,12 +172,11 @@ def test_fragmented(ws_testdata):
<< SendData(tctx.client, b"\x01\x03foo")
<< SendData(tctx.client, b"\x80\x03bar")
)
assert flow().messages[-1].content == "foobar"
assert flow.websocket.messages[-1].content == b"foobar"
def test_protocol_error(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
tctx, playbook, flow = ws_testdata
assert (
playbook
<< websocket.WebsocketStartHook(flow)
@ -193,12 +191,11 @@ def test_protocol_error(ws_testdata):
>> reply()
)
assert not flow().messages
assert not flow.websocket.messages
def test_ping(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
tctx, playbook, flow = ws_testdata
assert (
playbook
<< websocket.WebsocketStartHook(flow)
@ -210,12 +207,11 @@ def test_ping(ws_testdata):
<< Log("Received WebSocket pong from server (payload: b'pong-with-payload')")
<< SendData(tctx.client, b"\x8a\x11pong-with-payload")
)
assert not flow().messages
assert not flow.websocket.messages
def test_close_normal(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
tctx, playbook, flow = ws_testdata
masked_close = Placeholder(bytes)
close = Placeholder(bytes)
assert (
@ -235,12 +231,11 @@ def test_close_normal(ws_testdata):
assert masked_close() == masked(b"\x88\x02\x03\xe8") or masked_close() == masked(b"\x88\x00")
assert close() == b"\x88\x02\x03\xe8" or close() == b"\x88\x00"
assert flow().close_code == 1005
assert flow.websocket.close_code == 1005
def test_close_disconnect(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
tctx, playbook, flow = ws_testdata
assert (
playbook
<< websocket.WebsocketStartHook(flow)
@ -253,12 +248,11 @@ def test_close_disconnect(ws_testdata):
>> reply()
>> ConnectionClosed(tctx.client)
)
assert "ABNORMAL_CLOSURE" in flow().error.msg
assert "ABNORMAL_CLOSURE" in flow.error.msg
def test_close_error(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
tctx, playbook, flow = ws_testdata
assert (
playbook
<< websocket.WebsocketStartHook(flow)
@ -271,15 +265,12 @@ def test_close_error(ws_testdata):
<< websocket.WebsocketErrorHook(flow)
>> reply()
)
assert "UNKNOWN_ERROR=4000" in flow().error.msg
assert "UNKNOWN_ERROR=4000" in flow.error.msg
def test_deflate(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
# noinspection PyUnresolvedReferences
http_flow: HTTPFlow = playbook.layer.flow.handshake_flow
http_flow.response.headers["Sec-WebSocket-Extensions"] = "permessage-deflate; server_max_window_bits=10"
tctx, playbook, flow = ws_testdata
flow.response.headers["Sec-WebSocket-Extensions"] = "permessage-deflate; server_max_window_bits=10"
assert (
playbook
<< websocket.WebsocketStartHook(flow)
@ -290,15 +281,12 @@ def test_deflate(ws_testdata):
>> reply()
<< SendData(tctx.client, bytes.fromhex("c1 07 f2 48 cd c9 c9 07 00"))
)
assert flow().messages[0].content == "Hello"
assert flow.websocket.messages[0].content == b"Hello"
def test_unknown_ext(ws_testdata):
tctx, playbook = ws_testdata
flow = Placeholder(WebSocketFlow)
# noinspection PyUnresolvedReferences
http_flow: HTTPFlow = playbook.layer.flow.handshake_flow
http_flow.response.headers["Sec-WebSocket-Extensions"] = "funky-bits; param=42"
tctx, playbook, flow = ws_testdata
flow.response.headers["Sec-WebSocket-Extensions"] = "funky-bits; param=42"
assert (
playbook
<< Log("Ignoring unknown WebSocket extension 'funky-bits'.")
@ -314,20 +302,20 @@ def test_websocket_connection_repr(tctx):
class TestFragmentizer:
def test_empty(self):
f = websocket.Fragmentizer([b"foo"])
f = websocket.Fragmentizer([b"foo"], False)
assert list(f(b"")) == []
def test_keep_sizes(self):
f = websocket.Fragmentizer([b"foo", b"bar"])
f = websocket.Fragmentizer([b"foo", b"bar"], True)
assert list(f(b"foobaz")) == [
wsproto.events.Message(b"foo", message_finished=False),
wsproto.events.Message(b"baz", message_finished=True),
wsproto.events.TextMessage("foo", message_finished=False),
wsproto.events.TextMessage("baz", message_finished=True),
]
def test_rechunk(self):
f = websocket.Fragmentizer([b"foo"])
f = websocket.Fragmentizer([b"foo"], False)
f.FRAGMENT_SIZE = 4
assert list(f(b"foobar")) == [
wsproto.events.Message(b"foob", message_finished=False),
wsproto.events.Message(b"ar", message_finished=True),
wsproto.events.BytesMessage(b"foob", message_finished=False),
wsproto.events.BytesMessage(b"ar", message_finished=True),
]

View File

@ -27,14 +27,20 @@ def test_http_flow(resp, err):
def test_websocket_flow(err):
f = tflow.twebsocketflow(err=err)
i = eventsequence.iterate(f)
assert isinstance(next(i), layers.http.HttpRequestHeadersHook)
assert isinstance(next(i), layers.http.HttpRequestHook)
assert isinstance(next(i), layers.http.HttpResponseHeadersHook)
assert isinstance(next(i), layers.http.HttpResponseHook)
assert isinstance(next(i), layers.websocket.WebsocketStartHook)
assert len(f.messages) == 0
assert len(f.websocket.messages) == 0
assert isinstance(next(i), layers.websocket.WebsocketMessageHook)
assert len(f.messages) == 1
assert len(f.websocket.messages) == 1
assert isinstance(next(i), layers.websocket.WebsocketMessageHook)
assert len(f.messages) == 2
assert len(f.websocket.messages) == 2
assert isinstance(next(i), layers.websocket.WebsocketMessageHook)
assert len(f.messages) == 3
assert len(f.websocket.messages) == 3
if err:
assert isinstance(next(i), layers.websocket.WebsocketErrorHook)
else:

View File

@ -122,20 +122,6 @@ class TestFlowMaster:
await ctx.master.load_flow(f)
assert s.flows[0].request.host == "use-this-domain"
@pytest.mark.asyncio
async def test_load_websocket_flow(self):
opts = options.Options(
mode="reverse:https://use-this-domain"
)
s = State()
with taddons.context(s, options=opts) as ctx:
f = tflow.twebsocketflow()
await ctx.master.load_flow(f.handshake_flow)
await ctx.master.load_flow(f)
assert s.flows[0].request.host == "use-this-domain"
assert s.flows[1].handshake_flow == f.handshake_flow
assert len(s.flows[1].messages) == len(f.messages)
@pytest.mark.asyncio
async def test_all(self):
opts = options.Options(

View File

@ -4,7 +4,7 @@ from unittest.mock import patch
from mitmproxy.test import tflow
from mitmproxy import flowfilter
from mitmproxy import flowfilter, http
class TestParsing:
@ -424,10 +424,10 @@ class TestMatchingTCPFlow:
class TestMatchingWebSocketFlow:
def flow(self):
def flow(self) -> http.HTTPFlow:
return tflow.twebsocketflow()
def err(self):
def err(self) -> http.HTTPFlow:
return tflow.twebsocketflow(err=True)
def q(self, q, o):
@ -437,10 +437,10 @@ class TestMatchingWebSocketFlow:
f = self.flow()
assert self.q("~websocket", f)
assert not self.q("~tcp", f)
assert not self.q("~http", f)
assert self.q("~http", f)
def test_handshake(self):
f = self.flow().handshake_flow
f = self.flow()
assert self.q("~websocket", f)
assert not self.q("~tcp", f)
assert self.q("~http", f)
@ -465,9 +465,6 @@ class TestMatchingWebSocketFlow:
assert self.q("~u example.com/ws", q)
assert not self.q("~u moo/path", q)
q.handshake_flow = None
assert not self.q("~u example.com", q)
def test_body(self):
f = self.flow()

View File

@ -1,88 +1,28 @@
import io
import pytest
from mitmproxy import flowfilter
from mitmproxy.exceptions import ControlException
from mitmproxy.io import tnetstring
from mitmproxy import http
from mitmproxy import websocket
from mitmproxy.test import tflow
from wsproto.frame_protocol import Opcode
class TestWebSocketFlow:
def test_copy(self):
f = tflow.twebsocketflow()
f.get_state()
f2 = f.copy()
a = f.get_state()
b = f2.get_state()
del a["id"]
del b["id"]
assert a == b
assert not f == f2
assert f is not f2
assert f.client_key == f2.client_key
assert f.client_protocol == f2.client_protocol
assert f.client_extensions == f2.client_extensions
assert f.server_accept == f2.server_accept
assert f.server_protocol == f2.server_protocol
assert f.server_extensions == f2.server_extensions
assert f.messages is not f2.messages
assert f.handshake_flow is not f2.handshake_flow
for m in f.messages:
m2 = m.copy()
m2.set_state(m2.get_state())
assert m is not m2
assert m.get_state() == m2.get_state()
f = tflow.twebsocketflow(err=True)
f2 = f.copy()
assert f is not f2
assert f.handshake_flow is not f2.handshake_flow
assert f.error.get_state() == f2.error.get_state()
assert f.error is not f2.error
def test_kill(self):
f = tflow.twebsocketflow()
with pytest.raises(ControlException):
f.intercept()
f.resume()
f.kill()
f = tflow.twebsocketflow()
f.intercept()
assert f.killable
f.kill()
assert not f.killable
def test_match(self):
f = tflow.twebsocketflow()
assert not flowfilter.match("~b nonexistent", f)
assert flowfilter.match(None, f)
assert not flowfilter.match("~b nonexistent", f)
f = tflow.twebsocketflow(err=True)
assert flowfilter.match("~e", f)
with pytest.raises(ValueError):
flowfilter.match("~", f)
class TestWebSocketData:
def test_repr(self):
assert repr(tflow.twebsocketflow().websocket) == "<WebSocketData (3 messages)>"
def test_state(self):
f = tflow.twebsocketflow()
assert f.message_info(f.messages[0])
assert 'WebSocketFlow' in repr(f)
assert 'binary message: ' in repr(f.messages[0])
assert 'text message: ' in repr(f.messages[1])
f2 = http.HTTPFlow.from_state(f.get_state())
f2.set_state(f.get_state())
def test_serialize(self):
b = io.BytesIO()
d = tflow.twebsocketflow().get_state()
tnetstring.dump(d, b)
assert b.getvalue()
b = io.BytesIO()
d = tflow.twebsocketflow().handshake_flow.get_state()
tnetstring.dump(d, b)
assert b.getvalue()
class TestWebSocketMessage:
def test_basic(self):
m = websocket.WebSocketMessage(Opcode.TEXT, True, b"foo")
m.set_state(m.get_state())
assert m.content == b"foo"
assert repr(m) == "'foo'"
m.type = Opcode.BINARY
assert repr(m) == "b'foo'"
assert not m.killed
m.kill()
assert m.killed