From 7d9e38ffb10e92b5127f203c2d8a524da8698b00 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 1 May 2015 10:09:35 +1200 Subject: [PATCH] websockets: A progressive masker. --- netlib/websockets.py | 32 ++++++++++++++++++-------------- test/test_websockets.py | 16 ++++++++++++++++ 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/netlib/websockets.py b/netlib/websockets.py index 1d02d6841..84eb03bae 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -35,21 +35,25 @@ OPCODE = utils.BiDi( ) -def apply_mask(message, masking_key): +class Masker: """ - Data sent from the server must be masked to prevent malicious clients - from sending data over the wire in predictable patterns + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns - This method both encodes and decodes strings with the provided mask - - Servers do not have to mask data they send to the client. - https://tools.ietf.org/html/rfc6455#section-5.3 + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 """ - masks = [utils.bytes_to_int(byte) for byte in masking_key] - result = "" - for char in message: - result += chr(ord(char) ^ masks[len(result) % 4]) - return result + def __init__(self, key): + self.key = key + self.masks = [utils.bytes_to_int(byte) for byte in key] + self.offset = 0 + + def __call__(self, data): + result = "" + for c in data: + result += chr(ord(c) ^ self.masks[self.offset % 4]) + self.offset += 1 + return result def client_handshake_headers(key=None, version=VERSION): @@ -324,7 +328,7 @@ class Frame(object): """ b = self.header.to_bytes() if self.header.masking_key: - b += apply_mask(self.payload, self.header.masking_key) + b += Masker(self.header.masking_key)(self.payload) else: b += self.payload return b @@ -345,7 +349,7 @@ class Frame(object): payload = fp.read(header.payload_length) if header.mask == 1 and header.masking_key: - payload = apply_mask(payload, header.masking_key) + payload = Masker(header.masking_key)(payload) return cls( payload, diff --git a/test/test_websockets.py b/test/test_websockets.py index d8e56a8fd..428f7c61d 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -232,3 +232,19 @@ class TestFrame: def test_human_readable(self): f = websockets.Frame() assert f.human_readable() + + +def test_masker(): + tests = [ + ["a"], + ["four"], + ["fourf"], + ["fourfive"], + ["a", "aasdfasdfa", "asdf"], + ["a"*50, "aasdfasdfa", "asdf"], + ] + for i in tests: + m = websockets.Masker("abcd") + data = "".join([m(t) for t in i]) + data2 = websockets.Masker("abcd")(data) + assert data2 == "".join(i)