From 5214f544e7b690dea2a45cb4cda44bbffec9a77e Mon Sep 17 00:00:00 2001 From: Ujjwal Verma Date: Thu, 17 Aug 2017 21:12:07 +0530 Subject: [PATCH] Use wsproto for websockets --- mitmproxy/addons/dumper.py | 2 + mitmproxy/proxy/protocol/websocket.py | 159 ++++++++---------- mitmproxy/tools/console/consoleaddons.py | 2 +- setup.cfg | 9 +- .../proxy/protocol/test_websocket.py | 122 ++++++++++++-- 5 files changed, 185 insertions(+), 109 deletions(-) diff --git a/mitmproxy/addons/dumper.py b/mitmproxy/addons/dumper.py index 54526d5b2..48bc81187 100644 --- a/mitmproxy/addons/dumper.py +++ b/mitmproxy/addons/dumper.py @@ -234,6 +234,8 @@ class Dumper: message = f.messages[-1] self.echo(f.message_info(message)) if ctx.options.flow_detail >= 3: + message = message.from_state(message.get_state()) + message.content = message.content.encode() if isinstance(message.content, str) else message.content self._echo_message(message) def websocket_end(self, f): diff --git a/mitmproxy/proxy/protocol/websocket.py b/mitmproxy/proxy/protocol/websocket.py index d1abd1346..54d8120de 100644 --- a/mitmproxy/proxy/protocol/websocket.py +++ b/mitmproxy/proxy/protocol/websocket.py @@ -1,19 +1,18 @@ -import os import socket -import struct from OpenSSL import SSL from wsproto import events from wsproto.connection import ConnectionType, WSConnection from wsproto.extensions import PerMessageDeflate +from wsproto.frame_protocol import Opcode from mitmproxy import exceptions from mitmproxy import flow from mitmproxy.proxy.protocol import base -from mitmproxy.net import http from mitmproxy.net import tcp from mitmproxy.net import websockets from mitmproxy.websocket import WebSocketFlow, WebSocketMessage +from mitmproxy.utils import strutils class WebSocketLayer(base.Layer): @@ -54,14 +53,16 @@ class WebSocketLayer(base.Layer): extensions = [] if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers: if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']: - extensions = [PerMessageDeflate.name] - + extensions = [PerMessageDeflate()] self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER, extensions=extensions) self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT, host=handshake_flow.request.host, resource=handshake_flow.request.path, extensions=extensions) + if extensions: + for conn in self.connections.values(): + conn.extensions[0].finalize(conn, handshake_flow.response.headers['Sec-WebSocket-Extensions']) data = self.connections[self.server_conn].bytes_to_send() self.connections[self.client_conn].receive_bytes(data) @@ -80,28 +81,78 @@ class WebSocketLayer(base.Layer): return self._handle_ping_received(event, source_conn, other_conn, is_server) elif isinstance(event, events.PongReceived): return self._handle_pong_received(event, source_conn, other_conn, is_server) - elif isinstance(event, events.ConnectionFailed): + elif isinstance(event, events.ConnectionClosed): return self._handle_connection_closed(event, source_conn, other_conn, is_server) - elif isinstance(event, events.ConnectionFailed): - return self._handle_connection_failed(event) # fail-safe for unhandled events - return True + return True # pragma: no cover def _handle_data_received(self, event, source_conn, other_conn, is_server): + fb = self.server_frame_buffer if is_server else self.client_frame_buffer + fb.append(event.data) + + if event.message_finished: + original_chunk_sizes = [len(f) for f in fb] + message_type = Opcode.TEXT if isinstance(event, events.TextReceived) else Opcode.BINARY + if message_type == Opcode.TEXT: + payload = ''.join(fb) + else: + payload = b''.join(fb) + fb.clear() + + websocket_message = WebSocketMessage(message_type, not is_server, payload) + length = len(websocket_message.content) + self.flow.messages.append(websocket_message) + self.channel.ask("websocket_message", self.flow) + + if not self.flow.stream: + def get_chunk(payload): + if len(payload) == length: + # message has the same length, we can reuse the same sizes + pos = 0 + for s in original_chunk_sizes: + yield (payload[pos:pos + s], True if pos + s == length else False) + pos += s + else: + # just re-chunk everything into 4kB frames + # header len = 4 bytes without masking key and 8 bytes with masking key + chunk_size = 4092 if is_server else 4088 + chunks = range(0, len(payload), chunk_size) + for i in chunks: + yield (payload[i:i + chunk_size], True if i + chunk_size >= len(payload) else False) + + for chunk, final in get_chunk(websocket_message.content): + self.connections[other_conn].send_data(chunk, final) + other_conn.send(self.connections[other_conn].bytes_to_send()) + + else: + self.connections[other_conn].send_data(event.data, event.message_finished) + other_conn.send(self.connections[other_conn].bytes_to_send()) + + elif self.flow.stream: + self.connections[other_conn].send_data(event.data, event.message_finished) + other_conn.send(self.connections[other_conn].bytes_to_send()) + return True def _handle_ping_received(self, event, source_conn, other_conn, is_server): # PING is automatically answered with a PONG by wsproto - # TODO: log this PING and its payload - self.connections[other_conn].ping(event.payload) + self.connections[other_conn].ping() other_conn.send(self.connections[other_conn].bytes_to_send()) + source_conn.send(self.connections[source_conn].bytes_to_send()) + self.log( + "Ping Received from {}".format("server" if is_server else "client"), + "info", + [strutils.bytes_to_escaped_str(bytes(event.payload))] + ) return True def _handle_pong_received(self, event, source_conn, other_conn, is_server): - # TODO: log this PONG and its payload - self.connections[other_conn].pong(event.payload) - other_conn.send(self.connections[other_conn].bytes_to_send()) + self.log( + "Pong Received from {}".format("server" if is_server else "client"), + "info", + [strutils.bytes_to_escaped_str(bytes(event.payload))] + ) return True def _handle_connection_closed(self, event, source_conn, other_conn, is_server): @@ -109,80 +160,12 @@ class WebSocketLayer(base.Layer): self.flow.close_code = event.code self.flow.close_reason = event.reason - print(self.connections[other_conn]) self.connections[other_conn].close(event.code, event.reason) + other_conn.send(self.connections[other_conn].bytes_to_send()) + source_conn.send(self.connections[source_conn].bytes_to_send()) - # initiate close handshake return False - def _handle_connection_failed(self, event): - raise exceptions.TcpException(repr(event)) - - # def _handle_data_frame(self, frame, source_conn, other_conn, is_server): - # - # fb = self.server_frame_buffer if is_server else self.client_frame_buffer - # fb.append(frame) - # - # if frame.header.fin: - # payload = b''.join(f.payload for f in fb) - # original_chunk_sizes = [len(f.payload) for f in fb] - # message_type = fb[0].header.opcode - # compressed_message = fb[0].header.rsv1 - # fb.clear() - # - # websocket_message = WebSocketMessage(message_type, not is_server, payload) - # length = len(websocket_message.content) - # self.flow.messages.append(websocket_message) - # self.channel.ask("websocket_message", self.flow) - # - # if not self.flow.stream: - # def get_chunk(payload): - # if len(payload) == length: - # # message has the same length, we can reuse the same sizes - # pos = 0 - # for s in original_chunk_sizes: - # yield payload[pos:pos + s] - # pos += s - # else: - # # just re-chunk everything into 4kB frames - # # header len = 4 bytes without masking key and 8 bytes with masking key - # chunk_size = 4092 if is_server else 4088 - # chunks = range(0, len(payload), chunk_size) - # for i in chunks: - # yield payload[i:i + chunk_size] - # - # frms = [ - # websockets.Frame( - # payload=chunk, - # opcode=frame.header.opcode, - # mask=(False if is_server else 1), - # masking_key=(b'' if is_server else os.urandom(4))) - # for chunk in get_chunk(websocket_message.content) - # ] - # - # if len(frms) > 0: - # frms[-1].header.fin = True - # else: - # frms.append(websockets.Frame( - # fin=True, - # opcode=websockets.OPCODE.CONTINUE, - # mask=(False if is_server else 1), - # masking_key=(b'' if is_server else os.urandom(4)))) - # - # frms[0].header.opcode = message_type - # frms[0].header.rsv1 = compressed_message - # - # for frm in frms: - # other_conn.send(bytes(frm)) - # - # else: - # other_conn.send(bytes(frame)) - # - # elif self.flow.stream: - # other_conn.send(bytes(frame)) - # - # return True - def __call__(self): self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self) self.flow.metadata['websocket_handshake'] = self.handshake_flow.id @@ -204,12 +187,12 @@ class WebSocketLayer(base.Layer): self.connections[source_conn].receive_bytes(bytes(frame)) source_conn.send(self.connections[source_conn].bytes_to_send()) + if close_received: + return + for event in self.connections[source_conn].events(): - print('is_server:', is_server, 'event:', event) if not self._handle_event(event, source_conn, other_conn, is_server): - if close_received: - break - else: + if not close_received: close_received = True except (socket.error, exceptions.TcpException, SSL.Error) as e: s = 'server' if is_server else 'client' diff --git a/mitmproxy/tools/console/consoleaddons.py b/mitmproxy/tools/console/consoleaddons.py index 1bda219f3..8233d45e4 100644 --- a/mitmproxy/tools/console/consoleaddons.py +++ b/mitmproxy/tools/console/consoleaddons.py @@ -49,7 +49,7 @@ class UnsupportedLog: def websocket_message(self, f): message = f.messages[-1] signals.add_log(f.message_info(message), "info") - signals.add_log(strutils.bytes_to_escaped_str(message.content), "debug") + signals.add_log(message.content if isinstance(message.content, str) else strutils.bytes_to_escaped_str(message.content), "debug") def websocket_end(self, f): signals.add_log("WebSocket connection closed by {}: {} {}, {}".format( diff --git a/setup.cfg b/setup.cfg index eaabfa12c..fd31d15b5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,7 +21,13 @@ exclude_lines = [tool:full_coverage] exclude = - mitmproxy/proxy/protocol/ + mitmproxy/proxy/protocol/base.py + mitmproxy/proxy/protocol/http.py + mitmproxy/proxy/protocol/http1.py + mitmproxy/proxy/protocol/http2.py + mitmproxy/proxy/protocol/http_replay.py + mitmproxy/proxy/protocol/rawtcp.py + mitmproxy/proxy/protocol/tls.py mitmproxy/proxy/root_context.py mitmproxy/proxy/server.py mitmproxy/tools/ @@ -64,7 +70,6 @@ exclude = mitmproxy/proxy/protocol/http_replay.py mitmproxy/proxy/protocol/rawtcp.py mitmproxy/proxy/protocol/tls.py - mitmproxy/proxy/protocol/websocket.py mitmproxy/proxy/root_context.py mitmproxy/proxy/server.py mitmproxy/stateobject.py diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py index 14dd74056..a7acdc4db 100644 --- a/test/mitmproxy/proxy/protocol/test_websocket.py +++ b/test/mitmproxy/proxy/protocol/test_websocket.py @@ -1,5 +1,6 @@ import pytest import os +import struct import tempfile import traceback @@ -33,6 +34,7 @@ class _WebSocketServerBase(net_tservers.ServerTestBase): connection='upgrade', upgrade='websocket', sec_websocket_accept=b'', + sec_websocket_extensions='permessage-deflate' if "permessage-deflate" in request.headers.values() else '' ), content=b'', ) @@ -80,7 +82,7 @@ class _WebSocketTestBase: if self.client: self.client.close() - def setup_connection(self): + def setup_connection(self, extension=False): self.client = tcp.TCPClient(("127.0.0.1", self.proxy.port)) self.client.connect() @@ -115,6 +117,7 @@ class _WebSocketTestBase: upgrade="websocket", sec_websocket_version="13", sec_websocket_key="1234", + sec_websocket_extensions="permessage-deflate" if extension else "" ), content=b'') self.client.wfile.write(http.http1.assemble_request(request)) @@ -145,11 +148,11 @@ class TestSimple(_WebSocketTest): wfile.flush() frame = websockets.Frame.from_file(rfile) - wfile.write(bytes(frame)) + wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) wfile.flush() frame = websockets.Frame.from_file(rfile) - wfile.write(bytes(frame)) + wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) wfile.flush() @pytest.mark.parametrize('streaming', [True, False]) @@ -183,17 +186,40 @@ class TestSimple(_WebSocketTest): assert isinstance(self.master.state.flows[0], HTTPFlow) assert isinstance(self.master.state.flows[1], WebSocketFlow) assert len(self.master.state.flows[1].messages) == 5 - assert self.master.state.flows[1].messages[0].content == b'server-foobar' + assert self.master.state.flows[1].messages[0].content == 'server-foobar' assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT - assert self.master.state.flows[1].messages[1].content == b'self.client-foobar' + assert self.master.state.flows[1].messages[1].content == 'self.client-foobar' assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT - assert self.master.state.flows[1].messages[2].content == b'self.client-foobar' + assert self.master.state.flows[1].messages[2].content == 'self.client-foobar' assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef' assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef' assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY + def test_change_payload(self): + class Addon: + def websocket_message(self, f): + f.messages[-1].content = "foo" + + self.master.addons.add(Addon()) + self.setup_connection() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.payload == b'foo' + + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) + self.client.wfile.flush() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.payload == b'foo' + + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) + self.client.wfile.flush() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.payload == b'foo' + class TestSimpleTLS(_WebSocketTest): ssl = True @@ -204,7 +230,7 @@ class TestSimpleTLS(_WebSocketTest): wfile.flush() frame = websockets.Frame.from_file(rfile) - wfile.write(bytes(frame)) + wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) wfile.flush() def test_simple_tls(self): @@ -237,19 +263,21 @@ class TestPing(_WebSocketTest): wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=b'done'))) wfile.flush() + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + wfile.flush() + websockets.Frame.from_file(rfile) + def test_ping(self): self.setup_connection() frame = websockets.Frame.from_file(self.client.rfile) - assert frame.header.opcode == websockets.OPCODE.PING - assert frame.payload == b'foobar' - - self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) + websockets.Frame.from_file(self.client.rfile) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) self.client.wfile.flush() + assert frame.header.opcode == websockets.OPCODE.PING + assert frame.payload == b'' # We don't send payload to other end - frame = websockets.Frame.from_file(self.client.rfile) - assert frame.header.opcode == websockets.OPCODE.PONG - assert frame.payload == b'done' + assert self.master.has_log("Pong Received from server", "info") class TestPong(_WebSocketTest): @@ -258,11 +286,15 @@ class TestPong(_WebSocketTest): def handle_websockets(cls, rfile, wfile): frame = websockets.Frame.from_file(rfile) assert frame.header.opcode == websockets.OPCODE.PING - assert frame.payload == b'foobar' + assert frame.payload == b'' wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) wfile.flush() + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + wfile.flush() + websockets.Frame.from_file(rfile) + def test_pong(self): self.setup_connection() @@ -270,8 +302,13 @@ class TestPong(_WebSocketTest): self.client.wfile.flush() frame = websockets.Frame.from_file(self.client.rfile) + websockets.Frame.from_file(self.client.rfile) + self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.flush() + assert frame.header.opcode == websockets.OPCODE.PONG assert frame.payload == b'foobar' + assert self.master.has_log("Pong Received from server", "info") class TestClose(_WebSocketTest): @@ -279,7 +316,7 @@ class TestClose(_WebSocketTest): @classmethod def handle_websockets(cls, rfile, wfile): frame = websockets.Frame.from_file(rfile) - wfile.write(bytes(frame)) + wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload))) wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) wfile.flush() @@ -329,8 +366,9 @@ class TestInvalidFrame(_WebSocketTest): # with pytest.raises(exceptions.TcpDisconnect): frame = websockets.Frame.from_file(self.client.rfile) - assert frame.header.opcode == 15 - assert frame.payload == b'foobar' + code, = struct.unpack('!H', frame.payload[:2]) + assert code == 1002 + assert frame.payload[2:].startswith(b'Invalid opcode') class TestStreaming(_WebSocketTest): @@ -360,3 +398,51 @@ class TestStreaming(_WebSocketTest): assert frame assert self.master.state.flows[1].messages == [] # Message not appended as the final frame isn't received + + +class TestExtension(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(b'\xc1\x0f*N-*K-\xd2M\xcb\xcfOJ,\x02\x00') + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + assert frame.header.rsv1 + wfile.write(b'\xc1\nJ\xce\xc9L\xcd+\x81r\x00\x00') + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + assert frame.header.rsv1 + wfile.write(b'\xc2\x07\xba\xb7v\xdf{\x00\x00') + wfile.flush() + + def test_extension(self): + self.setup_connection(True) + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.header.rsv1 + + self.client.wfile.write(b'\xc1\x8fQ\xb7vX\x1by\xbf\x14\x9c\x9c\xa7\x15\x9ax9\x12}\xb5v') + self.client.wfile.flush() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.header.rsv1 + + self.client.wfile.write(b'\xc2\x87\xeb\xbb\x0csQ\x0cz\xac\x90\xbb\x0c') + self.client.wfile.flush() + + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.header.rsv1 + + assert len(self.master.state.flows[1].messages) == 5 + assert self.master.state.flows[1].messages[0].content == 'server-foobar' + assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT + assert self.master.state.flows[1].messages[1].content == 'client-foobar' + assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT + assert self.master.state.flows[1].messages[2].content == 'client-foobar' + assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT + assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef' + assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY + assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef' + assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY