small cleanups, working on tests
This commit is contained in:
parent
e41e5cbfdd
commit
0edc04814e
|
@ -26,8 +26,8 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
|
||||||
self.on_message(decoded)
|
self.on_message(decoded)
|
||||||
|
|
||||||
def send_message(self, message):
|
def send_message(self, message):
|
||||||
frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = False)
|
frame = ws.WebSocketsFrame.default(message, from_client = False)
|
||||||
self.wfile.write(frame.to_bytes())
|
self.wfile.write(frame.safe_to_bytes())
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
|
|
||||||
def handshake(self):
|
def handshake(self):
|
||||||
|
@ -47,7 +47,7 @@ class WebSocketsClient(tcp.TCPClient):
|
||||||
def __init__(self, address, source_address=None):
|
def __init__(self, address, source_address=None):
|
||||||
super(WebSocketsClient, self).__init__(address, source_address)
|
super(WebSocketsClient, self).__init__(address, source_address)
|
||||||
self.version = "13"
|
self.version = "13"
|
||||||
self.key = b64encode(os.urandom(16)).decode('utf-8')
|
self.key = ws.generate_client_nounce()
|
||||||
self.resource = "/"
|
self.resource = "/"
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
|
@ -76,6 +76,6 @@ class WebSocketsClient(tcp.TCPClient):
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def send_message(self, message):
|
def send_message(self, message):
|
||||||
frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = True)
|
frame = ws.WebSocketsFrame.default(message, from_client = True)
|
||||||
self.wfile.write(frame.to_bytes())
|
self.wfile.write(frame.safe_to_bytes())
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
|
|
|
@ -65,7 +65,6 @@ class WebSocketsFrame(object):
|
||||||
payload = None, # bytestring
|
payload = None, # bytestring
|
||||||
masking_key = None, # 32 bit byte string
|
masking_key = None, # 32 bit byte string
|
||||||
actual_payload_length = None, # any decimal integer
|
actual_payload_length = None, # any decimal integer
|
||||||
use_validation = True # indicates whether or not you care if this frame adheres to the spec
|
|
||||||
):
|
):
|
||||||
self.fin = fin
|
self.fin = fin
|
||||||
self.rsv1 = rsv1
|
self.rsv1 = rsv1
|
||||||
|
@ -78,21 +77,18 @@ class WebSocketsFrame(object):
|
||||||
self.payload = payload
|
self.payload = payload
|
||||||
self.decoded_payload = decoded_payload
|
self.decoded_payload = decoded_payload
|
||||||
self.actual_payload_length = actual_payload_length
|
self.actual_payload_length = actual_payload_length
|
||||||
self.use_validation = use_validation
|
|
||||||
|
|
||||||
if self.use_validation:
|
|
||||||
self.validate_frame()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_bytes(cls, bytestring):
|
def from_bytes(cls, bytestring):
|
||||||
"""
|
"""
|
||||||
Construct a websocket frame from an in-memory bytestring
|
Construct a websocket frame from an in-memory bytestring
|
||||||
to construct a frame from a stream of bytes, use read_frame() directly
|
to construct a frame from a stream of bytes, use from_byte_stream() directly
|
||||||
"""
|
"""
|
||||||
self.from_byte_stream(io.BytesIO(bytestring).read)
|
self.from_byte_stream(io.BytesIO(bytestring).read)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_frame_from_message(cls, message, from_client = False):
|
def default(cls, message, from_client = False):
|
||||||
"""
|
"""
|
||||||
Construct a basic websocket frame from some default values.
|
Construct a basic websocket frame from some default values.
|
||||||
Creates a non-fragmented text frame.
|
Creates a non-fragmented text frame.
|
||||||
|
@ -119,7 +115,7 @@ class WebSocketsFrame(object):
|
||||||
actual_payload_length = actual_length
|
actual_payload_length = actual_length
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate_frame(self):
|
def frame_is_valid(self):
|
||||||
"""
|
"""
|
||||||
Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame
|
Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame
|
||||||
has not been corrupted.
|
has not been corrupted.
|
||||||
|
@ -141,10 +137,11 @@ class WebSocketsFrame(object):
|
||||||
assert self.actual_payload_length == len(self.payload)
|
assert self.actual_payload_length == len(self.payload)
|
||||||
|
|
||||||
if self.payload is not None and self.masking_key is not None:
|
if self.payload is not None and self.masking_key is not None:
|
||||||
apply_mask(self.payload, self.masking_key) == self.decoded_payload
|
assert apply_mask(self.payload, self.masking_key) == self.decoded_payload
|
||||||
|
|
||||||
|
return True
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
raise WebSocketFrameValidationException()
|
return False
|
||||||
|
|
||||||
def human_readable(self):
|
def human_readable(self):
|
||||||
return "\n".join([
|
return "\n".join([
|
||||||
|
@ -161,15 +158,19 @@ class WebSocketsFrame(object):
|
||||||
("actual_payload_length - " + str(self.actual_payload_length)),
|
("actual_payload_length - " + str(self.actual_payload_length)),
|
||||||
("use_validation - " + str(self.use_validation))])
|
("use_validation - " + str(self.use_validation))])
|
||||||
|
|
||||||
|
def safe_to_bytes(self):
|
||||||
|
try:
|
||||||
|
assert self.frame_is_valid()
|
||||||
|
return self.to_bytes()
|
||||||
|
except:
|
||||||
|
raise WebSocketFrameValidationException()
|
||||||
|
|
||||||
def to_bytes(self):
|
def to_bytes(self):
|
||||||
"""
|
"""
|
||||||
Serialize the frame back into the wire format, returns a bytestring
|
Serialize the frame back into the wire format, returns a bytestring
|
||||||
|
If you haven't checked is_valid_frame() then there's no guarentees that the
|
||||||
|
serialized bytes will be correct. see safe_to_bytes()
|
||||||
"""
|
"""
|
||||||
# validate enforces all the assumptions made by this serializer
|
|
||||||
# in the spritit of mitmproxy, it's possible to create and serialize invalid frames
|
|
||||||
# by skipping validation.
|
|
||||||
if self.use_validation:
|
|
||||||
self.validate_frame()
|
|
||||||
|
|
||||||
max_16_bit_int = (1 << 16)
|
max_16_bit_int = (1 << 16)
|
||||||
max_64_bit_int = (1 << 63)
|
max_64_bit_int = (1 << 63)
|
||||||
|
@ -198,6 +199,7 @@ class WebSocketsFrame(object):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elif self.actual_payload_length < max_16_bit_int:
|
elif self.actual_payload_length < max_16_bit_int:
|
||||||
|
|
||||||
# '!H' pack as 16 bit unsigned short
|
# '!H' pack as 16 bit unsigned short
|
||||||
bytes += struct.pack('!H', self.actual_payload_length) # add 2 byte extended payload length
|
bytes += struct.pack('!H', self.actual_payload_length) # add 2 byte extended payload length
|
||||||
|
|
||||||
|
@ -284,9 +286,6 @@ def apply_mask(message, masking_key):
|
||||||
def random_masking_key():
|
def random_masking_key():
|
||||||
return os.urandom(4)
|
return os.urandom(4)
|
||||||
|
|
||||||
def masking_key_list(masking_key):
|
|
||||||
return [utils.bytes_to_int(byte) for byte in masking_key]
|
|
||||||
|
|
||||||
def create_client_handshake(host, port, key, version, resource):
|
def create_client_handshake(host, port, key, version, resource):
|
||||||
"""
|
"""
|
||||||
WebSockets connections are intiated by the client with a valid HTTP upgrade request
|
WebSockets connections are intiated by the client with a valid HTTP upgrade request
|
||||||
|
|
|
@ -1,15 +1,29 @@
|
||||||
from netlib import test
|
from netlib import test
|
||||||
from netlib.websockets import implementations as ws
|
from netlib.websockets import implementations as impl
|
||||||
|
from netlib.websockets import websockets as ws
|
||||||
|
import os
|
||||||
|
|
||||||
class TestWebSockets(test.ServerTestBase):
|
class TestWebSockets(test.ServerTestBase):
|
||||||
handler = ws.WebSocketsEchoHandler
|
handler = impl.WebSocketsEchoHandler
|
||||||
|
|
||||||
def test_websockets_echo(self):
|
def echo(self, msg):
|
||||||
msg = "hello I'm the client"
|
client = impl.WebSocketsClient(("127.0.0.1", self.port))
|
||||||
client = ws.WebSocketsClient(("127.0.0.1", self.port))
|
|
||||||
client.connect()
|
client.connect()
|
||||||
client.send_message(msg)
|
client.send_message(msg)
|
||||||
response = client.read_next_message()
|
response = client.read_next_message()
|
||||||
print "Assert response: " + response + " == msg: " + msg
|
print "Assert response: " + response + " == msg: " + msg
|
||||||
assert response == msg
|
assert response == msg
|
||||||
|
|
||||||
|
def test_simple_echo(self):
|
||||||
|
self.echo("hello I'm the client")
|
||||||
|
|
||||||
|
def test_frame_sizes(self):
|
||||||
|
small_string = os.urandom(100) # length can fit in the the 7 bit payload length
|
||||||
|
medium_string = os.urandom(50000) # 50kb, sligthly larger than can fit in a 7 bit int
|
||||||
|
large_string = os.urandom(150000) # 150kb, slightly larger than can fit in a 16 bit int
|
||||||
|
|
||||||
|
self.echo(small_string)
|
||||||
|
self.echo(medium_string)
|
||||||
|
self.echo(large_string)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue