From 3e0a71ea345131a5f2dcc9581a7d93b8ebe09b13 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 21 Apr 2015 22:39:45 +1200 Subject: [PATCH] websockets: refactor to use http and header functions in http.py --- netlib/http.py | 126 ++++++++++++++++++++++------------------ netlib/websockets.py | 108 +++++++++++----------------------- test/test_websockets.py | 112 +++++++++++++++-------------------- 3 files changed, 152 insertions(+), 194 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index b925fe874..fe27240a8 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -4,7 +4,7 @@ import string import urlparse import binascii import sys -from . import odict, utils, tcp +from . import odict, utils, tcp, http_status class HttpError(Exception): @@ -314,62 +314,6 @@ def parse_response_line(line): return (proto, code, msg) -Response = collections.namedtuple( - "Response", - [ - "httpversion", - "code", - "msg", - "headers", - "content" - ] -) - - -def read_response(rfile, request_method, body_size_limit, include_body=True): - """ - Return an (httpversion, code, msg, headers, content) tuple. - - By default, both response header and body are read. - If include_body=False is specified, content may be one of the - following: - - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a - response to a HEAD request) - """ - line = rfile.readline() - # Possible leftover from previous message - if line == "\r\n" or line == "\n": - line = rfile.readline() - if not line: - raise HttpErrorConnClosed(502, "Server disconnect.") - parts = parse_response_line(line) - if not parts: - raise HttpError(502, "Invalid server response: %s" % repr(line)) - proto, code, msg = parts - httpversion = parse_http_protocol(proto) - if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) - headers = read_headers(rfile) - if headers is None: - raise HttpError(502, "Invalid headers.") - - if include_body: - content = read_http_body( - rfile, - headers, - body_size_limit, - request_method, - code, - False - ) - else: - # if include_body==False then a None content means the body should be - # read separately - content = None - return Response(httpversion, code, msg, headers, content) - - def read_http_body(*args, **kwargs): return "".join( content for _, content, _ in read_http_body_chunked(*args, **kwargs) @@ -579,3 +523,71 @@ def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): headers, content ) + + +Response = collections.namedtuple( + "Response", + [ + "httpversion", + "code", + "msg", + "headers", + "content" + ] +) + + +def read_response(rfile, request_method, body_size_limit, include_body=True): + """ + Return an (httpversion, code, msg, headers, content) tuple. + + By default, both response header and body are read. + If include_body=False is specified, content may be one of the + following: + - None, if the response is technically allowed to have a response body + - "", if the response must not have a response body (e.g. it's a + response to a HEAD request) + """ + line = rfile.readline() + # Possible leftover from previous message + if line == "\r\n" or line == "\n": + line = rfile.readline() + if not line: + raise HttpErrorConnClosed(502, "Server disconnect.") + parts = parse_response_line(line) + if not parts: + raise HttpError(502, "Invalid server response: %s" % repr(line)) + proto, code, msg = parts + httpversion = parse_http_protocol(proto) + if httpversion is None: + raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) + headers = read_headers(rfile) + if headers is None: + raise HttpError(502, "Invalid headers.") + + if include_body: + content = read_http_body( + rfile, + headers, + body_size_limit, + request_method, + code, + False + ) + else: + # if include_body==False then a None content means the body should be + # read separately + content = None + return Response(httpversion, code, msg, headers, content) + + +def request_preamble(method, resource, http_major="1", http_minor="1"): + return '%s %s HTTP/%s.%s' % ( + method, resource, http_major, http_minor + ) + + +def response_preamble(code, message=None, http_major="1", http_minor="1"): + if message is None: + message = http_status.RESPONSES.get(code) + return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) diff --git a/netlib/websockets.py b/netlib/websockets.py index f2d467a56..a03185fae 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -2,13 +2,11 @@ from __future__ import absolute_import import base64 import hashlib -import mimetools -import StringIO import os import struct import io -from . import utils +from . import utils, odict # Colleciton of utility functions that implement small portions of the RFC6455 # WebSockets Protocol Useful for building WebSocket clients and servers. @@ -23,6 +21,7 @@ from . import utils # The magic sha that websocket servers must know to prove they understand # RFC6455 websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +VERSION = "13" class CONST(object): @@ -151,9 +150,9 @@ class Frame(object): ("opcode - " + str(self.opcode)), ("mask_bit - " + str(self.mask_bit)), ("payload_length_code - " + str(self.payload_length_code)), - ("masking_key - " + str(self.masking_key)), - ("payload - " + str(self.payload)), - ("decoded_payload - " + str(self.decoded_payload)), + ("masking_key - " + repr(str(self.masking_key))), + ("payload - " + repr(str(self.payload))), + ("decoded_payload - " + repr(str(self.decoded_payload))), ("actual_payload_length - " + str(self.actual_payload_length)) ]) @@ -198,24 +197,24 @@ class Frame(object): second_byte = (self.mask_bit << 7) | self.payload_length_code - bytes = chr(first_byte) + chr(second_byte) + b = chr(first_byte) + chr(second_byte) if self.actual_payload_length < 126: pass elif self.actual_payload_length < CONST.MAX_16_BIT_INT: # '!H' pack as 16 bit unsigned short # add 2 byte extended payload length - bytes += struct.pack('!H', self.actual_payload_length) + b += struct.pack('!H', self.actual_payload_length) elif self.actual_payload_length < CONST.MAX_64_BIT_INT: # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length - bytes += struct.pack('!Q', self.actual_payload_length) + b += struct.pack('!Q', self.actual_payload_length) if self.masking_key is not None: - bytes += self.masking_key + b += self.masking_key - bytes += self.payload # already will be encoded if neccessary - return bytes + b += self.payload # already will be encoded if neccessary + return b def to_file(self, writer): writer.write(self.to_bytes()) @@ -313,58 +312,35 @@ def random_masking_key(): return os.urandom(4) -def create_client_handshake(host, port, key, version, resource): +def client_handshake_headers(key=None, version=VERSION): """ - WebSockets connections are intiated by the client with a valid HTTP - upgrade request + Create the headers for a valid HTTP upgrade request. If Key is not + specified, it is generated, and can be found in sec-websocket-key in + the returned header set. + + Returns an instance of ODictCaseless """ - headers = [ - ('Host', '%s:%s' % (host, port)), + if not key: + key = base64.b64encode(os.urandom(16)).decode('utf-8') + return odict.ODictCaseless([ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), ('Sec-WebSocket-Key', key), ('Sec-WebSocket-Version', version) - ] - request = "GET %s HTTP/1.1" % resource - return build_handshake(headers, request) + ]) -def create_server_handshake(key): +def server_handshake_headers(key): """ The server response is a valid HTTP 101 response. """ - headers = [ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - ('Sec-WebSocket-Accept', create_server_nonce(key)) - ] - request = "HTTP/1.1 101 Switching Protocols" - return build_handshake(headers, request) - - -def build_handshake(headers, request): - handshake = [request.encode('utf-8')] - for header, value in headers: - handshake.append(("%s: %s" % (header, value)).encode('utf-8')) - handshake.append(b'\r\n') - return b'\r\n'.join(handshake) - - -def read_handshake(reader, num_bytes_per_read): - """ - From provided function that reads bytes, read in a - complete HTTP request, which terminates with a CLRF - """ - response = b'' - doubleCLRF = b'\r\n\r\n' - while True: - bytes = reader.read(num_bytes_per_read) - if not bytes: - break - response += bytes - if doubleCLRF in response: - break - return response + return odict.ODictCaseless( + [ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Accept', create_server_nonce(key)) + ] + ) def get_payload_length_pair(payload_bytestring): @@ -384,33 +360,19 @@ def get_payload_length_pair(payload_bytestring): return (length_code, actual_length) -def process_handshake_from_client(handshake): - headers = headers_from_http_message(handshake) - if headers.get("Upgrade", None) != "websocket": +def check_client_handshake(req): + if req.headers.get_first("upgrade", None) != "websocket": return - key = headers['Sec-WebSocket-Key'] - return key + return req.headers.get_first('sec-websocket-key') -def process_handshake_from_server(handshake): - headers = headers_from_http_message(handshake) - if headers.get("Upgrade", None) != "websocket": +def check_server_handshake(resp): + if resp.headers.get_first("upgrade", None) != "websocket": return - key = headers['Sec-WebSocket-Accept'] - return key - - -def headers_from_http_message(http_message): - return mimetools.Message( - StringIO.StringIO(http_message.split('\r\n', 1)[1]) - ) + return resp.headers.get_first('sec-websocket-accept') def create_server_nonce(client_nonce): return base64.b64encode( hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') ) - - -def create_client_nonce(): - return base64.b64encode(os.urandom(16)).decode('utf-8') diff --git a/test/test_websockets.py b/test/test_websockets.py index 1f2025bfb..9b27e810d 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -1,6 +1,4 @@ -from netlib import tcp -from netlib import test -from netlib import websockets +from netlib import tcp, test, websockets, http, odict import io import os from nose.tools import raises @@ -21,18 +19,20 @@ class WebSocketsEchoHandler(tcp.BaseHandler): self.read_next_message() def read_next_message(self): - decoded = websockets.Frame.from_file(self.rfile).decoded_payload - self.on_message(decoded) + frame = websockets.Frame.from_file(self.rfile) + self.on_message(frame.decoded_payload) def send_message(self, message): frame = websockets.Frame.default(message, from_client = False) frame.to_file(self.wfile) def handshake(self): - client_hs = websockets.read_handshake(self.rfile, 1) - key = websockets.process_handshake_from_client(client_hs) - response = websockets.create_server_handshake(key) - self.wfile.write(response) + req = http.read_request(self.rfile) + key = websockets.check_client_handshake(req) + + self.wfile.write(http.response_preamble(101) + "\r\n") + headers = websockets.server_handshake_headers(key) + self.wfile.write(headers.format() + "\r\n") self.wfile.flush() self.handshake_done = True @@ -44,28 +44,20 @@ class WebSocketsEchoHandler(tcp.BaseHandler): class WebSocketsClient(tcp.TCPClient): def __init__(self, address, source_address=None): super(WebSocketsClient, self).__init__(address, source_address) - self.version = "13" - self.client_nonce = websockets.create_client_nonce() - self.resource = "/" + self.client_nonce = None def connect(self): super(WebSocketsClient, self).connect() - handshake = websockets.create_client_handshake( - self.address.host, - self.address.port, - self.client_nonce, - self.version, - self.resource - ) - - self.wfile.write(handshake) + preamble = http.request_preamble("GET", "/") + self.wfile.write(preamble + "\r\n") + headers = websockets.client_handshake_headers() + self.client_nonce = headers.get_first("sec-websocket-key") + self.wfile.write(headers.format() + "\r\n") self.wfile.flush() - server_handshake = websockets.read_handshake(self.rfile, 1) - server_nonce = websockets.process_handshake_from_server( - server_handshake - ) + resp = http.read_response(self.rfile, "get", None) + server_nonce = websockets.check_server_handshake(resp) if not server_nonce == websockets.create_server_nonce(self.client_nonce): self.close() @@ -140,51 +132,43 @@ class TestWebSockets(test.ServerTestBase): frame.actual_payload_length = 1 # corrupt the frame frame.safe_to_bytes() - def test_handshake(self): - bad_upgrade = "not_websockets" - bad_header_handshake = websockets.build_handshake([ - ('Host', '%s:%s' % ("a", "b")), - ('Connection', "c"), - ('Upgrade', bad_upgrade), - ('Sec-WebSocket-Key', "d"), - ('Sec-WebSocket-Version', "e") - ], "f") - - # check behavior when required header values are missing - assert None is websockets.process_handshake_from_server( - bad_header_handshake + def test_check_server_handshake(self): + resp = http.Response( + (1, 1), + 101, + "Switching Protocols", + websockets.server_handshake_headers("key"), + "" ) - assert None is websockets.process_handshake_from_client( - bad_header_handshake + assert websockets.check_server_handshake(resp) + resp.headers["Upgrade"] = ["not_websocket"] + assert not websockets.check_server_handshake(resp) + + def test_check_client_handshake(self): + resp = http.Request( + "relative", + "get", + "http", + "host", + 22, + "/", + (1, 1), + websockets.client_handshake_headers("key"), + "" ) - - key = "test_key" - - client_handshake = websockets.create_client_handshake( - "a", "b", key, "d", "e" - ) - assert key == websockets.process_handshake_from_client( - client_handshake - ) - - server_handshake = websockets.create_server_handshake(key) - assert websockets.create_server_nonce(key) == websockets.process_handshake_from_server(server_handshake) - - handshake = websockets.create_client_handshake("a", "b", "c", "d", "e") - stream = io.BytesIO(handshake) - assert handshake == websockets.read_handshake(stream, 1) - - # ensure readhandshake doesn't loop forever on empty stream - empty_stream = io.BytesIO("") - assert "" == websockets.read_handshake(empty_stream, 1) + assert websockets.check_client_handshake(resp) == "key" + resp.headers["Upgrade"] = ["not_websocket"] + assert not websockets.check_client_handshake(resp) class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): - client_hs = websockets.read_handshake(self.rfile, 1) - websockets.process_handshake_from_client(client_hs) - response = websockets.create_server_handshake("malformed_key") - self.wfile.write(response) + client_hs = http.read_request(self.rfile) + websockets.check_client_handshake(client_hs) + + self.wfile.write(http.response_preamble(101) + "\r\n") + headers = websockets.server_handshake_headers("malformed key") + self.wfile.write(headers.format() + "\r\n") self.wfile.flush() self.handshake_done = True