diff --git a/.travis.yml b/.travis.yml index 2edd25581..c5634a10c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,7 +17,7 @@ matrix: - libssl-dev - python: 3.5 script: - - py.test3 -n 4 -k "not http2 and not websockets and not wsgi and not models" . + - py.test -n 4 -k "not http2" . - python: pypy - python: pypy env: OPENSSL=1.0.2 diff --git a/netlib/encoding.py b/netlib/encoding.py index 8ac599051..4c11273bc 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -8,27 +8,25 @@ import zlib from .utils import always_byte_args -ENCODINGS = {b"identity", b"gzip", b"deflate"} +ENCODINGS = {"identity", "gzip", "deflate"} -@always_byte_args("ascii", "ignore") def decode(e, content): encoding_map = { - b"identity": identity, - b"gzip": decode_gzip, - b"deflate": decode_deflate, + "identity": identity, + "gzip": decode_gzip, + "deflate": decode_deflate, } if e not in encoding_map: return None return encoding_map[e](content) -@always_byte_args("ascii", "ignore") def encode(e, content): encoding_map = { - b"identity": identity, - b"gzip": encode_gzip, - b"deflate": encode_deflate, + "identity": identity, + "gzip": encode_gzip, + "deflate": encode_deflate, } if e not in encoding_map: return None diff --git a/netlib/http/models.py b/netlib/http/models.py index ff854b135..3c360a371 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -3,7 +3,7 @@ import copy from ..odict import ODict from .. import utils, encoding -from ..utils import always_bytes, always_byte_args +from ..utils import always_bytes, always_byte_args, native from . import cookies import six @@ -254,7 +254,7 @@ class Request(Message): def __repr__(self): if self.host and self.port: - hostport = "{}:{}".format(self.host, self.port) + hostport = "{}:{}".format(native(self.host,"idna"), self.port) else: hostport = "" path = self.path or "" @@ -279,14 +279,14 @@ class Request(Message): Modifies this request to remove headers that will compress the resource's data. """ - self.headers["Accept-Encoding"] = b"identity" + self.headers["Accept-Encoding"] = "identity" def constrain_encoding(self): """ Limits the permissible Accept-Encoding values, based on what we can decode appropriately. """ - accept_encoding = self.headers.get(b"Accept-Encoding") + accept_encoding = native(self.headers.get("Accept-Encoding"), "ascii") if accept_encoding: self.headers["Accept-Encoding"] = ( ', '.join( @@ -309,9 +309,9 @@ class Request(Message): indicates non-form data. """ if self.body: - if HDR_FORM_URLENCODED in self.headers.get("Content-Type", "").lower(): + if HDR_FORM_URLENCODED in self.headers.get("Content-Type", b"").lower(): return self.get_form_urlencoded() - elif HDR_FORM_MULTIPART in self.headers.get("Content-Type", "").lower(): + elif HDR_FORM_MULTIPART in self.headers.get("Content-Type", b"").lower(): return self.get_form_multipart() return ODict([]) @@ -321,12 +321,12 @@ class Request(Message): Returns an empty ODict if there is no data or the content-type indicates non-form data. """ - if self.body and HDR_FORM_URLENCODED in self.headers.get("Content-Type", "").lower(): + if self.body and HDR_FORM_URLENCODED in self.headers.get("Content-Type", b"").lower(): return ODict(utils.urldecode(self.body)) return ODict([]) def get_form_multipart(self): - if self.body and HDR_FORM_MULTIPART in self.headers.get("Content-Type", "").lower(): + if self.body and HDR_FORM_MULTIPART in self.headers.get("Content-Type", b"").lower(): return ODict( utils.multipartdecode( self.headers, @@ -351,7 +351,7 @@ class Request(Message): Components are unquoted. """ _, _, path, _, _, _ = urllib.parse.urlparse(self.url) - return [urllib.parse.unquote(i) for i in path.split(b"/") if i] + return [urllib.parse.unquote(native(i,"ascii")) for i in path.split(b"/") if i] def set_path_components(self, lst): """ @@ -360,7 +360,7 @@ class Request(Message): Components are quoted. """ lst = [urllib.parse.quote(i, safe="") for i in lst] - path = b"/" + b"/".join(lst) + path = always_bytes("/" + "/".join(lst)) scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) self.url = urllib.parse.urlunparse( [scheme, netloc, path, params, query, fragment] @@ -408,11 +408,11 @@ class Request(Message): def pretty_url(self, hostheader): if self.form_out == "authority": # upstream proxy mode - return "%s:%s" % (self.pretty_host(hostheader), self.port) + return b"%s:%d" % (always_bytes(self.pretty_host(hostheader)), self.port) return utils.unparse_url(self.scheme, self.pretty_host(hostheader), self.port, - self.path).encode('ascii') + self.path) def get_cookies(self): """ @@ -420,7 +420,7 @@ class Request(Message): """ ret = ODict() for i in self.headers.get_all("Cookie"): - ret.extend(cookies.parse_cookie_header(i)) + ret.extend(cookies.parse_cookie_header(native(i,"ascii"))) return ret def set_cookies(self, odict): @@ -441,7 +441,7 @@ class Request(Message): self.host, self.port, self.path - ).encode('ascii') + ) @url.setter def url(self, url): @@ -499,7 +499,7 @@ class Response(Message): """ ret = [] for header in self.headers.get_all("Set-Cookie"): - v = cookies.parse_set_cookie_header(header) + v = cookies.parse_set_cookie_header(native(header, "ascii")) if v: name, value, attrs = v ret.append([name, [value, attrs]]) diff --git a/netlib/tutils.py b/netlib/tutils.py index 4903d63b0..1665a7929 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -7,7 +7,7 @@ from contextlib import contextmanager import six import sys -from . import utils +from . import utils, tcp from .http import Request, Response, Headers @@ -15,7 +15,6 @@ def treader(bytes): """ Construct a tcp.Read object from bytes. """ - from . import tcp # TODO: move to top once cryptography is on Python 3.5 fp = BytesIO(bytes) return tcp.Reader(fp) @@ -106,7 +105,7 @@ def treq(**kwargs): port=22, path=b"/path", http_version=b"HTTP/1.1", - headers=Headers(header=b"qvalue"), + headers=Headers(header="qvalue"), body=b"content" ) default.update(kwargs) diff --git a/netlib/utils.py b/netlib/utils.py index 799b0d428..8d11bd5b2 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -9,6 +9,41 @@ import six from six.moves import urllib +def always_bytes(unicode_or_bytes, *encode_args): + if isinstance(unicode_or_bytes, six.text_type): + return unicode_or_bytes.encode(*encode_args) + return unicode_or_bytes + + +def always_byte_args(*encode_args): + """Decorator that transparently encodes all arguments passed as unicode""" + def decorator(fun): + def _fun(*args, **kwargs): + args = [always_bytes(arg, *encode_args) for arg in args] + kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} + return fun(*args, **kwargs) + return _fun + return decorator + + +def native(s, encoding="latin-1"): + """ + Convert :py:class:`bytes` or :py:class:`unicode` to the native + :py:class:`str` type, using latin1 encoding if conversion is necessary. + + https://www.python.org/dev/peps/pep-3333/#a-note-on-string-types + """ + if not isinstance(s, (six.binary_type, six.text_type)): + raise TypeError("%r is neither bytes nor unicode" % s) + if six.PY3: + if isinstance(s, six.binary_type): + return s.decode(encoding) + else: + if isinstance(s, six.text_type): + return s.encode(encoding) + return s + + def isascii(bytes): try: bytes.decode("ascii") @@ -238,6 +273,7 @@ def get_header_tokens(headers, key): return [token.strip() for token in tokens] +@always_byte_args() def hostport(scheme, host, port): """ Returns the host component, with a port specifcation if needed. @@ -323,20 +359,3 @@ def multipartdecode(headers, content): r.append((key, value)) return r return [] - - -def always_bytes(unicode_or_bytes, *encode_args): - if isinstance(unicode_or_bytes, six.text_type): - return unicode_or_bytes.encode(*encode_args) - return unicode_or_bytes - - -def always_byte_args(*encode_args): - """Decorator that transparently encodes all arguments passed as unicode""" - def decorator(fun): - def _fun(*args, **kwargs): - args = [always_bytes(arg, *encode_args) for arg in args] - kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} - return fun(*args, **kwargs) - return _fun - return decorator diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index ceddd2733..55eeaf416 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -2,13 +2,14 @@ from __future__ import absolute_import import os import struct import io +import warnings + import six from .protocol import Masker from netlib import tcp from netlib import utils -DEFAULT = object() MAX_16_BIT_INT = (1 << 16) MAX_64_BIT_INT = (1 << 64) @@ -33,9 +34,9 @@ class FrameHeader(object): rsv1=False, rsv2=False, rsv3=False, - masking_key=DEFAULT, - mask=DEFAULT, - length_code=DEFAULT + masking_key=None, + mask=None, + length_code=None ): if not 0 <= opcode < 2 ** 4: raise ValueError("opcode must be 0-16") @@ -46,18 +47,18 @@ class FrameHeader(object): self.rsv2 = rsv2 self.rsv3 = rsv3 - if length_code is DEFAULT: + if length_code is None: self.length_code = self._make_length_code(self.payload_length) else: self.length_code = length_code - if mask is DEFAULT and masking_key is DEFAULT: + if mask is None and masking_key is None: self.mask = False - self.masking_key = "" - elif mask is DEFAULT: + self.masking_key = b"" + elif mask is None: self.mask = 1 self.masking_key = masking_key - elif masking_key is DEFAULT: + elif masking_key is None: self.mask = mask self.masking_key = os.urandom(4) else: @@ -81,7 +82,7 @@ class FrameHeader(object): else: return 127 - def human_readable(self): + def __repr__(self): vals = [ "ws frame:", OPCODE.get_name(self.opcode, hex(self.opcode)).lower() @@ -98,7 +99,11 @@ class FrameHeader(object): vals.append(" %s" % utils.pretty_size(self.payload_length)) return "".join(vals) - def to_bytes(self): + def human_readable(self): + warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) + return repr(self) + + def __bytes__(self): first_byte = utils.setbit(0, 7, self.fin) first_byte = utils.setbit(first_byte, 6, self.rsv1) first_byte = utils.setbit(first_byte, 5, self.rsv2) @@ -107,7 +112,7 @@ class FrameHeader(object): second_byte = utils.setbit(self.length_code, 7, self.mask) - b = chr(first_byte) + chr(second_byte) + b = six.int2byte(first_byte) + six.int2byte(second_byte) if self.payload_length < 126: pass @@ -119,10 +124,17 @@ class FrameHeader(object): # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length b += struct.pack('!Q', self.payload_length) - if self.masking_key is not None: + if self.masking_key: b += self.masking_key return b + if six.PY2: + __str__ = __bytes__ + + def to_bytes(self): + warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) + return bytes(self) + @classmethod def from_file(cls, fp): """ @@ -154,7 +166,7 @@ class FrameHeader(object): if mask_bit == 1: masking_key = fp.safe_read(4) else: - masking_key = None + masking_key = False return cls( fin=fin, @@ -169,7 +181,9 @@ class FrameHeader(object): ) def __eq__(self, other): - return self.to_bytes() == other.to_bytes() + if isinstance(other, FrameHeader): + return bytes(self) == bytes(other) + return False class Frame(object): @@ -200,7 +214,7 @@ class Frame(object): +---------------------------------------------------------------+ """ - def __init__(self, payload="", **kwargs): + def __init__(self, payload=b"", **kwargs): self.payload = payload kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) self.header = FrameHeader(**kwargs) @@ -216,7 +230,7 @@ class Frame(object): masking_key = os.urandom(4) else: mask_bit = 0 - masking_key = None + masking_key = False return cls( message, @@ -234,28 +248,37 @@ class Frame(object): """ return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) - def human_readable(self): - ret = self.header.human_readable() + def __repr__(self): + ret = repr(self.header) if self.payload: - ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload) + ret = ret + "\nPayload:\n" + utils.clean_bin(self.payload).decode("ascii") return ret - def __repr__(self): - return self.header.human_readable() + def human_readable(self): + warnings.warn("Frame.to_bytes is deprecated, use bytes(frame) instead.", DeprecationWarning) + return repr(self) - def to_bytes(self): + def __bytes__(self): """ Serialize the frame to wire format. Returns a string. """ - b = self.header.to_bytes() + b = bytes(self.header) if self.header.masking_key: b += Masker(self.header.masking_key)(self.payload) else: b += self.payload return b + if six.PY2: + __str__ = __bytes__ + + def to_bytes(self): + warnings.warn("FrameHeader.to_bytes is deprecated, use bytes(frame_header) instead.", DeprecationWarning) + return bytes(self) + def to_file(self, writer): - writer.write(self.to_bytes()) + warnings.warn("Frame.to_file is deprecated, use wfile.write(bytes(frame)) instead.", DeprecationWarning) + writer.write(bytes(self)) writer.flush() @classmethod @@ -286,4 +309,6 @@ class Frame(object): ) def __eq__(self, other): - return self.to_bytes() == other.to_bytes() + if isinstance(other, Frame): + return bytes(self) == bytes(other) + return False \ No newline at end of file diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index 68d827a57..778fe7e74 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -17,11 +17,12 @@ from __future__ import absolute_import import base64 import hashlib import os + +import binascii import six from ..http import Headers -from .. import utils -websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +websockets_magic = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' VERSION = "13" HEADER_WEBSOCKET_KEY = 'sec-websocket-key' @@ -41,14 +42,21 @@ class Masker(object): def __init__(self, key): self.key = key - self.masks = [six.byte2int(byte) for byte in key] self.offset = 0 def mask(self, offset, data): - result = "" - for c in data: - result += chr(ord(c) ^ self.masks[offset % 4]) - offset += 1 + result = bytearray(data) + if six.PY2: + for i in range(len(data)): + result[i] ^= ord(self.key[offset % 4]) + offset += 1 + result = str(result) + else: + + for i in range(len(data)): + result[i] ^= self.key[offset % 4] + offset += 1 + result = bytes(result) return result def __call__(self, data): @@ -73,37 +81,35 @@ class WebsocketsProtocol(object): """ if not key: key = base64.b64encode(os.urandom(16)).decode('utf-8') - return Headers([ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - (HEADER_WEBSOCKET_KEY, key), - (HEADER_WEBSOCKET_VERSION, version) - ]) + return Headers(**{ + HEADER_WEBSOCKET_KEY: key, + HEADER_WEBSOCKET_VERSION: version, + "Connection": "Upgrade", + "Upgrade": "websocket", + }) @classmethod def server_handshake_headers(self, key): """ The server response is a valid HTTP 101 response. """ - return Headers( - [ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - (HEADER_WEBSOCKET_ACCEPT, self.create_server_nonce(key)) - ] - ) + return Headers(**{ + HEADER_WEBSOCKET_ACCEPT: self.create_server_nonce(key), + "Connection": "Upgrade", + "Upgrade": "websocket", + }) @classmethod def check_client_handshake(self, headers): - if headers.get("upgrade") != "websocket": + if headers.get("upgrade") != b"websocket": return return headers.get(HEADER_WEBSOCKET_KEY) @classmethod def check_server_handshake(self, headers): - if headers.get("upgrade") != "websocket": + if headers.get("upgrade") != b"websocket": return return headers.get(HEADER_WEBSOCKET_ACCEPT) @@ -111,5 +117,5 @@ class WebsocketsProtocol(object): @classmethod def create_server_nonce(self, client_nonce): return base64.b64encode( - hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') + binascii.unhexlify(hashlib.sha1(client_nonce + websockets_magic).hexdigest()) ) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index fba9f3885..8fb09008d 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,14 +1,15 @@ from __future__ import (absolute_import, print_function, division) -from io import BytesIO +from io import BytesIO, StringIO import urllib import time import traceback import six +from six.moves import urllib +from netlib.utils import always_bytes, native from . import http, tcp - class ClientConn(object): def __init__(self, address): @@ -24,9 +25,10 @@ class Flow(object): class Request(object): - def __init__(self, scheme, method, path, headers, body): + def __init__(self, scheme, method, path, http_version, headers, body): self.scheme, self.method, self.path = scheme, method, path self.headers, self.body = headers, body + self.http_version = http_version def date_time_string(): @@ -53,38 +55,38 @@ class WSGIAdaptor(object): self.app, self.domain, self.port, self.sversion = app, domain, port, sversion def make_environ(self, flow, errsoc, **extra): - if '?' in flow.request.path: - path_info, query = flow.request.path.split('?', 1) + path = native(flow.request.path) + if '?' in path: + path_info, query = native(path).split('?', 1) else: - path_info = flow.request.path + path_info = path query = '' environ = { 'wsgi.version': (1, 0), - 'wsgi.url_scheme': flow.request.scheme, + 'wsgi.url_scheme': native(flow.request.scheme), 'wsgi.input': BytesIO(flow.request.body or b""), 'wsgi.errors': errsoc, 'wsgi.multithread': True, 'wsgi.multiprocess': False, 'wsgi.run_once': False, 'SERVER_SOFTWARE': self.sversion, - 'REQUEST_METHOD': flow.request.method, + 'REQUEST_METHOD': native(flow.request.method), 'SCRIPT_NAME': '', - 'PATH_INFO': urllib.unquote(path_info), + 'PATH_INFO': urllib.parse.unquote(path_info), 'QUERY_STRING': query, - 'CONTENT_TYPE': flow.request.headers.get('Content-Type', ''), - 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', ''), + 'CONTENT_TYPE': native(flow.request.headers.get('Content-Type', '')), + 'CONTENT_LENGTH': native(flow.request.headers.get('Content-Length', '')), 'SERVER_NAME': self.domain, 'SERVER_PORT': str(self.port), - # FIXME: We need to pick up the protocol read from the request. - 'SERVER_PROTOCOL': "HTTP/1.1", + 'SERVER_PROTOCOL': native(flow.request.http_version), } environ.update(extra) if flow.client_conn.address: - environ["REMOTE_ADDR"], environ[ - "REMOTE_PORT"] = flow.client_conn.address() + environ["REMOTE_ADDR"] = native(flow.client_conn.address.host) + environ["REMOTE_PORT"] = flow.client_conn.address.port for key, value in flow.request.headers.items(): - key = 'HTTP_' + key.upper().replace('-', '_') + key = 'HTTP_' + native(key).upper().replace('-', '_') if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): environ[key] = value return environ @@ -99,7 +101,7 @@ class WSGIAdaptor(object):

Internal Server Error

%s"
- """.strip() % s + """.strip() % s.encode() if not headers_sent: soc.write(b"HTTP/1.1 500 Internal Server Error\r\n") soc.write(b"Content-Type: text/html\r\n") @@ -117,7 +119,7 @@ class WSGIAdaptor(object): def write(data): if not state["headers_sent"]: - soc.write(b"HTTP/1.1 %s\r\n" % state["status"]) + soc.write(b"HTTP/1.1 %s\r\n" % state["status"].encode()) headers = state["headers"] if 'server' not in headers: headers["Server"] = self.sversion @@ -132,18 +134,17 @@ class WSGIAdaptor(object): def start_response(status, headers, exc_info=None): if exc_info: - try: - if state["headers_sent"]: - six.reraise(*exc_info) - finally: - exc_info = None + if state["headers_sent"]: + six.reraise(*exc_info) elif state["status"]: raise AssertionError('Response already started') state["status"] = status - state["headers"] = http.Headers(headers) - return write + state["headers"] = http.Headers([[always_bytes(k), always_bytes(v)] for k,v in headers]) + if exc_info: + self.error_page(soc, state["headers_sent"], traceback.format_tb(exc_info[2])) + state["headers_sent"] = True - errs = BytesIO() + errs = six.BytesIO() try: dataiter = self.app( self.make_environ(request, errs, **env), start_response @@ -155,7 +156,7 @@ class WSGIAdaptor(object): except Exception as e: try: s = traceback.format_exc() - errs.write(s) + errs.write(s.encode("utf-8", "replace")) self.error_page(soc, state["headers_sent"], s) except Exception: # pragma: no cover pass diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/http/test_models.py b/test/http/test_models.py index 6970a6e42..d420b22bc 100644 --- a/test/http/test_models.py +++ b/test/http/test_models.py @@ -58,20 +58,20 @@ class TestRequest(object): req = tutils.treq() req.headers["Accept-Encoding"] = "foobar" req.anticomp() - assert req.headers["Accept-Encoding"] == "identity" + assert req.headers["Accept-Encoding"] == b"identity" def test_constrain_encoding(self): req = tutils.treq() req.headers["Accept-Encoding"] = "identity, gzip, foo" req.constrain_encoding() - assert "foo" not in req.headers["Accept-Encoding"] + assert b"foo" not in req.headers["Accept-Encoding"] def test_update_host(self): req = tutils.treq() req.headers["Host"] = "" req.host = "foobar" req.update_host_header() - assert req.headers["Host"] == "foobar" + assert req.headers["Host"] == b"foobar" def test_get_form(self): req = tutils.treq() @@ -132,7 +132,7 @@ class TestRequest(object): def test_set_path_components(self): req = tutils.treq() - req.set_path_components(["foo", "bar"]) + req.set_path_components([b"foo", b"bar"]) # TODO: add meaningful assertions def test_get_query(self): @@ -140,7 +140,7 @@ class TestRequest(object): assert req.get_query().lst == [] req.url = "http://localhost:80/foo?bar=42" - assert req.get_query().lst == [("bar", "42")] + assert req.get_query().lst == [(b"bar", b"42")] def test_set_query(self): req = tutils.treq() @@ -167,12 +167,12 @@ class TestRequest(object): def test_pretty_url(self): req = tutils.treq() req.form_out = "authority" - assert req.pretty_url(True) == "address:22" - assert req.pretty_url(False) == "address:22" + assert req.pretty_url(True) == b"address:22" + assert req.pretty_url(False) == b"address:22" req.form_out = "relative" - assert req.pretty_url(True) == "http://address:22/path" - assert req.pretty_url(False) == "http://address:22/path" + assert req.pretty_url(True) == b"http://address:22/path" + assert req.pretty_url(False) == b"http://address:22/path" def test_get_cookies_none(self): headers = Headers() @@ -213,11 +213,11 @@ class TestRequest(object): def test_set_url(self): r = tutils.treq(form_in="absolute") - r.url = "https://otheraddress:42/ORLY" - assert r.scheme == "https" - assert r.host == "otheraddress" + r.url = b"https://otheraddress:42/ORLY" + assert r.scheme == b"https" + assert r.host == b"otheraddress" assert r.port == 42 - assert r.path == "/ORLY" + assert r.path == b"/ORLY" try: r.url = "//localhost:80/foo@bar" @@ -374,8 +374,8 @@ class TestResponse(object): def test_get_cookies_twocookies(self): resp = tutils.tresp() resp.headers = Headers([ - ["Set-Cookie", "cookiename=cookievalue"], - ["Set-Cookie", "othercookie=othervalue"] + [b"Set-Cookie", b"cookiename=cookievalue"], + [b"Set-Cookie", b"othercookie=othervalue"] ]) result = resp.get_cookies() assert len(result) == 2 @@ -399,8 +399,8 @@ class TestHeaders(object): def _2host(self): return Headers( [ - ["Host", "example.com"], - ["host", "example.org"] + [b"Host", b"example.com"], + [b"host", b"example.org"] ] ) @@ -408,37 +408,37 @@ class TestHeaders(object): headers = Headers() assert len(headers) == 0 - headers = Headers([["Host", "example.com"]]) + headers = Headers([[b"Host", b"example.com"]]) assert len(headers) == 1 - assert headers["Host"] == "example.com" + assert headers["Host"] == b"example.com" headers = Headers(Host="example.com") assert len(headers) == 1 - assert headers["Host"] == "example.com" + assert headers["Host"] == b"example.com" headers = Headers( - [["Host", "invalid"]], + [[b"Host", b"invalid"]], Host="example.com" ) assert len(headers) == 1 - assert headers["Host"] == "example.com" + assert headers["Host"] == b"example.com" headers = Headers( - [["Host", "invalid"], ["Accept", "text/plain"]], + [[b"Host", b"invalid"], [b"Accept", b"text/plain"]], Host="example.com" ) assert len(headers) == 2 - assert headers["Host"] == "example.com" - assert headers["Accept"] == "text/plain" + assert headers["Host"] == b"example.com" + assert headers["Accept"] == b"text/plain" def test_getitem(self): headers = Headers(Host="example.com") - assert headers["Host"] == "example.com" - assert headers["host"] == "example.com" + assert headers["Host"] == b"example.com" + assert headers["host"] == b"example.com" tutils.raises(KeyError, headers.__getitem__, "Accept") headers = self._2host() - assert headers["Host"] == "example.com, example.org" + assert headers["Host"] == b"example.com, example.org" def test_str(self): headers = Headers(Host="example.com") @@ -458,12 +458,12 @@ class TestHeaders(object): headers["Host"] = "example.com" assert "Host" in headers assert "host" in headers - assert headers["Host"] == "example.com" + assert headers["Host"] == b"example.com" headers["host"] = "example.org" assert "Host" in headers assert "host" in headers - assert headers["Host"] == "example.org" + assert headers["Host"] == b"example.org" headers["accept"] = "text/plain" assert len(headers) == 2 @@ -494,12 +494,10 @@ class TestHeaders(object): def test_keys(self): headers = Headers(Host="example.com") - assert len(headers.keys()) == 1 - assert headers.keys()[0] == "Host" + assert list(headers.keys()) == [b"Host"] headers = self._2host() - assert len(headers.keys()) == 1 - assert headers.keys()[0] == "Host" + assert list(headers.keys()) == [b"Host"] def test_eq_ne(self): headers1 = Headers(Host="example.com") @@ -516,7 +514,7 @@ class TestHeaders(object): def test_get_all(self): headers = self._2host() - assert headers.get_all("host") == ["example.com", "example.org"] + assert headers.get_all("host") == [b"example.com", b"example.org"] assert headers.get_all("accept") == [] def test_set_all(self): @@ -527,10 +525,10 @@ class TestHeaders(object): headers = self._2host() headers.set_all("Host", ["example.org"]) - assert headers["host"] == "example.org" + assert headers["host"] == b"example.org" headers.set_all("Host", ["example.org", "example.net"]) - assert headers["host"] == "example.org, example.net" + assert headers["host"] == b"example.org, example.net" def test_state(self): headers = self._2host() diff --git a/test/test_encoding.py b/test/test_encoding.py index 90f99338b..0ff1aad10 100644 --- a/test/test_encoding.py +++ b/test/test_encoding.py @@ -4,8 +4,6 @@ from netlib import encoding def test_identity(): assert b"string" == encoding.decode("identity", b"string") assert b"string" == encoding.encode("identity", b"string") - assert b"string" == encoding.encode(b"identity", b"string") - assert b"string" == encoding.decode(b"identity", b"string") assert not encoding.encode("nonexistent", b"string") assert not encoding.decode("nonexistent encoding", b"string") diff --git a/test/test_wsgi.py b/test/test_wsgi.py index 856967aff..fe6f09b52 100644 --- a/test/test_wsgi.py +++ b/test/test_wsgi.py @@ -5,8 +5,8 @@ from netlib.http import Headers def tflow(): - headers = Headers(test="value") - req = wsgi.Request("http", "GET", "/", headers, "") + headers = Headers(test=b"value") + req = wsgi.Request("http", "GET", "/", "HTTP/1.1", headers, "") return wsgi.Flow(("127.0.0.1", 8888), req) @@ -20,7 +20,7 @@ class TestApp: status = '200 OK' response_headers = [('Content-type', 'text/plain')] start_response(status, response_headers) - return ['Hello', ' world!\n'] + return [b'Hello', b' world!\n'] class TestWSGI: @@ -47,8 +47,8 @@ class TestWSGI: assert not err val = wfile.getvalue() - assert "Hello world" in val - assert "Server:" in val + assert b"Hello world" in val + assert b"Server:" in val def _serve(self, app): w = wsgi.WSGIAdaptor(app, "foo", 80, "version") @@ -77,7 +77,7 @@ class TestWSGI: response_headers = [('Content-type', 'text/plain')] start_response(status, response_headers) start_response(status, response_headers) - assert "Internal Server Error" in self._serve(app) + assert b"Internal Server Error" in self._serve(app) def test_serve_single_err(self): def app(environ, start_response): @@ -88,7 +88,8 @@ class TestWSGI: status = '200 OK' response_headers = [('Content-type', 'text/plain')] start_response(status, response_headers, ei) - assert "Internal Server Error" in self._serve(app) + yield b"" + assert b"Internal Server Error" in self._serve(app) def test_serve_double_err(self): def app(environ, start_response): @@ -99,7 +100,7 @@ class TestWSGI: status = '200 OK' response_headers = [('Content-type', 'text/plain')] start_response(status, response_headers) - yield "aaa" + yield b"aaa" start_response(status, response_headers, ei) - yield "bbb" - assert "Internal Server Error" in self._serve(app) + yield b"bbb" + assert b"Internal Server Error" in self._serve(app) diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 3af5dc9c2..6f67b84d8 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -41,7 +41,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): key = self.protocol.check_client_handshake(req.headers) preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) - self.wfile.write(preamble + "\r\n") + self.wfile.write(preamble.encode() + b"\r\n") headers = self.protocol.server_handshake_headers(key) self.wfile.write(str(headers) + "\r\n") self.wfile.flush() @@ -62,11 +62,11 @@ class WebSocketsClient(tcp.TCPClient): def connect(self): super(WebSocketsClient, self).connect() - preamble = 'GET / HTTP/1.1' - self.wfile.write(preamble + "\r\n") + preamble = b'GET / HTTP/1.1' + self.wfile.write(preamble + b"\r\n") headers = self.protocol.client_handshake_headers() self.client_nonce = headers["sec-websocket-key"] - self.wfile.write(str(headers) + "\r\n") + self.wfile.write(bytes(headers) + b"\r\n") self.wfile.flush() resp = read_response(self.rfile, treq(method="GET")) @@ -101,7 +101,7 @@ class TestWebSockets(tservers.ServerTestBase): assert response == msg def test_simple_echo(self): - self.echo("hello I'm the client") + self.echo(b"hello I'm the client") def test_frame_sizes(self): # length can fit in the the 7 bit payload length @@ -161,10 +161,10 @@ class BadHandshakeHandler(WebSocketsEchoHandler): client_hs = read_request(self.rfile) self.protocol.check_client_handshake(client_hs.headers) - preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) - self.wfile.write(preamble + "\r\n") - headers = self.protocol.server_handshake_headers("malformed key") - self.wfile.write(str(headers) + "\r\n") + preamble = 'HTTP/1.1 101 %s\r\n' % status_codes.RESPONSES.get(101) + self.wfile.write(preamble.encode()) + headers = self.protocol.server_handshake_headers(b"malformed key") + self.wfile.write(bytes(headers) + b"\r\n") self.wfile.flush() self.handshake_done = True @@ -180,7 +180,7 @@ class TestBadHandshake(tservers.ServerTestBase): def test(self): client = WebSocketsClient(("127.0.0.1", self.port)) client.connect() - client.send_message("hello") + client.send_message(b"hello") class TestFrameHeader: @@ -188,8 +188,7 @@ class TestFrameHeader: def test_roundtrip(self): def round(*args, **kwargs): f = websockets.FrameHeader(*args, **kwargs) - bytes = f.to_bytes() - f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) + f2 = websockets.FrameHeader.from_file(tutils.treader(bytes(f))) assert f == f2 round() round(fin=1) @@ -201,11 +200,11 @@ class TestFrameHeader: round(payload_length=1000) round(payload_length=10000) round(opcode=websockets.OPCODE.PING) - round(masking_key="test") + round(masking_key=b"test") def test_human_readable(self): f = websockets.FrameHeader( - masking_key="test", + masking_key=b"test", fin=True, payload_length=10 ) @@ -214,23 +213,23 @@ class TestFrameHeader: assert f.human_readable() def test_funky(self): - f = websockets.FrameHeader(masking_key="test", mask=False) + f = websockets.FrameHeader(masking_key=b"test", mask=False) bytes = f.to_bytes() f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) assert not f2.mask def test_violations(self): tutils.raises("opcode", websockets.FrameHeader, opcode=17) - tutils.raises("masking key", websockets.FrameHeader, masking_key="x") + tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x") def test_automask(self): f = websockets.FrameHeader(mask=True) assert f.masking_key - f = websockets.FrameHeader(masking_key="foob") + f = websockets.FrameHeader(masking_key=b"foob") assert f.mask - f = websockets.FrameHeader(masking_key="foob", mask=0) + f = websockets.FrameHeader(masking_key=b"foob", mask=0) assert not f.mask assert f.masking_key @@ -240,31 +239,31 @@ class TestFrame: def test_roundtrip(self): def round(*args, **kwargs): f = websockets.Frame(*args, **kwargs) - bytes = f.to_bytes() - f2 = websockets.Frame.from_file(tutils.treader(bytes)) + raw = bytes(f) + f2 = websockets.Frame.from_file(tutils.treader(raw)) assert f == f2 - round("test") - round("test", fin=1) - round("test", rsv1=1) - round("test", opcode=websockets.OPCODE.PING) - round("test", masking_key="test") + round(b"test") + round(b"test", fin=1) + round(b"test", rsv1=1) + round(b"test", opcode=websockets.OPCODE.PING) + round(b"test", masking_key=b"test") def test_human_readable(self): f = websockets.Frame() - assert f.human_readable() + assert repr(f) def test_masker(): tests = [ - ["a"], - ["four"], - ["fourf"], - ["fourfive"], - ["a", "aasdfasdfa", "asdf"], - ["a" * 50, "aasdfasdfa", "asdf"], + [b"a"], + [b"four"], + [b"fourf"], + [b"fourfive"], + [b"a", b"aasdfasdfa", b"asdf"], + [b"a" * 50, b"aasdfasdfa", b"asdf"], ] for i in tests: - m = websockets.Masker("abcd") - data = "".join([m(t) for t in i]) - data2 = websockets.Masker("abcd")(data) - assert data2 == "".join(i) + m = websockets.Masker(b"abcd") + data = b"".join([m(t) for t in i]) + data2 = websockets.Masker(b"abcd")(data) + assert data2 == b"".join(i)