Use wsproto for websockets

This commit is contained in:
Ujjwal Verma 2017-08-17 21:12:07 +05:30 committed by Thomas Kriechbaumer
parent 130021b76d
commit 5214f544e7
5 changed files with 185 additions and 109 deletions

View File

@ -234,6 +234,8 @@ class Dumper:
message = f.messages[-1]
self.echo(f.message_info(message))
if ctx.options.flow_detail >= 3:
message = message.from_state(message.get_state())
message.content = message.content.encode() if isinstance(message.content, str) else message.content
self._echo_message(message)
def websocket_end(self, f):

View File

@ -1,19 +1,18 @@
import os
import socket
import struct
from OpenSSL import SSL
from wsproto import events
from wsproto.connection import ConnectionType, WSConnection
from wsproto.extensions import PerMessageDeflate
from wsproto.frame_protocol import Opcode
from mitmproxy import exceptions
from mitmproxy import flow
from mitmproxy.proxy.protocol import base
from mitmproxy.net import http
from mitmproxy.net import tcp
from mitmproxy.net import websockets
from mitmproxy.websocket import WebSocketFlow, WebSocketMessage
from mitmproxy.utils import strutils
class WebSocketLayer(base.Layer):
@ -54,14 +53,16 @@ class WebSocketLayer(base.Layer):
extensions = []
if 'Sec-WebSocket-Extensions' in handshake_flow.response.headers:
if PerMessageDeflate.name in handshake_flow.response.headers['Sec-WebSocket-Extensions']:
extensions = [PerMessageDeflate.name]
extensions = [PerMessageDeflate()]
self.connections[self.client_conn] = WSConnection(ConnectionType.SERVER,
extensions=extensions)
self.connections[self.server_conn] = WSConnection(ConnectionType.CLIENT,
host=handshake_flow.request.host,
resource=handshake_flow.request.path,
extensions=extensions)
if extensions:
for conn in self.connections.values():
conn.extensions[0].finalize(conn, handshake_flow.response.headers['Sec-WebSocket-Extensions'])
data = self.connections[self.server_conn].bytes_to_send()
self.connections[self.client_conn].receive_bytes(data)
@ -80,28 +81,78 @@ class WebSocketLayer(base.Layer):
return self._handle_ping_received(event, source_conn, other_conn, is_server)
elif isinstance(event, events.PongReceived):
return self._handle_pong_received(event, source_conn, other_conn, is_server)
elif isinstance(event, events.ConnectionFailed):
elif isinstance(event, events.ConnectionClosed):
return self._handle_connection_closed(event, source_conn, other_conn, is_server)
elif isinstance(event, events.ConnectionFailed):
return self._handle_connection_failed(event)
# fail-safe for unhandled events
return True
return True # pragma: no cover
def _handle_data_received(self, event, source_conn, other_conn, is_server):
fb = self.server_frame_buffer if is_server else self.client_frame_buffer
fb.append(event.data)
if event.message_finished:
original_chunk_sizes = [len(f) for f in fb]
message_type = Opcode.TEXT if isinstance(event, events.TextReceived) else Opcode.BINARY
if message_type == Opcode.TEXT:
payload = ''.join(fb)
else:
payload = b''.join(fb)
fb.clear()
websocket_message = WebSocketMessage(message_type, not is_server, payload)
length = len(websocket_message.content)
self.flow.messages.append(websocket_message)
self.channel.ask("websocket_message", self.flow)
if not self.flow.stream:
def get_chunk(payload):
if len(payload) == length:
# message has the same length, we can reuse the same sizes
pos = 0
for s in original_chunk_sizes:
yield (payload[pos:pos + s], True if pos + s == length else False)
pos += s
else:
# just re-chunk everything into 4kB frames
# header len = 4 bytes without masking key and 8 bytes with masking key
chunk_size = 4092 if is_server else 4088
chunks = range(0, len(payload), chunk_size)
for i in chunks:
yield (payload[i:i + chunk_size], True if i + chunk_size >= len(payload) else False)
for chunk, final in get_chunk(websocket_message.content):
self.connections[other_conn].send_data(chunk, final)
other_conn.send(self.connections[other_conn].bytes_to_send())
else:
self.connections[other_conn].send_data(event.data, event.message_finished)
other_conn.send(self.connections[other_conn].bytes_to_send())
elif self.flow.stream:
self.connections[other_conn].send_data(event.data, event.message_finished)
other_conn.send(self.connections[other_conn].bytes_to_send())
return True
def _handle_ping_received(self, event, source_conn, other_conn, is_server):
# PING is automatically answered with a PONG by wsproto
# TODO: log this PING and its payload
self.connections[other_conn].ping(event.payload)
self.connections[other_conn].ping()
other_conn.send(self.connections[other_conn].bytes_to_send())
source_conn.send(self.connections[source_conn].bytes_to_send())
self.log(
"Ping Received from {}".format("server" if is_server else "client"),
"info",
[strutils.bytes_to_escaped_str(bytes(event.payload))]
)
return True
def _handle_pong_received(self, event, source_conn, other_conn, is_server):
# TODO: log this PONG and its payload
self.connections[other_conn].pong(event.payload)
other_conn.send(self.connections[other_conn].bytes_to_send())
self.log(
"Pong Received from {}".format("server" if is_server else "client"),
"info",
[strutils.bytes_to_escaped_str(bytes(event.payload))]
)
return True
def _handle_connection_closed(self, event, source_conn, other_conn, is_server):
@ -109,80 +160,12 @@ class WebSocketLayer(base.Layer):
self.flow.close_code = event.code
self.flow.close_reason = event.reason
print(self.connections[other_conn])
self.connections[other_conn].close(event.code, event.reason)
other_conn.send(self.connections[other_conn].bytes_to_send())
source_conn.send(self.connections[source_conn].bytes_to_send())
# initiate close handshake
return False
def _handle_connection_failed(self, event):
raise exceptions.TcpException(repr(event))
# def _handle_data_frame(self, frame, source_conn, other_conn, is_server):
#
# fb = self.server_frame_buffer if is_server else self.client_frame_buffer
# fb.append(frame)
#
# if frame.header.fin:
# payload = b''.join(f.payload for f in fb)
# original_chunk_sizes = [len(f.payload) for f in fb]
# message_type = fb[0].header.opcode
# compressed_message = fb[0].header.rsv1
# fb.clear()
#
# websocket_message = WebSocketMessage(message_type, not is_server, payload)
# length = len(websocket_message.content)
# self.flow.messages.append(websocket_message)
# self.channel.ask("websocket_message", self.flow)
#
# if not self.flow.stream:
# def get_chunk(payload):
# if len(payload) == length:
# # message has the same length, we can reuse the same sizes
# pos = 0
# for s in original_chunk_sizes:
# yield payload[pos:pos + s]
# pos += s
# else:
# # just re-chunk everything into 4kB frames
# # header len = 4 bytes without masking key and 8 bytes with masking key
# chunk_size = 4092 if is_server else 4088
# chunks = range(0, len(payload), chunk_size)
# for i in chunks:
# yield payload[i:i + chunk_size]
#
# frms = [
# websockets.Frame(
# payload=chunk,
# opcode=frame.header.opcode,
# mask=(False if is_server else 1),
# masking_key=(b'' if is_server else os.urandom(4)))
# for chunk in get_chunk(websocket_message.content)
# ]
#
# if len(frms) > 0:
# frms[-1].header.fin = True
# else:
# frms.append(websockets.Frame(
# fin=True,
# opcode=websockets.OPCODE.CONTINUE,
# mask=(False if is_server else 1),
# masking_key=(b'' if is_server else os.urandom(4))))
#
# frms[0].header.opcode = message_type
# frms[0].header.rsv1 = compressed_message
#
# for frm in frms:
# other_conn.send(bytes(frm))
#
# else:
# other_conn.send(bytes(frame))
#
# elif self.flow.stream:
# other_conn.send(bytes(frame))
#
# return True
def __call__(self):
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
self.flow.metadata['websocket_handshake'] = self.handshake_flow.id
@ -204,12 +187,12 @@ class WebSocketLayer(base.Layer):
self.connections[source_conn].receive_bytes(bytes(frame))
source_conn.send(self.connections[source_conn].bytes_to_send())
if close_received:
return
for event in self.connections[source_conn].events():
print('is_server:', is_server, 'event:', event)
if not self._handle_event(event, source_conn, other_conn, is_server):
if close_received:
break
else:
if not close_received:
close_received = True
except (socket.error, exceptions.TcpException, SSL.Error) as e:
s = 'server' if is_server else 'client'

