diff --git a/examples/addons/websocket-simple.py b/examples/addons/websocket-simple.py index 836b12384..c4ec4d22f 100644 --- a/examples/addons/websocket-simple.py +++ b/examples/addons/websocket-simple.py @@ -5,17 +5,17 @@ from mitmproxy import ctx def websocket_message(flow): # get the latest message - message = flow.messages[-1] + message = flow.websocket.messages[-1] # was the message sent from the client or server? if message.from_client: - ctx.log.info(f"Client sent a message: {message.content}") + ctx.log.info(f"Client sent a message: {message.content!r}") else: - ctx.log.info(f"Server sent a message: {message.content}") + ctx.log.info(f"Server sent a message: {message.content!r}") # manipulate the message content - message.content = re.sub(r'^Hello', 'HAPPY', message.content) + message.content = re.sub(rb'^Hello', b'HAPPY', message.content) - if 'FOOBAR' in message.content: + if b'FOOBAR' in message.content: # kill the message and not send it to the other endpoint message.content = "" diff --git a/mitmproxy/addons/clientplayback.py b/mitmproxy/addons/clientplayback.py index ac56fc8a1..e7d6603b2 100644 --- a/mitmproxy/addons/clientplayback.py +++ b/mitmproxy/addons/clientplayback.py @@ -145,6 +145,8 @@ class ClientPlayback: return "Can't replay flow with missing request." if f.request.raw_content is None: return "Can't replay flow with missing content." + if f.websocket is not None: + return "Can't replay WebSocket flows." else: return "Can only replay HTTP flows." return None diff --git a/mitmproxy/addons/dumper.py b/mitmproxy/addons/dumper.py index 3b05c0945..7ebb128b5 100644 --- a/mitmproxy/addons/dumper.py +++ b/mitmproxy/addons/dumper.py @@ -13,7 +13,7 @@ from mitmproxy import http from mitmproxy.tcp import TCPFlow, TCPMessage from mitmproxy.utils import human from mitmproxy.utils import strutils -from mitmproxy.websocket import WebSocketFlow, WebSocketMessage +from mitmproxy.websocket import WebSocketMessage def indent(n: int, text: str) -> str: @@ -98,7 +98,7 @@ class Dumper: def _echo_message( self, message: Union[http.Message, TCPMessage, WebSocketMessage], - flow: Union[http.HTTPFlow, TCPFlow, WebSocketFlow] + flow: Union[http.HTTPFlow, TCPFlow] ): _, lines, error = contentviews.get_message_content_view( ctx.options.dumper_default_contentview, @@ -277,37 +277,36 @@ class Dumper: if self.match(f): self.echo_flow(f) - def websocket_error(self, f): + def websocket_error(self, f: http.HTTPFlow): self.echo_error( - "Error in WebSocket connection to {}: {}".format( - human.format_address(f.server_conn.address), f.error - ), + f"Error in WebSocket connection to {human.format_address(f.server_conn.address)}: {f.error}", fg="red" ) - def websocket_message(self, f): + def websocket_message(self, f: http.HTTPFlow): + assert f.websocket is not None # satisfy type checker if self.match(f): - message = f.messages[-1] - self.echo(f.message_info(message)) + message = f.websocket.messages[-1] + + direction = "->" if message.from_client else "<-" + self.echo( + f"{human.format_address(f.client_conn.peername)} " + f"{direction} WebSocket {message.type.name.lower()} message " + f"{direction} {human.format_address(f.server_conn.address)}{f.request.path}" + ) if ctx.options.flow_detail >= 3: - message = message.from_state(message.get_state()) - message.content = message.content.encode() if isinstance(message.content, str) else message.content self._echo_message(message, f) - def websocket_end(self, f): + def websocket_end(self, f: http.HTTPFlow): + assert f.websocket is not None # satisfy type checker if self.match(f): - self.echo("WebSocket connection closed by {}: {} {}, {}".format( - f.close_sender, - f.close_code, - f.close_message, - f.close_reason)) + c = 'client' if f.websocket.closed_by_client else 'server' + self.echo(f"WebSocket connection closed by {c}: {f.websocket.close_code} {f.websocket.close_reason}") def tcp_error(self, f): if self.match(f): self.echo_error( - "Error in TCP connection to {}: {}".format( - human.format_address(f.server_conn.address), f.error - ), + f"Error in TCP connection to {human.format_address(f.server_conn.address)}: {f.error}", fg="red" ) diff --git a/mitmproxy/addons/save.py b/mitmproxy/addons/save.py index 60118aff6..49468f1ea 100644 --- a/mitmproxy/addons/save.py +++ b/mitmproxy/addons/save.py @@ -7,6 +7,7 @@ from mitmproxy import flowfilter from mitmproxy import io from mitmproxy import ctx from mitmproxy import flow +from mitmproxy import http import mitmproxy.types @@ -88,28 +89,26 @@ class Save: def tcp_error(self, flow): self.tcp_end(flow) - def websocket_start(self, flow): - if self.stream: - self.active_flows.add(flow) - - def websocket_end(self, flow): + def websocket_end(self, flow: http.HTTPFlow): if self.stream: self.stream.add(flow) self.active_flows.discard(flow) - def websocket_error(self, flow): + def websocket_error(self, flow: http.HTTPFlow): self.websocket_end(flow) - def request(self, flow): + def request(self, flow: http.HTTPFlow): if self.stream: self.active_flows.add(flow) - def response(self, flow): - if self.stream: + def response(self, flow: http.HTTPFlow): + # websocket flows will receive either websocket_end or websocket_error, + # we don't want to persist them here already + if self.stream and flow.websocket is None: self.stream.add(flow) self.active_flows.discard(flow) - def error(self, flow): + def error(self, flow: http.HTTPFlow): self.response(flow) def done(self): diff --git a/mitmproxy/contentviews/__init__.py b/mitmproxy/contentviews/__init__.py index 165dc7bf2..564300978 100644 --- a/mitmproxy/contentviews/__init__.py +++ b/mitmproxy/contentviews/__init__.py @@ -25,7 +25,7 @@ from . import ( from .base import View, KEY_MAX, format_text, format_dict, TViewResult from ..http import HTTPFlow from ..tcp import TCPMessage, TCPFlow -from ..websocket import WebSocketMessage, WebSocketFlow +from ..websocket import WebSocketMessage views: List[View] = [] @@ -67,7 +67,7 @@ def safe_to_print(lines, encoding="utf8"): def get_message_content_view( viewname: str, message: Union[http.Message, TCPMessage, WebSocketMessage], - flow: Union[HTTPFlow, TCPFlow, WebSocketFlow], + flow: Union[HTTPFlow, TCPFlow], ): """ Like get_content_view, but also handles message encoding. @@ -79,7 +79,7 @@ def get_message_content_view( content: Optional[bytes] try: - content = message.content # type: ignore + content = message.content except ValueError: assert isinstance(message, http.Message) content = message.raw_content diff --git a/mitmproxy/eventsequence.py b/mitmproxy/eventsequence.py index 9c2b6e3db..0efb916b9 100644 --- a/mitmproxy/eventsequence.py +++ b/mitmproxy/eventsequence.py @@ -1,11 +1,10 @@ -from typing import Iterator, Any, Dict, Type, Callable +from typing import Any, Callable, Dict, Iterator, Type from mitmproxy import controller -from mitmproxy import hooks from mitmproxy import flow +from mitmproxy import hooks from mitmproxy import http from mitmproxy import tcp -from mitmproxy import websocket from mitmproxy.proxy import layers TEventGenerator = Iterator[hooks.Hook] @@ -18,24 +17,21 @@ def _iterate_http(f: http.HTTPFlow) -> TEventGenerator: if f.response: yield layers.http.HttpResponseHeadersHook(f) yield layers.http.HttpResponseHook(f) - if f.error: + if f.websocket: + message_queue = f.websocket.messages + f.websocket.messages = [] + yield layers.websocket.WebsocketStartHook(f) + for m in message_queue: + f.websocket.messages.append(m) + yield layers.websocket.WebsocketMessageHook(f) + if f.error: + yield layers.websocket.WebsocketErrorHook(f) + else: + yield layers.websocket.WebsocketEndHook(f) + elif f.error: yield layers.http.HttpErrorHook(f) -def _iterate_websocket(f: websocket.WebSocketFlow) -> TEventGenerator: - messages = f.messages - f.messages = [] - f.reply = controller.DummyReply() - yield layers.websocket.WebsocketStartHook(f) - while messages: - f.messages.append(messages.pop(0)) - yield layers.websocket.WebsocketMessageHook(f) - if f.error: - yield layers.websocket.WebsocketErrorHook(f) - else: - yield layers.websocket.WebsocketEndHook(f) - - def _iterate_tcp(f: tcp.TCPFlow) -> TEventGenerator: messages = f.messages f.messages = [] @@ -52,7 +48,6 @@ def _iterate_tcp(f: tcp.TCPFlow) -> TEventGenerator: _iterate_map: Dict[Type[flow.Flow], Callable[[Any], TEventGenerator]] = { http.HTTPFlow: _iterate_http, - websocket.WebSocketFlow: _iterate_websocket, tcp.TCPFlow: _iterate_tcp, } diff --git a/mitmproxy/flowfilter.py b/mitmproxy/flowfilter.py index b1448f285..384670ddf 100644 --- a/mitmproxy/flowfilter.py +++ b/mitmproxy/flowfilter.py @@ -39,8 +39,7 @@ from typing import Callable, ClassVar, Optional, Sequence, Type import pyparsing as pp -from mitmproxy import flow, http, tcp, websocket -from mitmproxy.net.websocket import check_handshake +from mitmproxy import flow, http, tcp def only(*types): @@ -102,15 +101,11 @@ class FHTTP(_Action): class FWebSocket(_Action): code = "websocket" - help = "Match WebSocket flows (and HTTP-WebSocket handshake flows)" + help = "Match WebSocket flows" - @only(http.HTTPFlow, websocket.WebSocketFlow) - def __call__(self, f): - m = ( - (isinstance(f, http.HTTPFlow) and f.request and check_handshake(f.request.headers)) - or isinstance(f, websocket.WebSocketFlow) - ) - return m + @only(http.HTTPFlow) + def __call__(self, f: http.HTTPFlow): + return f.websocket is not None class FTCP(_Action): @@ -258,7 +253,7 @@ class FBod(_Rex): help = "Body" flags = re.DOTALL - @only(http.HTTPFlow, websocket.WebSocketFlow, tcp.TCPFlow) + @only(http.HTTPFlow, tcp.TCPFlow) def __call__(self, f): if isinstance(f, http.HTTPFlow): if f.request and f.request.raw_content: @@ -267,7 +262,11 @@ class FBod(_Rex): if f.response and f.response.raw_content: if self.re.search(f.response.get_content(strict=False)): return True - elif isinstance(f, websocket.WebSocketFlow) or isinstance(f, tcp.TCPFlow): + if f.websocket: + for msg in f.websocket.messages: + if self.re.search(msg.content): + return True + elif isinstance(f, tcp.TCPFlow): for msg in f.messages: if self.re.search(msg.content): return True @@ -279,13 +278,17 @@ class FBodRequest(_Rex): help = "Request body" flags = re.DOTALL - @only(http.HTTPFlow, websocket.WebSocketFlow, tcp.TCPFlow) + @only(http.HTTPFlow, tcp.TCPFlow) def __call__(self, f): if isinstance(f, http.HTTPFlow): if f.request and f.request.raw_content: if self.re.search(f.request.get_content(strict=False)): return True - elif isinstance(f, websocket.WebSocketFlow) or isinstance(f, tcp.TCPFlow): + if f.websocket: + for msg in f.websocket.messages: + if msg.from_client and self.re.search(msg.content): + return True + elif isinstance(f, tcp.TCPFlow): for msg in f.messages: if msg.from_client and self.re.search(msg.content): return True @@ -296,13 +299,17 @@ class FBodResponse(_Rex): help = "Response body" flags = re.DOTALL - @only(http.HTTPFlow, websocket.WebSocketFlow, tcp.TCPFlow) + @only(http.HTTPFlow, tcp.TCPFlow) def __call__(self, f): if isinstance(f, http.HTTPFlow): if f.response and f.response.raw_content: if self.re.search(f.response.get_content(strict=False)): return True - elif isinstance(f, websocket.WebSocketFlow) or isinstance(f, tcp.TCPFlow): + if f.websocket: + for msg in f.websocket.messages: + if not msg.from_client and self.re.search(msg.content): + return True + elif isinstance(f, tcp.TCPFlow): for msg in f.messages: if not msg.from_client and self.re.search(msg.content): return True @@ -324,10 +331,8 @@ class FDomain(_Rex): flags = re.IGNORECASE is_binary = False - @only(http.HTTPFlow, websocket.WebSocketFlow) + @only(http.HTTPFlow) def __call__(self, f): - if isinstance(f, websocket.WebSocketFlow): - f = f.handshake_flow return bool( self.re.search(f.request.host) or self.re.search(f.request.pretty_host) @@ -347,10 +352,8 @@ class FUrl(_Rex): toks = toks[1:] return klass(*toks) - @only(http.HTTPFlow, websocket.WebSocketFlow) + @only(http.HTTPFlow) def __call__(self, f): - if isinstance(f, websocket.WebSocketFlow): - f = f.handshake_flow if not f or not f.request: return False return self.re.search(f.request.pretty_url) @@ -482,9 +485,9 @@ def _make(): unicode_words = pp.CharsNotIn("()~'\"" + pp.ParserElement.DEFAULT_WHITE_CHARS) unicode_words.skipWhitespace = True regex = ( - unicode_words - | pp.QuotedString('"', escChar='\\') - | pp.QuotedString("'", escChar='\\') + unicode_words + | pp.QuotedString('"', escChar='\\') + | pp.QuotedString("'", escChar='\\') ) for cls in filter_rex: f = pp.Literal(f"~{cls.code}") + pp.WordEnd() + regex.copy() diff --git a/mitmproxy/http.py b/mitmproxy/http.py index 1893a9ef5..b3060bd31 100644 --- a/mitmproxy/http.py +++ b/mitmproxy/http.py @@ -18,6 +18,7 @@ from typing import Union from typing import cast from mitmproxy import flow +from mitmproxy.websocket import WebSocketData from mitmproxy.coretypes import multidict from mitmproxy.coretypes import serializable from mitmproxy.net import encoding @@ -1169,6 +1170,11 @@ class HTTPFlow(flow.Flow): from the server, but there was an error sending it back to the client. """ + websocket: Optional[WebSocketData] = None + """ + If this HTTP flow initiated a WebSocket connection, this attribute contains all associated WebSocket data. + """ + def __init__(self, client_conn, server_conn, live=None, mode="regular"): super().__init__("http", client_conn, server_conn, live) self.mode = mode @@ -1178,12 +1184,13 @@ class HTTPFlow(flow.Flow): _stateobject_attributes.update(dict( request=Request, response=Response, + websocket=WebSocketData, mode=str )) def __repr__(self): s = " Any: if isinstance(o, dict): return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()} @@ -308,6 +357,7 @@ converters = { 8: convert_8_9, 9: convert_9_10, 10: convert_10_11, + 11: convert_11_12, } @@ -325,8 +375,8 @@ def migrate_flow(flow_data: Dict[Union[bytes, str], Any]) -> Dict[Union[bytes, s flow_data = converters[flow_version](flow_data) else: should_upgrade = ( - isinstance(flow_version, int) - and flow_version > version.FLOW_FORMAT_VERSION + isinstance(flow_version, int) + and flow_version > version.FLOW_FORMAT_VERSION ) raise ValueError( "{} cannot read files with flow format version {}{}.".format( diff --git a/mitmproxy/io/io.py b/mitmproxy/io/io.py index e65c7ebbf..e9251c717 100644 --- a/mitmproxy/io/io.py +++ b/mitmproxy/io/io.py @@ -1,19 +1,16 @@ import os -from typing import Type, Iterable, Dict, Union, Any, cast # noqa +from typing import Any, Dict, Iterable, Type, Union, cast # noqa from mitmproxy import exceptions from mitmproxy import flow from mitmproxy import flowfilter from mitmproxy import http from mitmproxy import tcp -from mitmproxy import websocket - from mitmproxy.io import compat from mitmproxy.io import tnetstring FLOW_TYPES: Dict[str, Type[flow.Flow]] = dict( http=http.HTTPFlow, - websocket=websocket.WebSocketFlow, tcp=tcp.TCPFlow, ) diff --git a/mitmproxy/master.py b/mitmproxy/master.py index 4f4910b16..888406ae7 100644 --- a/mitmproxy/master.py +++ b/mitmproxy/master.py @@ -11,7 +11,6 @@ from mitmproxy import eventsequence from mitmproxy import http from mitmproxy import log from mitmproxy import options -from mitmproxy import websocket from mitmproxy.net import server_spec from . import ctx as mitmproxy_ctx @@ -34,7 +33,6 @@ class Master: self.commands = command.CommandManager(self) self.addons = addonmanager.AddonManager(self) self._server = None - self.waiting_flows = [] self.log = log.Log(self) mitmproxy_ctx.master = self @@ -111,24 +109,11 @@ class Master: async def load_flow(self, f): """ - Loads a flow and links websocket & handshake flows + Loads a flow """ if isinstance(f, http.HTTPFlow): self._change_reverse_host(f) - if 'websocket' in f.metadata: - self.waiting_flows.append(f) - - if isinstance(f, websocket.WebSocketFlow): - hfs = [hf for hf in self.waiting_flows if hf.id == f.metadata['websocket_handshake']] - if hfs: - hf = hfs[0] - f.handshake_flow = hf - self.waiting_flows.remove(hf) - self._change_reverse_host(f.handshake_flow) - else: - # this will fail - but at least it will load the remaining flows - f.handshake_flow = http.HTTPFlow(None, None) f.reply = controller.DummyReply() for e in eventsequence.iterate(f): diff --git a/mitmproxy/net/websocket.py b/mitmproxy/net/websocket.py deleted file mode 100644 index 4758db0ca..000000000 --- a/mitmproxy/net/websocket.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -Collection of WebSocket protocol utility functions (RFC6455) -Spec: https://tools.ietf.org/html/rfc6455 -""" - - -def check_handshake(headers): - return ( - "upgrade" in headers.get("connection", "").lower() and - headers.get("upgrade", "").lower() == "websocket" and - (headers.get("sec-websocket-key") is not None or headers.get("sec-websocket-accept") is not None) - ) - - -def get_extensions(headers): - return headers.get("sec-websocket-extensions", None) - - -def get_protocol(headers): - return headers.get("sec-websocket-protocol", None) - - -def get_client_key(headers): - return headers.get("sec-websocket-key", None) - - -def get_server_accept(headers): - return headers.get("sec-websocket-accept", None) diff --git a/mitmproxy/proxy/layers/http/__init__.py b/mitmproxy/proxy/layers/http/__init__.py index e81d0e1e2..1671f15d5 100644 --- a/mitmproxy/proxy/layers/http/__init__.py +++ b/mitmproxy/proxy/layers/http/__init__.py @@ -4,6 +4,8 @@ import time from dataclasses import dataclass from typing import DefaultDict, Dict, List, Optional, Tuple, Union +import wsproto.handshake + from mitmproxy import flow, http from mitmproxy.connection import Connection, Server from mitmproxy.net import server_spec @@ -13,6 +15,7 @@ from mitmproxy.proxy.layers import tcp, tls, websocket from mitmproxy.proxy.layers.http import _upstream_proxy from mitmproxy.proxy.utils import expect from mitmproxy.utils import human +from mitmproxy.websocket import WebSocketData from ._base import HttpCommand, HttpConnection, ReceiveHttp, StreamId from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \ ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError @@ -308,6 +311,21 @@ class HttpStream(layer.Layer): """We have either consumed the entire response from the server or the response was set by an addon.""" assert self.flow.response self.flow.response.timestamp_end = time.time() + + is_websocket = ( + self.flow.response.status_code == 101 + and + self.flow.response.headers.get("upgrade", "").lower() == "websocket" + and + self.flow.request.headers.get("Sec-WebSocket-Version", "").encode() == wsproto.handshake.WEBSOCKET_VERSION + and + self.context.options.websocket + ) + if is_websocket: + # We need to set this before calling the response hook + # so that addons can determine if a WebSocket connection is following up. + self.flow.websocket = WebSocketData() + yield HttpResponseHook(self.flow) self.server_state = self.state_done if (yield from self.check_killed(False)): @@ -322,12 +340,7 @@ class HttpStream(layer.Layer): yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client) if self.flow.response.status_code == 101: - is_websocket = ( - self.flow.response.headers.get("upgrade", "").lower() == "websocket" - and - self.flow.request.headers.get("Sec-WebSocket-Version", "") == "13" - ) - if is_websocket and self.context.options.websocket: + if is_websocket: self.child_layer = websocket.WebsocketLayer(self.context, self.flow) elif self.context.options.rawtcp: self.child_layer = tcp.TCPLayer(self.context) diff --git a/mitmproxy/proxy/layers/websocket.py b/mitmproxy/proxy/layers/websocket.py index a09a76aef..f83ef14f0 100644 --- a/mitmproxy/proxy/layers/websocket.py +++ b/mitmproxy/proxy/layers/websocket.py @@ -1,11 +1,11 @@ from dataclasses import dataclass -from typing import Union, List, Iterator +from typing import Iterator, List import wsproto import wsproto.extensions import wsproto.frame_protocol import wsproto.utilities -from mitmproxy import flow, websocket, http, connection +from mitmproxy import connection, flow, http, websocket from mitmproxy.proxy import commands, events, layer from mitmproxy.proxy.commands import StartHook from mitmproxy.proxy.context import Context @@ -19,7 +19,7 @@ class WebsocketStartHook(StartHook): """ A WebSocket connection has commenced. """ - flow: websocket.WebSocketFlow + flow: http.HTTPFlow @dataclass @@ -30,7 +30,7 @@ class WebsocketMessageHook(StartHook): message is user-modifiable. Currently there are two types of messages, corresponding to the BINARY and TEXT frame types. """ - flow: websocket.WebSocketFlow + flow: http.HTTPFlow @dataclass @@ -39,7 +39,7 @@ class WebsocketEndHook(StartHook): A WebSocket connection has ended. """ - flow: websocket.WebSocketFlow + flow: http.HTTPFlow @dataclass @@ -49,7 +49,7 @@ class WebsocketErrorHook(StartHook): Every WebSocket flow will receive either a websocket_error or a websocket_end event, but not both. """ - flow: websocket.WebSocketFlow + flow: http.HTTPFlow class WebsocketConnection(wsproto.Connection): @@ -61,7 +61,7 @@ class WebsocketConnection(wsproto.Connection): - we wrap .send() so that we can directly yield it. """ conn: connection.Connection - frame_buf: List[Union[str, bytes]] + frame_buf: List[bytes] def __init__(self, *args, conn: connection.Connection, **kwargs): super(WebsocketConnection, self).__init__(*args, **kwargs) @@ -80,13 +80,13 @@ class WebsocketLayer(layer.Layer): """ WebSocket layer that intercepts and relays messages. """ - flow: websocket.WebSocketFlow + flow: http.HTTPFlow client_ws: WebsocketConnection server_ws: WebsocketConnection - def __init__(self, context: Context, handshake_flow: http.HTTPFlow): + def __init__(self, context: Context, flow: http.HTTPFlow): super().__init__(context) - self.flow = websocket.WebSocketFlow(context.client, context.server, handshake_flow) + self.flow = flow assert context.server.connected @expect(events.Start) @@ -96,7 +96,8 @@ class WebsocketLayer(layer.Layer): server_extensions = [] # Parse extension headers. We only support deflate at the moment and ignore everything else. - ext_header = self.flow.handshake_flow.response.headers.get("Sec-WebSocket-Extensions", "") + assert self.flow.response # satisfy type checker + ext_header = self.flow.response.headers.get("Sec-WebSocket-Extensions", "") if ext_header: for ext in wsproto.utilities.split_comma_header(ext_header.encode("ascii", "replace")): ext_name = ext.split(";", 1)[0].strip() @@ -115,15 +116,14 @@ class WebsocketLayer(layer.Layer): yield WebsocketStartHook(self.flow) - if self.flow.stream: # pragma: no cover - raise NotImplementedError("WebSocket streaming is not supported at the moment.") - self._handle_event = self.relay_messages _handle_event = start @expect(events.DataReceived, events.ConnectionClosed) def relay_messages(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]: + assert self.flow.websocket # satisfy type checker + from_client = event.connection == self.context.client from_str = 'client' if from_client else 'server' if from_client: @@ -142,27 +142,27 @@ class WebsocketLayer(layer.Layer): for ws_event in src_ws.events(): if isinstance(ws_event, wsproto.events.Message): - src_ws.frame_buf.append(ws_event.data) + is_text = isinstance(ws_event.data, str) + if is_text: + typ = Opcode.TEXT + src_ws.frame_buf.append(ws_event.data.encode()) + else: + typ = Opcode.BINARY + src_ws.frame_buf.append(ws_event.data) if ws_event.message_finished: - if isinstance(ws_event, wsproto.events.TextMessage): - frame_type = Opcode.TEXT - content = "".join(src_ws.frame_buf) # type: ignore - else: - frame_type = Opcode.BINARY - content = b"".join(src_ws.frame_buf) # type: ignore + content = b"".join(src_ws.frame_buf) - fragmentizer = Fragmentizer(src_ws.frame_buf) + fragmentizer = Fragmentizer(src_ws.frame_buf, is_text) src_ws.frame_buf.clear() - message = websocket.WebSocketMessage(frame_type, from_client, content) - self.flow.messages.append(message) + message = websocket.WebSocketMessage(typ, from_client, content) + self.flow.websocket.messages.append(message) yield WebsocketMessageHook(self.flow) - assert not message.killed # this is deprecated, instead we should have .content set to emptystr. - - for msg in fragmentizer(message.content): - yield dst_ws.send2(msg) + if not message.killed: + for msg in fragmentizer(message.content): + yield dst_ws.send2(msg) elif isinstance(ws_event, (wsproto.events.Ping, wsproto.events.Pong)): yield commands.Log( @@ -171,9 +171,9 @@ class WebsocketLayer(layer.Layer): ) yield dst_ws.send2(ws_event) elif isinstance(ws_event, wsproto.events.CloseConnection): - self.flow.close_sender = from_str - self.flow.close_code = ws_event.code - self.flow.close_reason = ws_event.reason + self.flow.websocket.closed_by_client = from_client + self.flow.websocket.close_code = ws_event.code + self.flow.websocket.close_reason = ws_event.reason for ws in [self.server_ws, self.client_ws]: if ws.state in {ConnectionState.OPEN, ConnectionState.REMOTE_CLOSING}: @@ -215,27 +215,35 @@ class Fragmentizer: As a workaround, we either retain the original chunking or, if the payload has been modified, use ~4kB chunks. """ - # A bit less than 4kb to accomodate for headers. + # A bit less than 4kb to accommodate for headers. FRAGMENT_SIZE = 4000 - def __init__(self, fragments: List[Union[str, bytes]]): + def __init__(self, fragments: List[bytes], is_text: bool): assert fragments self.fragment_lengths = [len(x) for x in fragments] + self.is_text = is_text - def __call__(self, content: Union[str, bytes]) -> Iterator[wsproto.events.Message]: + def msg(self, data: bytes, message_finished: bool): + if self.is_text: + data_str = data.decode(errors="replace") + return wsproto.events.TextMessage(data_str, message_finished=message_finished) + else: + return wsproto.events.BytesMessage(data, message_finished=message_finished) + + def __call__(self, content: bytes) -> Iterator[wsproto.events.Message]: if not content: return if len(content) == sum(self.fragment_lengths): # message has the same length, we can reuse the same sizes offset = 0 for fl in self.fragment_lengths[:-1]: - yield wsproto.events.Message(content[offset:offset + fl], message_finished=False) + yield self.msg(content[offset:offset + fl], False) offset += fl - yield wsproto.events.Message(content[offset:], message_finished=True) + yield self.msg(content[offset:], True) else: offset = 0 total = len(content) - self.FRAGMENT_SIZE while offset < total: - yield wsproto.events.Message(content[offset:offset + self.FRAGMENT_SIZE], message_finished=False) + yield self.msg(content[offset:offset + self.FRAGMENT_SIZE], False) offset += self.FRAGMENT_SIZE - yield wsproto.events.Message(content[offset:], message_finished=True) + yield self.msg(content[offset:], True) diff --git a/mitmproxy/test/tflow.py b/mitmproxy/test/tflow.py index 5e745b6cb..9f76d34eb 100644 --- a/mitmproxy/test/tflow.py +++ b/mitmproxy/test/tflow.py @@ -6,7 +6,6 @@ from mitmproxy import flow from mitmproxy import http from mitmproxy import tcp from mitmproxy import websocket -from mitmproxy.net.http import status_codes from mitmproxy.test import tutils from wsproto.frame_protocol import Opcode @@ -31,68 +30,55 @@ def ttcpflow(client_conn=True, server_conn=True, messages=True, err=None): return f -def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None, handshake_flow=True): - if client_conn is True: - client_conn = tclient_conn() - if server_conn is True: - server_conn = tserver_conn() - if handshake_flow is True: - req = http.Request( - "example.com", - 80, - b"GET", - b"http", - b"example.com", - b"/ws", - b"HTTP/1.1", - headers=http.Headers( - connection="upgrade", - upgrade="websocket", - sec_websocket_version="13", - sec_websocket_key="1234", - ), - content=b'', - trailers=None, - timestamp_start=946681200, - timestamp_end=946681201, +def twebsocketflow(messages=True, err=None) -> http.HTTPFlow: + flow = http.HTTPFlow(tclient_conn(), tserver_conn()) + flow.request = http.Request( + "example.com", + 80, + b"GET", + b"http", + b"example.com", + b"/ws", + b"HTTP/1.1", + headers=http.Headers( + connection="upgrade", + upgrade="websocket", + sec_websocket_version="13", + sec_websocket_key="1234", + ), + content=b'', + trailers=None, + timestamp_start=946681200, + timestamp_end=946681201, - ) - resp = http.Response( - b"HTTP/1.1", - 101, - reason=status_codes.RESPONSES.get(101), - headers=http.Headers( - connection='upgrade', - upgrade='websocket', - sec_websocket_accept=b'', - ), - content=b'', - trailers=None, - timestamp_start=946681202, - timestamp_end=946681203, - ) - handshake_flow = http.HTTPFlow(client_conn, server_conn) - handshake_flow.request = req - handshake_flow.response = resp - - f = websocket.WebSocketFlow(client_conn, server_conn, handshake_flow) - f.metadata['websocket_handshake'] = handshake_flow.id - handshake_flow.metadata['websocket_flow'] = f.id - handshake_flow.metadata['websocket'] = True + ) + flow.response = http.Response( + b"HTTP/1.1", + 101, + reason=b"Switching Protocols", + headers=http.Headers( + connection='upgrade', + upgrade='websocket', + sec_websocket_accept=b'', + ), + content=b'', + trailers=None, + timestamp_start=946681202, + timestamp_end=946681203, + ) + flow.websocket = websocket.WebSocketData() if messages is True: - messages = [ - websocket.WebSocketMessage(Opcode.BINARY, True, b"hello binary"), - websocket.WebSocketMessage(Opcode.TEXT, True, b"hello text"), - websocket.WebSocketMessage(Opcode.TEXT, False, b"it's me"), + flow.websocket.messages = [ + websocket.WebSocketMessage(Opcode.BINARY, True, b"hello binary", 946681203), + websocket.WebSocketMessage(Opcode.TEXT, True, b"hello text", 946681204), + websocket.WebSocketMessage(Opcode.TEXT, False, b"it's me", 946681205), ] if err is True: - err = terr() + flow.error = terr() - f.messages = messages - f.error = err - f.reply = controller.DummyReply() - return f + flow.reply = controller.DummyReply() + return flow def tflow(client_conn=True, server_conn=True, req=True, resp=None, err=None): diff --git a/mitmproxy/tools/console/common.py b/mitmproxy/tools/console/common.py index fe31a6375..edb4a0d1e 100644 --- a/mitmproxy/tools/console/common.py +++ b/mitmproxy/tools/console/common.py @@ -119,6 +119,8 @@ else: SCHEME_STYLES = { 'http': 'scheme_http', 'https': 'scheme_https', + 'ws': 'scheme_ws', + 'wss': 'scheme_wss', 'tcp': 'scheme_tcp', } HTTP_REQUEST_METHOD_STYLES = { @@ -297,12 +299,8 @@ def colorize_url(url): parts = url.split('/', 3) if len(parts) < 4 or len(parts[1]) > 0 or parts[0][-1:] != ':': return [('error', len(url))] # bad URL - schemes = { - 'http:': 'scheme_http', - 'https:': 'scheme_https', - } return [ - (schemes.get(parts[0], "scheme_other"), len(parts[0]) - 1), + (SCHEME_STYLES.get(parts[0], "scheme_other"), len(parts[0]) - 1), ('url_punctuation', 3), # :// ] + colorize_host(parts[2]) + colorize_req('/' + parts[3]) @@ -699,6 +697,13 @@ def format_flow( response_content_type = None duration = None + scheme = f.request.scheme + if f.websocket is not None: + if scheme == "https": + scheme = "wss" + elif scheme == "http": + scheme = "ws" + if render_mode in (RenderMode.LIST, RenderMode.DETAILVIEW): render_func = format_http_flow_list else: @@ -709,7 +714,7 @@ def format_flow( marked=f.marked, is_replay=f.is_replay, request_method=f.request.method, - request_scheme=f.request.scheme, + request_scheme=scheme, request_host=f.request.pretty_host if hostheader else f.request.host, request_path=f.request.path, request_url=f.request.pretty_url if hostheader else f.request.url, diff --git a/mitmproxy/tools/console/consoleaddons.py b/mitmproxy/tools/console/consoleaddons.py index d737a2291..61d614977 100644 --- a/mitmproxy/tools/console/consoleaddons.py +++ b/mitmproxy/tools/console/consoleaddons.py @@ -42,25 +42,6 @@ console_flowlist_layout = [ ] -class UnsupportedLog: - """ - A small addon to dump info on flow types we don't support yet. - """ - - def websocket_message(self, f): - message = f.messages[-1] - ctx.log.info(f.message_info(message)) - ctx.log.debug( - message.content if isinstance(message.content, str) else strutils.bytes_to_escaped_str(message.content)) - - def websocket_end(self, f): - ctx.log.info("WebSocket connection closed by {}: {} {}, {}".format( - f.close_sender, - f.close_code, - f.close_message, - f.close_reason)) - - class ConsoleAddon: """ An addon that exposes console-specific commands, and hooks into required @@ -226,11 +207,11 @@ class ConsoleAddon: @command.command("console.choose") def console_choose( - self, - prompt: str, - choices: typing.Sequence[str], - cmd: mitmproxy.types.Cmd, - *args: mitmproxy.types.CmdArgs + self, + prompt: str, + choices: typing.Sequence[str], + cmd: mitmproxy.types.Cmd, + *args: mitmproxy.types.CmdArgs ) -> None: """ Prompt the user to choose from a specified list of strings, then @@ -252,11 +233,11 @@ class ConsoleAddon: @command.command("console.choose.cmd") def console_choose_cmd( - self, - prompt: str, - choicecmd: mitmproxy.types.Cmd, - subcmd: mitmproxy.types.Cmd, - *args: mitmproxy.types.CmdArgs + self, + prompt: str, + choicecmd: mitmproxy.types.Cmd, + subcmd: mitmproxy.types.Cmd, + *args: mitmproxy.types.CmdArgs ) -> None: """ Prompt the user to choose from a list of strings returned by a @@ -415,8 +396,8 @@ class ConsoleAddon: flow.backup() require_dummy_response = ( - flow_part in ("response-headers", "response-body", "set-cookies") and - flow.response is None + flow_part in ("response-headers", "response-body", "set-cookies") and + flow.response is None ) if require_dummy_response: flow.response = http.Response.make() @@ -584,11 +565,11 @@ class ConsoleAddon: @command.command("console.key.bind") def key_bind( - self, - contexts: typing.Sequence[str], - key: str, - cmd: mitmproxy.types.Cmd, - *args: mitmproxy.types.CmdArgs + self, + contexts: typing.Sequence[str], + key: str, + cmd: mitmproxy.types.Cmd, + *args: mitmproxy.types.CmdArgs ) -> None: """ Bind a shortcut key. diff --git a/mitmproxy/tools/console/flowview.py b/mitmproxy/tools/console/flowview.py index 3d3cf4766..4c6a11176 100644 --- a/mitmproxy/tools/console/flowview.py +++ b/mitmproxy/tools/console/flowview.py @@ -60,14 +60,23 @@ class FlowDetails(tabs.Tabs): return self.master.view.focus.flow def focus_changed(self): - if self.flow: - if isinstance(self.flow, http.HTTPFlow): - self.tabs = [ - (self.tab_http_request, self.view_request), - (self.tab_http_response, self.view_response), - (self.tab_details, self.view_details), - ] - elif isinstance(self.flow, tcp.TCPFlow): + f = self.flow + if f: + if isinstance(f, http.HTTPFlow): + if f.websocket: + self.tabs = [ + (self.tab_http_request, self.view_request), + (self.tab_http_response, self.view_response), + (self.tab_websocket_messages, self.view_websocket_messages), + (self.tab_details, self.view_details), + ] + else: + self.tabs = [ + (self.tab_http_request, self.view_request), + (self.tab_http_response, self.view_response), + (self.tab_details, self.view_details), + ] + elif isinstance(f, tcp.TCPFlow): self.tabs = [ (self.tab_tcp_stream, self.view_tcp_stream), (self.tab_details, self.view_details), @@ -95,6 +104,9 @@ class FlowDetails(tabs.Tabs): def tab_tcp_stream(self): return "TCP Stream" + def tab_websocket_messages(self): + return "WebSocket Messages" + def tab_details(self): return "Detail" @@ -128,6 +140,36 @@ class FlowDetails(tabs.Tabs): contentview_status_bar = urwid.AttrWrap(urwid.Columns(cols), "heading") return contentview_status_bar + def view_websocket_messages(self): + flow = self.flow + assert isinstance(flow, http.HTTPFlow) + assert flow.websocket is not None + + if not flow.websocket.messages: + return searchable.Searchable([urwid.Text(("highlight", "No messages."))]) + + viewmode = self.master.commands.call("console.flowview.mode") + + widget_lines = [] + for m in flow.websocket.messages: + _, lines, _ = contentviews.get_message_content_view(viewmode, m, flow) + + for line in lines: + if m.from_client: + line.insert(0, ("from_client", f"{common.SYMBOL_FROM_CLIENT} ")) + else: + line.insert(0, ("to_client", f"{common.SYMBOL_TO_CLIENT} ")) + + widget_lines.append(urwid.Text(line)) + + if flow.intercepted: + markup = widget_lines[-1].get_text()[0] + widget_lines[-1].set_text(("intercept", markup)) + + widget_lines.insert(0, self._contentview_status_bar(viewmode.capitalize(), viewmode)) + + return searchable.Searchable(widget_lines) + def view_tcp_stream(self) -> urwid.Widget: flow = self.flow assert isinstance(flow, tcp.TCPFlow) diff --git a/mitmproxy/tools/console/master.py b/mitmproxy/tools/console/master.py index f3d9887b3..3910149a7 100644 --- a/mitmproxy/tools/console/master.py +++ b/mitmproxy/tools/console/master.py @@ -53,7 +53,6 @@ class ConsoleMaster(master.Master): intercept.Intercept(), self.view, self.events, - consoleaddons.UnsupportedLog(), readfile.ReadFile(), consoleaddons.ConsoleAddon(self), keymap.KeymapConfig(), diff --git a/mitmproxy/tools/console/palettes.py b/mitmproxy/tools/console/palettes.py index 15745df59..571358864 100644 --- a/mitmproxy/tools/console/palettes.py +++ b/mitmproxy/tools/console/palettes.py @@ -23,7 +23,7 @@ class Palette: # List and Connections 'method_get', 'method_post', 'method_delete', 'method_other', 'method_head', 'method_put', 'method_http2_push', - 'scheme_http', 'scheme_https', 'scheme_tcp', 'scheme_other', + 'scheme_http', 'scheme_https', 'scheme_ws', 'scheme_wss', 'scheme_tcp', 'scheme_other', 'url_punctuation', 'url_domain', 'url_filename', 'url_extension', 'url_query_key', 'url_query_value', 'content_none', 'content_text', 'content_script', 'content_media', 'content_data', 'content_raw', 'content_other', 'focus', @@ -136,6 +136,8 @@ class LowDark(Palette): scheme_http = ('dark cyan', 'default'), scheme_https = ('dark green', 'default'), + scheme_ws=('brown', 'default'), + scheme_wss=('dark magenta', 'default'), scheme_tcp=('dark magenta', 'default'), scheme_other = ('dark magenta', 'default'), @@ -245,6 +247,8 @@ class LowLight(Palette): scheme_http = ('dark cyan', 'default'), scheme_https = ('light green', 'default'), + scheme_ws=('brown', 'default'), + scheme_wss=('light magenta', 'default'), scheme_tcp=('light magenta', 'default'), scheme_other = ('light magenta', 'default'), @@ -373,6 +377,8 @@ class SolarizedLight(LowLight): scheme_http = (sol_cyan, 'default'), scheme_https = ('light green', 'default'), + scheme_ws=(sol_orange, 'default'), + scheme_wss=('light magenta', 'default'), scheme_tcp=('light magenta', 'default'), scheme_other = ('light magenta', 'default'), diff --git a/mitmproxy/version.py b/mitmproxy/version.py index 883469ddd..18a12a595 100644 --- a/mitmproxy/version.py +++ b/mitmproxy/version.py @@ -7,7 +7,7 @@ MITMPROXY = "mitmproxy " + VERSION # Serialization format version. This is displayed nowhere, it just needs to be incremented by one # for each change in the file format. -FLOW_FORMAT_VERSION = 11 +FLOW_FORMAT_VERSION = 12 def get_dev_version() -> str: diff --git a/mitmproxy/websocket.py b/mitmproxy/websocket.py index e7137c0d8..841944058 100644 --- a/mitmproxy/websocket.py +++ b/mitmproxy/websocket.py @@ -1,168 +1,126 @@ """ -*Deprecation Notice:* Mitmproxy's WebSocket API is going to change soon, -see . +Mitmproxy used to have its own WebSocketFlow type until mitmproxy 6, but now WebSocket connections now are represented +as HTTP flows as well. They can be distinguished from regular HTTP requests by having the +`mitmproxy.http.HTTPFlow.websocket` attribute set. + +This module only defines the classes for individual `WebSocketMessage`s and the `WebSocketData` container. """ -import queue import time -import warnings -from typing import List +from typing import List, Tuple, Union from typing import Optional -from typing import Union -from mitmproxy import flow +from mitmproxy import stateobject from mitmproxy.coretypes import serializable -from mitmproxy.net import websocket -from mitmproxy.utils import human -from mitmproxy.utils import strutils - -from wsproto.frame_protocol import CloseReason from wsproto.frame_protocol import Opcode +WebSocketMessageState = Tuple[int, bool, bytes, float, bool] + class WebSocketMessage(serializable.Serializable): """ - A WebSocket message sent from one endpoint to the other. + A single WebSocket message sent from one peer to the other. + + Fragmented WebSocket messages are reassembled by mitmproxy and the + represented as a single instance of this class. + + The [WebSocket RFC](https://tools.ietf.org/html/rfc6455) specifies both + text and binary messages. To avoid a whole class of nasty type confusion bugs, + mitmproxy stores all message contents as binary. If you need text, you can decode the `content` property: + + >>> from wsproto.frame_protocol import Opcode + >>> if message.type == Opcode.TEXT: + >>> text = message.content.decode() + + Per the WebSocket spec, text messages always use UTF-8 encoding. """ - type: Opcode - """indicates either TEXT or BINARY (from wsproto.frame_protocol.Opcode).""" from_client: bool """True if this messages was sent by the client.""" - content: Union[bytes, str] + type: Opcode + """ + The message type, as per RFC 6455's [opcode](https://tools.ietf.org/html/rfc6455#section-5.2). + + Note that mitmproxy will always store the message contents as *bytes*. + A dedicated `.text` property for text messages is planned, see https://github.com/mitmproxy/mitmproxy/pull/4486. + """ + content: bytes """A byte-string representing the content of this message.""" timestamp: float """Timestamp of when this message was received or created.""" - killed: bool - """True if this messages was killed and should not be sent to the other endpoint.""" + """True if the message has not been forwarded by mitmproxy, False otherwise.""" def __init__( self, - type: int, + type: Union[int, Opcode], from_client: bool, - content: Union[bytes, str], + content: bytes, timestamp: Optional[float] = None, - killed: bool = False + killed: bool = False, ) -> None: - self.type = Opcode(type) # type: ignore self.from_client = from_client + self.type = Opcode(type) self.content = content self.timestamp: float = timestamp or time.time() self.killed = killed @classmethod - def from_state(cls, state): + def from_state(cls, state: WebSocketMessageState): return cls(*state) - def get_state(self): + def get_state(self) -> WebSocketMessageState: return int(self.type), self.from_client, self.content, self.timestamp, self.killed - def set_state(self, state): - self.type, self.from_client, self.content, self.timestamp, self.killed = state - self.type = Opcode(self.type) # replace enum with bare int + def set_state(self, state: WebSocketMessageState) -> None: + typ, self.from_client, self.content, self.timestamp, self.killed = state + self.type = Opcode(typ) def __repr__(self): if self.type == Opcode.TEXT: - return "text message: {}".format(repr(self.content)) + return repr(self.content.decode(errors="replace")) else: - return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content)) + return repr(self.content) - def kill(self): # pragma: no cover - """ - Kill this message. - - It will not be sent to the other endpoint. This has no effect in streaming mode. - """ - warnings.warn( - "WebSocketMessage.kill is deprecated, set an empty content instead.", - DeprecationWarning, - stacklevel=2, - ) - # empty str or empty bytes. - self.content = type(self.content)() + def kill(self): + # Likely to be replaced with .drop() in the future, see https://github.com/mitmproxy/mitmproxy/pull/4486 + self.killed = True -class WebSocketFlow(flow.Flow): +class WebSocketData(stateobject.StateObject): """ - A WebSocketFlow is a simplified representation of a WebSocket connection. + A data container for everything related to a single WebSocket connection. + This is typically accessed as `mitmproxy.http.HTTPFlow.websocket`. """ - def __init__(self, client_conn, server_conn, handshake_flow, live=None): - super().__init__("websocket", client_conn, server_conn, live) + messages: List[WebSocketMessage] + """All `WebSocketMessage`s transferred over this connection.""" - self.messages: List[WebSocketMessage] = [] - """A list containing all WebSocketMessage's.""" - self.close_sender = 'client' - """'client' if the client initiated connection closing.""" - self.close_code = CloseReason.NORMAL_CLOSURE - """WebSocket close code.""" - self.close_message = '(message missing)' - """WebSocket close message.""" - self.close_reason = 'unknown status code' - """WebSocket close reason.""" - self.stream = False - """True of this connection is streaming directly to the other endpoint.""" - self.handshake_flow = handshake_flow - """The HTTP flow containing the initial WebSocket handshake.""" - self.ended = False - """True when the WebSocket connection has been closed.""" + closed_by_client: Optional[bool] = None + """ + True if the client closed the connection, + False if the server closed the connection, + None if the connection is active. + """ + close_code: Optional[int] = None + """[Close Code](https://tools.ietf.org/html/rfc6455#section-7.1.5)""" + close_reason: Optional[str] = None + """[Close Reason](https://tools.ietf.org/html/rfc6455#section-7.1.6)""" - self._inject_messages_client = queue.Queue(maxsize=1) - self._inject_messages_server = queue.Queue(maxsize=1) - - if handshake_flow: - self.client_key = websocket.get_client_key(handshake_flow.request.headers) - self.client_protocol = websocket.get_protocol(handshake_flow.request.headers) - self.client_extensions = websocket.get_extensions(handshake_flow.request.headers) - self.server_accept = websocket.get_server_accept(handshake_flow.response.headers) - self.server_protocol = websocket.get_protocol(handshake_flow.response.headers) - self.server_extensions = websocket.get_extensions(handshake_flow.response.headers) - else: - self.client_key = '' - self.client_protocol = '' - self.client_extensions = '' - self.server_accept = '' - self.server_protocol = '' - self.server_extensions = '' - - _stateobject_attributes = flow.Flow._stateobject_attributes.copy() - # mypy doesn't support update with kwargs - _stateobject_attributes.update(dict( + _stateobject_attributes = dict( messages=List[WebSocketMessage], - close_sender=str, + closed_by_client=bool, close_code=int, - close_message=str, close_reason=str, - client_key=str, - client_protocol=str, - client_extensions=str, - server_accept=str, - server_protocol=str, - server_extensions=str, - # Do not include handshake_flow, to prevent recursive serialization! - # Since mitmproxy-console currently only displays HTTPFlows, - # dumping the handshake_flow will include the WebSocketFlow too. - )) + ) - def get_state(self): - d = super().get_state() - d['close_code'] = int(d['close_code']) # replace enum with bare int - return d + def __init__(self): + self.messages = [] + + def __repr__(self): + return f"" @classmethod def from_state(cls, state): - f = cls(None, None, None) - f.set_state(state) - return f - - def __repr__(self): - return "".format(len(self.messages)) - - def message_info(self, message: WebSocketMessage) -> str: - return "{client} {direction} WebSocket {type} message {direction} {server}{endpoint}".format( - type=message.type, - client=human.format_address(self.client_conn.peername), - server=human.format_address(self.server_conn.address), - direction="->" if message.from_client else "<-", - endpoint=self.handshake_flow.request.path, - ) + d = WebSocketData() + d.set_state(state) + return d diff --git a/test/mitmproxy/addons/test_clientplayback.py b/test/mitmproxy/addons/test_clientplayback.py index 0704b712f..682a097e5 100644 --- a/test/mitmproxy/addons/test_clientplayback.py +++ b/test/mitmproxy/addons/test_clientplayback.py @@ -110,7 +110,7 @@ async def test_start_stop(tdata): assert cp.count() == 1 cp.start_replay([tflow.twebsocketflow()]) - await tctx.master.await_log("Can only replay HTTP flows.", level="warn") + await tctx.master.await_log("Can't replay WebSocket flows.", level="warn") assert cp.count() == 1 cp.stop_replay() diff --git a/test/mitmproxy/addons/test_dumper.py b/test/mitmproxy/addons/test_dumper.py index 0f6c0353f..c4fa19ae4 100644 --- a/test/mitmproxy/addons/test_dumper.py +++ b/test/mitmproxy/addons/test_dumper.py @@ -233,7 +233,7 @@ def test_websocket(): d.websocket_end(f) assert "WebSocket connection closed by" in sio.getvalue() - f = tflow.twebsocketflow(client_conn=True, err=True) + f = tflow.twebsocketflow(err=True) d.websocket_error(f) assert "Error in WebSocket" in sio_err.getvalue() diff --git a/test/mitmproxy/addons/test_save.py b/test/mitmproxy/addons/test_save.py index 65ec0d76d..754c755d1 100644 --- a/test/mitmproxy/addons/test_save.py +++ b/test/mitmproxy/addons/test_save.py @@ -55,11 +55,11 @@ def test_websocket(tmpdir): tctx.configure(sa, save_stream_file=p) f = tflow.twebsocketflow() - sa.websocket_start(f) + sa.request(f) sa.websocket_end(f) f = tflow.twebsocketflow() - sa.websocket_start(f) + sa.request(f) sa.websocket_error(f) tctx.configure(sa, save_stream_file=None) diff --git a/test/mitmproxy/data/dumpfile-7-websocket.mitm b/test/mitmproxy/data/dumpfile-7-websocket.mitm index b34e7d908..daef027e9 100644 Binary files a/test/mitmproxy/data/dumpfile-7-websocket.mitm and b/test/mitmproxy/data/dumpfile-7-websocket.mitm differ diff --git a/test/mitmproxy/io/test_compat.py b/test/mitmproxy/io/test_compat.py index 26b5c35e5..904064a7b 100644 --- a/test/mitmproxy/io/test_compat.py +++ b/test/mitmproxy/io/test_compat.py @@ -8,7 +8,7 @@ from mitmproxy import exceptions ["dumpfile-011.mitm", "https://example.com/", 1], ["dumpfile-018.mitm", "https://www.example.com/", 1], ["dumpfile-019.mitm", "https://webrv.rtb-seller.com/", 1], - ["dumpfile-7-websocket.mitm", "https://echo.websocket.org/", 5], + ["dumpfile-7-websocket.mitm", "https://echo.websocket.org/", 6], ]) def test_load(tdata, dumpfile, url, count): with open(tdata.path("mitmproxy/data/" + dumpfile), "rb") as f: diff --git a/test/mitmproxy/net/test_websocket.py b/test/mitmproxy/net/test_websocket.py deleted file mode 100644 index 06ea6581a..000000000 --- a/test/mitmproxy/net/test_websocket.py +++ /dev/null @@ -1,39 +0,0 @@ -from mitmproxy.net import websocket - - -def test_check_handshake(): - assert not websocket.check_handshake({ - "connection": "upgrade", - "upgrade": "webFOOsocket", - "sec-websocket-key": "foo", - }) - assert websocket.check_handshake({ - "connection": "upgrade", - "upgrade": "websocket", - "sec-websocket-key": "foo", - }) - assert websocket.check_handshake({ - "connection": "upgrade", - "upgrade": "websocket", - "sec-websocket-accept": "bar", - }) - - -def test_get_extensions(): - assert websocket.get_extensions({}) is None - assert websocket.get_extensions({"sec-websocket-extensions": "foo"}) == "foo" - - -def test_get_protocol(): - assert websocket.get_protocol({}) is None - assert websocket.get_protocol({"sec-websocket-protocol": "foo"}) == "foo" - - -def test_get_client_key(): - assert websocket.get_client_key({}) is None - assert websocket.get_client_key({"sec-websocket-key": "foo"}) == "foo" - - -def test_get_server_accept(): - assert websocket.get_server_accept({}) is None - assert websocket.get_server_accept({"sec-websocket-accept": "foo"}) == "foo" diff --git a/test/mitmproxy/proxy/layers/http/test_http.py b/test/mitmproxy/proxy/layers/http/test_http.py index c8608e3a1..574aebe0f 100644 --- a/test/mitmproxy/proxy/layers/http/test_http.py +++ b/test/mitmproxy/proxy/layers/http/test_http.py @@ -12,7 +12,6 @@ from mitmproxy.proxy.layers import TCPLayer, http, tls from mitmproxy.proxy.layers.tcp import TcpStartHook from mitmproxy.proxy.layers.websocket import WebsocketStartHook from mitmproxy.tcp import TCPFlow -from mitmproxy.websocket import WebSocketFlow from test.mitmproxy.proxy.tutils import Placeholder, Playbook, reply, reply_next_layer @@ -960,7 +959,7 @@ def test_upgrade(tctx, proto): b"\r\n") ) if proto == "websocket": - assert playbook << WebsocketStartHook(Placeholder(WebSocketFlow)) + assert playbook << WebsocketStartHook(http_flow) elif proto == "tcp": assert playbook << TcpStartHook(Placeholder(TCPFlow)) else: diff --git a/test/mitmproxy/proxy/layers/test_websocket.py b/test/mitmproxy/proxy/layers/test_websocket.py index f4a654d47..5d372a2a3 100644 --- a/test/mitmproxy/proxy/layers/test_websocket.py +++ b/test/mitmproxy/proxy/layers/test_websocket.py @@ -11,8 +11,9 @@ from mitmproxy.proxy.commands import SendData, CloseConnection, Log from mitmproxy.connection import ConnectionState from mitmproxy.proxy.events import DataReceived, ConnectionClosed from mitmproxy.proxy.layers import http, websocket -from mitmproxy.websocket import WebSocketFlow +from mitmproxy.websocket import WebSocketData from test.mitmproxy.proxy.tutils import Placeholder, Playbook, reply +from wsproto.frame_protocol import Opcode @dataclass @@ -53,8 +54,7 @@ def test_upgrade(tctx): """Test a HTTP -> WebSocket upgrade""" tctx.server.address = ("example.com", 80) tctx.server.state = ConnectionState.OPEN - http_flow = Placeholder(HTTPFlow) - flow = Placeholder(WebSocketFlow) + flow = Placeholder(HTTPFlow) assert ( Playbook(http.HttpLayer(tctx, HTTPMode.transparent)) >> DataReceived(tctx.client, @@ -63,9 +63,9 @@ def test_upgrade(tctx): b"Upgrade: websocket\r\n" b"Sec-WebSocket-Version: 13\r\n" b"\r\n") - << http.HttpRequestHeadersHook(http_flow) + << http.HttpRequestHeadersHook(flow) >> reply() - << http.HttpRequestHook(http_flow) + << http.HttpRequestHook(flow) >> reply() << SendData(tctx.server, b"GET / HTTP/1.1\r\n" b"Connection: upgrade\r\n" @@ -76,9 +76,9 @@ def test_upgrade(tctx): b"Upgrade: websocket\r\n" b"Connection: Upgrade\r\n" b"\r\n") - << http.HttpResponseHeadersHook(http_flow) + << http.HttpResponseHeadersHook(flow) >> reply() - << http.HttpResponseHook(http_flow) + << http.HttpResponseHook(flow) >> reply() << SendData(tctx.client, b"HTTP/1.1 101 Switching Protocols\r\n" b"Upgrade: websocket\r\n" @@ -95,12 +95,13 @@ def test_upgrade(tctx): >> reply() << SendData(tctx.client, b"\x82\nhello back") ) - assert flow().handshake_flow == http_flow() - assert len(flow().messages) == 2 - assert flow().messages[0].content == "hello world" - assert flow().messages[0].from_client - assert flow().messages[1].content == b"hello back" - assert flow().messages[1].from_client is False + assert len(flow().websocket.messages) == 2 + assert flow().websocket.messages[0].content == b"hello world" + assert flow().websocket.messages[0].from_client + assert flow().websocket.messages[0].type == Opcode.TEXT + assert flow().websocket.messages[1].content == b"hello back" + assert flow().websocket.messages[1].from_client is False + assert flow().websocket.messages[1].type == Opcode.BINARY @pytest.fixture() @@ -120,12 +121,12 @@ def ws_testdata(tctx): "Connection": "upgrade", "Upgrade": "websocket", }) - return tctx, Playbook(websocket.WebsocketLayer(tctx, flow)) + flow.websocket = WebSocketData() + return tctx, Playbook(websocket.WebsocketLayer(tctx, flow)), flow def test_modify_message(ws_testdata): - tctx, playbook = ws_testdata - flow = Placeholder(WebSocketFlow) + tctx, playbook, flow = ws_testdata assert ( playbook << websocket.WebsocketStartHook(flow) @@ -133,7 +134,7 @@ def test_modify_message(ws_testdata): >> DataReceived(tctx.server, b"\x81\x03foo") << websocket.WebsocketMessageHook(flow) ) - flow().messages[-1].content = flow().messages[-1].content.replace("foo", "foobar") + flow.websocket.messages[-1].content = flow.websocket.messages[-1].content.replace(b"foo", b"foobar") assert ( playbook >> reply() @@ -142,8 +143,7 @@ def test_modify_message(ws_testdata): def test_drop_message(ws_testdata): - tctx, playbook = ws_testdata - flow = Placeholder(WebSocketFlow) + tctx, playbook, flow = ws_testdata assert ( playbook << websocket.WebsocketStartHook(flow) @@ -151,7 +151,7 @@ def test_drop_message(ws_testdata): >> DataReceived(tctx.server, b"\x81\x03foo") << websocket.WebsocketMessageHook(flow) ) - flow().messages[-1].content = "" + flow.websocket.messages[-1].kill() assert ( playbook >> reply() @@ -160,8 +160,7 @@ def test_drop_message(ws_testdata): def test_fragmented(ws_testdata): - tctx, playbook = ws_testdata - flow = Placeholder(WebSocketFlow) + tctx, playbook, flow = ws_testdata assert ( playbook << websocket.WebsocketStartHook(flow) @@ -173,12 +172,11 @@ def test_fragmented(ws_testdata): << SendData(tctx.client, b"\x01\x03foo") << SendData(tctx.client, b"\x80\x03bar") ) - assert flow().messages[-1].content == "foobar" + assert flow.websocket.messages[-1].content == b"foobar" def test_protocol_error(ws_testdata): - tctx, playbook = ws_testdata - flow = Placeholder(WebSocketFlow) + tctx, playbook, flow = ws_testdata assert ( playbook << websocket.WebsocketStartHook(flow) @@ -193,12 +191,11 @@ def test_protocol_error(ws_testdata): >> reply() ) - assert not flow().messages + assert not flow.websocket.messages def test_ping(ws_testdata): - tctx, playbook = ws_testdata - flow = Placeholder(WebSocketFlow) + tctx, playbook, flow = ws_testdata assert ( playbook << websocket.WebsocketStartHook(flow) @@ -210,12 +207,11 @@ def test_ping(ws_testdata): << Log("Received WebSocket pong from server (payload: b'pong-with-payload')") << SendData(tctx.client, b"\x8a\x11pong-with-payload") ) - assert not flow().messages + assert not flow.websocket.messages def test_close_normal(ws_testdata): - tctx, playbook = ws_testdata - flow = Placeholder(WebSocketFlow) + tctx, playbook, flow = ws_testdata masked_close = Placeholder(bytes) close = Placeholder(bytes) assert ( @@ -235,12 +231,11 @@ def test_close_normal(ws_testdata): assert masked_close() == masked(b"\x88\x02\x03\xe8") or masked_close() == masked(b"\x88\x00") assert close() == b"\x88\x02\x03\xe8" or close() == b"\x88\x00" - assert flow().close_code == 1005 + assert flow.websocket.close_code == 1005 def test_close_disconnect(ws_testdata): - tctx, playbook = ws_testdata - flow = Placeholder(WebSocketFlow) + tctx, playbook, flow = ws_testdata assert ( playbook << websocket.WebsocketStartHook(flow) @@ -253,12 +248,11 @@ def test_close_disconnect(ws_testdata): >> reply() >> ConnectionClosed(tctx.client) ) - assert "ABNORMAL_CLOSURE" in flow().error.msg + assert "ABNORMAL_CLOSURE" in flow.error.msg def test_close_error(ws_testdata): - tctx, playbook = ws_testdata - flow = Placeholder(WebSocketFlow) + tctx, playbook, flow = ws_testdata assert ( playbook << websocket.WebsocketStartHook(flow) @@ -271,15 +265,12 @@ def test_close_error(ws_testdata): << websocket.WebsocketErrorHook(flow) >> reply() ) - assert "UNKNOWN_ERROR=4000" in flow().error.msg + assert "UNKNOWN_ERROR=4000" in flow.error.msg def test_deflate(ws_testdata): - tctx, playbook = ws_testdata - flow = Placeholder(WebSocketFlow) - # noinspection PyUnresolvedReferences - http_flow: HTTPFlow = playbook.layer.flow.handshake_flow - http_flow.response.headers["Sec-WebSocket-Extensions"] = "permessage-deflate; server_max_window_bits=10" + tctx, playbook, flow = ws_testdata + flow.response.headers["Sec-WebSocket-Extensions"] = "permessage-deflate; server_max_window_bits=10" assert ( playbook << websocket.WebsocketStartHook(flow) @@ -290,15 +281,12 @@ def test_deflate(ws_testdata): >> reply() << SendData(tctx.client, bytes.fromhex("c1 07 f2 48 cd c9 c9 07 00")) ) - assert flow().messages[0].content == "Hello" + assert flow.websocket.messages[0].content == b"Hello" def test_unknown_ext(ws_testdata): - tctx, playbook = ws_testdata - flow = Placeholder(WebSocketFlow) - # noinspection PyUnresolvedReferences - http_flow: HTTPFlow = playbook.layer.flow.handshake_flow - http_flow.response.headers["Sec-WebSocket-Extensions"] = "funky-bits; param=42" + tctx, playbook, flow = ws_testdata + flow.response.headers["Sec-WebSocket-Extensions"] = "funky-bits; param=42" assert ( playbook << Log("Ignoring unknown WebSocket extension 'funky-bits'.") @@ -314,20 +302,20 @@ def test_websocket_connection_repr(tctx): class TestFragmentizer: def test_empty(self): - f = websocket.Fragmentizer([b"foo"]) + f = websocket.Fragmentizer([b"foo"], False) assert list(f(b"")) == [] def test_keep_sizes(self): - f = websocket.Fragmentizer([b"foo", b"bar"]) + f = websocket.Fragmentizer([b"foo", b"bar"], True) assert list(f(b"foobaz")) == [ - wsproto.events.Message(b"foo", message_finished=False), - wsproto.events.Message(b"baz", message_finished=True), + wsproto.events.TextMessage("foo", message_finished=False), + wsproto.events.TextMessage("baz", message_finished=True), ] def test_rechunk(self): - f = websocket.Fragmentizer([b"foo"]) + f = websocket.Fragmentizer([b"foo"], False) f.FRAGMENT_SIZE = 4 assert list(f(b"foobar")) == [ - wsproto.events.Message(b"foob", message_finished=False), - wsproto.events.Message(b"ar", message_finished=True), + wsproto.events.BytesMessage(b"foob", message_finished=False), + wsproto.events.BytesMessage(b"ar", message_finished=True), ] diff --git a/test/mitmproxy/test_eventsequence.py b/test/mitmproxy/test_eventsequence.py index 866168556..cbfbb7087 100644 --- a/test/mitmproxy/test_eventsequence.py +++ b/test/mitmproxy/test_eventsequence.py @@ -27,14 +27,20 @@ def test_http_flow(resp, err): def test_websocket_flow(err): f = tflow.twebsocketflow(err=err) i = eventsequence.iterate(f) + + assert isinstance(next(i), layers.http.HttpRequestHeadersHook) + assert isinstance(next(i), layers.http.HttpRequestHook) + assert isinstance(next(i), layers.http.HttpResponseHeadersHook) + assert isinstance(next(i), layers.http.HttpResponseHook) + assert isinstance(next(i), layers.websocket.WebsocketStartHook) - assert len(f.messages) == 0 + assert len(f.websocket.messages) == 0 assert isinstance(next(i), layers.websocket.WebsocketMessageHook) - assert len(f.messages) == 1 + assert len(f.websocket.messages) == 1 assert isinstance(next(i), layers.websocket.WebsocketMessageHook) - assert len(f.messages) == 2 + assert len(f.websocket.messages) == 2 assert isinstance(next(i), layers.websocket.WebsocketMessageHook) - assert len(f.messages) == 3 + assert len(f.websocket.messages) == 3 if err: assert isinstance(next(i), layers.websocket.WebsocketErrorHook) else: diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index 95428261f..a09ba852f 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -122,20 +122,6 @@ class TestFlowMaster: await ctx.master.load_flow(f) assert s.flows[0].request.host == "use-this-domain" - @pytest.mark.asyncio - async def test_load_websocket_flow(self): - opts = options.Options( - mode="reverse:https://use-this-domain" - ) - s = State() - with taddons.context(s, options=opts) as ctx: - f = tflow.twebsocketflow() - await ctx.master.load_flow(f.handshake_flow) - await ctx.master.load_flow(f) - assert s.flows[0].request.host == "use-this-domain" - assert s.flows[1].handshake_flow == f.handshake_flow - assert len(s.flows[1].messages) == len(f.messages) - @pytest.mark.asyncio async def test_all(self): opts = options.Options( diff --git a/test/mitmproxy/test_flowfilter.py b/test/mitmproxy/test_flowfilter.py index 52e80c27c..243748419 100644 --- a/test/mitmproxy/test_flowfilter.py +++ b/test/mitmproxy/test_flowfilter.py @@ -4,7 +4,7 @@ from unittest.mock import patch from mitmproxy.test import tflow -from mitmproxy import flowfilter +from mitmproxy import flowfilter, http class TestParsing: @@ -424,10 +424,10 @@ class TestMatchingTCPFlow: class TestMatchingWebSocketFlow: - def flow(self): + def flow(self) -> http.HTTPFlow: return tflow.twebsocketflow() - def err(self): + def err(self) -> http.HTTPFlow: return tflow.twebsocketflow(err=True) def q(self, q, o): @@ -437,10 +437,10 @@ class TestMatchingWebSocketFlow: f = self.flow() assert self.q("~websocket", f) assert not self.q("~tcp", f) - assert not self.q("~http", f) + assert self.q("~http", f) def test_handshake(self): - f = self.flow().handshake_flow + f = self.flow() assert self.q("~websocket", f) assert not self.q("~tcp", f) assert self.q("~http", f) @@ -465,9 +465,6 @@ class TestMatchingWebSocketFlow: assert self.q("~u example.com/ws", q) assert not self.q("~u moo/path", q) - q.handshake_flow = None - assert not self.q("~u example.com", q) - def test_body(self): f = self.flow() diff --git a/test/mitmproxy/test_websocket.py b/test/mitmproxy/test_websocket.py index 1ff98962e..26d7b0f6a 100644 --- a/test/mitmproxy/test_websocket.py +++ b/test/mitmproxy/test_websocket.py @@ -1,88 +1,28 @@ -import io - -import pytest - -from mitmproxy import flowfilter -from mitmproxy.exceptions import ControlException -from mitmproxy.io import tnetstring +from mitmproxy import http +from mitmproxy import websocket from mitmproxy.test import tflow +from wsproto.frame_protocol import Opcode -class TestWebSocketFlow: - - def test_copy(self): - f = tflow.twebsocketflow() - f.get_state() - f2 = f.copy() - a = f.get_state() - b = f2.get_state() - del a["id"] - del b["id"] - assert a == b - assert not f == f2 - assert f is not f2 - - assert f.client_key == f2.client_key - assert f.client_protocol == f2.client_protocol - assert f.client_extensions == f2.client_extensions - assert f.server_accept == f2.server_accept - assert f.server_protocol == f2.server_protocol - assert f.server_extensions == f2.server_extensions - assert f.messages is not f2.messages - assert f.handshake_flow is not f2.handshake_flow - - for m in f.messages: - m2 = m.copy() - m2.set_state(m2.get_state()) - assert m is not m2 - assert m.get_state() == m2.get_state() - - f = tflow.twebsocketflow(err=True) - f2 = f.copy() - assert f is not f2 - assert f.handshake_flow is not f2.handshake_flow - assert f.error.get_state() == f2.error.get_state() - assert f.error is not f2.error - - def test_kill(self): - f = tflow.twebsocketflow() - with pytest.raises(ControlException): - f.intercept() - f.resume() - f.kill() - - f = tflow.twebsocketflow() - f.intercept() - assert f.killable - f.kill() - assert not f.killable - - def test_match(self): - f = tflow.twebsocketflow() - assert not flowfilter.match("~b nonexistent", f) - assert flowfilter.match(None, f) - assert not flowfilter.match("~b nonexistent", f) - - f = tflow.twebsocketflow(err=True) - assert flowfilter.match("~e", f) - - with pytest.raises(ValueError): - flowfilter.match("~", f) - +class TestWebSocketData: def test_repr(self): + assert repr(tflow.twebsocketflow().websocket) == "" + + def test_state(self): f = tflow.twebsocketflow() - assert f.message_info(f.messages[0]) - assert 'WebSocketFlow' in repr(f) - assert 'binary message: ' in repr(f.messages[0]) - assert 'text message: ' in repr(f.messages[1]) + f2 = http.HTTPFlow.from_state(f.get_state()) + f2.set_state(f.get_state()) - def test_serialize(self): - b = io.BytesIO() - d = tflow.twebsocketflow().get_state() - tnetstring.dump(d, b) - assert b.getvalue() - b = io.BytesIO() - d = tflow.twebsocketflow().handshake_flow.get_state() - tnetstring.dump(d, b) - assert b.getvalue() +class TestWebSocketMessage: + def test_basic(self): + m = websocket.WebSocketMessage(Opcode.TEXT, True, b"foo") + m.set_state(m.get_state()) + assert m.content == b"foo" + assert repr(m) == "'foo'" + m.type = Opcode.BINARY + assert repr(m) == "b'foo'" + + assert not m.killed + m.kill() + assert m.killed