View File

@ -49,7 +49,7 @@ class UnsupportedLog:
def websocket_message(self, f):
message = f.messages[-1]
signals.add_log(f.message_info(message), "info")
signals.add_log(strutils.bytes_to_escaped_str(message.content), "debug")
signals.add_log(message.content if isinstance(message.content, str) else strutils.bytes_to_escaped_str(message.content), "debug")
def websocket_end(self, f):
signals.add_log("WebSocket connection closed by {}: {} {}, {}".format(

View File

@ -21,7 +21,13 @@ exclude_lines =
[tool:full_coverage]
exclude =
mitmproxy/proxy/protocol/
mitmproxy/proxy/protocol/base.py
mitmproxy/proxy/protocol/http.py
mitmproxy/proxy/protocol/http1.py
mitmproxy/proxy/protocol/http2.py
mitmproxy/proxy/protocol/http_replay.py
mitmproxy/proxy/protocol/rawtcp.py
mitmproxy/proxy/protocol/tls.py
mitmproxy/proxy/root_context.py
mitmproxy/proxy/server.py
mitmproxy/tools/
@ -64,7 +70,6 @@ exclude =
mitmproxy/proxy/protocol/http_replay.py
mitmproxy/proxy/protocol/rawtcp.py
mitmproxy/proxy/protocol/tls.py
mitmproxy/proxy/protocol/websocket.py
mitmproxy/proxy/root_context.py
mitmproxy/proxy/server.py
mitmproxy/stateobject.py

View File

@ -1,5 +1,6 @@
import pytest
import os
import struct
import tempfile
import traceback
@ -33,6 +34,7 @@ class _WebSocketServerBase(net_tservers.ServerTestBase):
connection='upgrade',
upgrade='websocket',
sec_websocket_accept=b'',
sec_websocket_extensions='permessage-deflate' if "permessage-deflate" in request.headers.values() else ''
),
content=b'',
)
@ -80,7 +82,7 @@ class _WebSocketTestBase:
if self.client:
self.client.close()
def setup_connection(self):
def setup_connection(self, extension=False):
self.client = tcp.TCPClient(("127.0.0.1", self.proxy.port))
self.client.connect()
@ -115,6 +117,7 @@ class _WebSocketTestBase:
upgrade="websocket",
sec_websocket_version="13",
sec_websocket_key="1234",
sec_websocket_extensions="permessage-deflate" if extension else ""
),
content=b'')
self.client.wfile.write(http.http1.assemble_request(request))
@ -145,11 +148,11 @@ class TestSimple(_WebSocketTest):
wfile.flush()
frame = websockets.Frame.from_file(rfile)
wfile.write(bytes(frame))
wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
wfile.flush()
frame = websockets.Frame.from_file(rfile)
wfile.write(bytes(frame))
wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
wfile.flush()
@pytest.mark.parametrize('streaming', [True, False])
@ -183,17 +186,40 @@ class TestSimple(_WebSocketTest):
assert isinstance(self.master.state.flows[0], HTTPFlow)
assert isinstance(self.master.state.flows[1], WebSocketFlow)
assert len(self.master.state.flows[1].messages) == 5
assert self.master.state.flows[1].messages[0].content == b'server-foobar'
assert self.master.state.flows[1].messages[0].content == 'server-foobar'
assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT
assert self.master.state.flows[1].messages[1].content == b'self.client-foobar'
assert self.master.state.flows[1].messages[1].content == 'self.client-foobar'
assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT
assert self.master.state.flows[1].messages[2].content == b'self.client-foobar'
assert self.master.state.flows[1].messages[2].content == 'self.client-foobar'
assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT
assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef'
assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY
assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef'
assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY
def test_change_payload(self):
class Addon:
def websocket_message(self, f):
f.messages[-1].content = "foo"
self.master.addons.add(Addon())
self.setup_connection()
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'foo'
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'foo'
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'foo'
class TestSimpleTLS(_WebSocketTest):
ssl = True
@ -204,7 +230,7 @@ class TestSimpleTLS(_WebSocketTest):
wfile.flush()
frame = websockets.Frame.from_file(rfile)
wfile.write(bytes(frame))
wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
wfile.flush()
def test_simple_tls(self):
@ -237,19 +263,21 @@ class TestPing(_WebSocketTest):
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=b'done')))
wfile.flush()
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
wfile.flush()
websockets.Frame.from_file(rfile)
def test_ping(self):
self.setup_connection()
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.header.opcode == websockets.OPCODE.PING
assert frame.payload == b'foobar'
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
websockets.Frame.from_file(self.client.rfile)
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
self.client.wfile.flush()
assert frame.header.opcode == websockets.OPCODE.PING
assert frame.payload == b'' # We don't send payload to other end
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.header.opcode == websockets.OPCODE.PONG
assert frame.payload == b'done'
assert self.master.has_log("Pong Received from server", "info")
class TestPong(_WebSocketTest):
@ -258,11 +286,15 @@ class TestPong(_WebSocketTest):
def handle_websockets(cls, rfile, wfile):
frame = websockets.Frame.from_file(rfile)
assert frame.header.opcode == websockets.OPCODE.PING
assert frame.payload == b'foobar'
assert frame.payload == b''
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
wfile.flush()
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
wfile.flush()
websockets.Frame.from_file(rfile)
def test_pong(self):
self.setup_connection()
@ -270,8 +302,13 @@ class TestPong(_WebSocketTest):
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
websockets.Frame.from_file(self.client.rfile)
self.client.wfile.write(bytes(websockets.Frame(fin=1, mask=1, opcode=websockets.OPCODE.CLOSE)))
self.client.wfile.flush()
assert frame.header.opcode == websockets.OPCODE.PONG
assert frame.payload == b'foobar'
assert self.master.has_log("Pong Received from server", "info")
class TestClose(_WebSocketTest):
@ -279,7 +316,7 @@ class TestClose(_WebSocketTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
frame = websockets.Frame.from_file(rfile)
wfile.write(bytes(frame))
wfile.write(bytes(websockets.Frame(fin=1, opcode=frame.header.opcode, payload=frame.payload)))
wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
wfile.flush()
@ -329,8 +366,9 @@ class TestInvalidFrame(_WebSocketTest):
# with pytest.raises(exceptions.TcpDisconnect):
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.header.opcode == 15
assert frame.payload == b'foobar'
code, = struct.unpack('!H', frame.payload[:2])
assert code == 1002
assert frame.payload[2:].startswith(b'Invalid opcode')
class TestStreaming(_WebSocketTest):
@ -360,3 +398,51 @@ class TestStreaming(_WebSocketTest):
assert frame
assert self.master.state.flows[1].messages == [] # Message not appended as the final frame isn't received
class TestExtension(_WebSocketTest):
@classmethod
def handle_websockets(cls, rfile, wfile):
wfile.write(b'\xc1\x0f*N-*K-\xd2M\xcb\xcfOJ,\x02\x00')
wfile.flush()
frame = websockets.Frame.from_file(rfile)
assert frame.header.rsv1
wfile.write(b'\xc1\nJ\xce\xc9L\xcd+\x81r\x00\x00')
wfile.flush()
frame = websockets.Frame.from_file(rfile)
assert frame.header.rsv1
wfile.write(b'\xc2\x07\xba\xb7v\xdf{\x00\x00')
wfile.flush()
def test_extension(self):
self.setup_connection(True)
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.header.rsv1
self.client.wfile.write(b'\xc1\x8fQ\xb7vX\x1by\xbf\x14\x9c\x9c\xa7\x15\x9ax9\x12}\xb5v')
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.header.rsv1
self.client.wfile.write(b'\xc2\x87\xeb\xbb\x0csQ\x0cz\xac\x90\xbb\x0c')
self.client.wfile.flush()
frame = websockets.Frame.from_file(self.client.rfile)
assert frame.header.rsv1
assert len(self.master.state.flows[1].messages) == 5
assert self.master.state.flows[1].messages[0].content == 'server-foobar'
assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT
assert self.master.state.flows[1].messages[1].content == 'client-foobar'
assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT
assert self.master.state.flows[1].messages[2].content == 'client-foobar'
assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT
assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef'
assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY
assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef'
assert self.master.state.flows[1].messages[4].type == websockets.OPCODE.BINARY