diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 0b1a0bc5f..9303de09c 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,7 +1,9 @@ -from .models import Request, Response, Headers, CONTENT_MISSING +from .models import Request, Response, Headers +from .models import HDR_FORM_MULTIPART, HDR_FORM_URLENCODED, CONTENT_MISSING from . import http1, http2 __all__ = [ - "Request", "Response", "Headers", "CONTENT_MISSING" + "Request", "Response", "Headers", + "HDR_FORM_MULTIPART", "HDR_FORM_URLENCODED", "CONTENT_MISSING", "http1", "http2" ] diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py index 4d223f975..a72c2e05c 100644 --- a/netlib/http/http1/__init__.py +++ b/netlib/http/http1/__init__.py @@ -1,7 +1,7 @@ from .read import ( read_request, read_request_head, read_response, read_response_head, - read_message_body, read_message_body_chunked, + read_body, connection_close, expected_http_body_size, ) @@ -14,7 +14,7 @@ from .assemble import ( __all__ = [ "read_request", "read_request_head", "read_response", "read_response_head", - "read_message_body", "read_message_body_chunked", + "read_body", "connection_close", "expected_http_body_size", "assemble_request", "assemble_request_head", diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index a3269eed0..47c7e95a3 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -31,8 +31,6 @@ def assemble_response_head(response): return b"%s\r\n%s\r\n" % (first_line, headers) - - def _assemble_request_line(request, form=None): if form is None: form = request.form_out @@ -50,7 +48,7 @@ def _assemble_request_line(request, form=None): request.httpversion ) elif form == "absolute": - return b"%s %s://%s:%s%s %s" % ( + return b"%s %s://%s:%d%s %s" % ( request.method, request.scheme, request.host, @@ -78,11 +76,11 @@ def _assemble_request_headers(request): if request.body or request.body == b"": headers[b"Content-Length"] = str(len(request.body)).encode("ascii") - return str(headers) + return bytes(headers) def _assemble_response_line(response): - return b"%s %s %s" % ( + return b"%s %d %s" % ( response.httpversion, response.status_code, response.msg, diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 573bc7399..4c423c4c3 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -7,12 +7,13 @@ from ... import utils from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException from .. import Request, Response, Headers -ALPN_PROTO_HTTP1 = 'http/1.1' +ALPN_PROTO_HTTP1 = b'http/1.1' def read_request(rfile, body_size_limit=None): request = read_request_head(rfile) - request.body = read_message_body(rfile, request, limit=body_size_limit) + expected_body_size = expected_http_body_size(request) + request.body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) request.timestamp_end = time.time() return request @@ -23,15 +24,14 @@ def read_request_head(rfile): Args: rfile: The input stream - body_size_limit (bool): Maximum body size Returns: - The HTTP request object + The HTTP request object (without body) Raises: - HttpReadDisconnect: If no bytes can be read from rfile. - HttpSyntaxException: If the input is invalid. - HttpException: A different error occured. + HttpReadDisconnect: No bytes can be read from rfile. + HttpSyntaxException: The input is malformed HTTP. + HttpException: Any other error occured. """ timestamp_start = time.time() if hasattr(rfile, "reset_timestamps"): @@ -51,12 +51,28 @@ def read_request_head(rfile): def read_response(rfile, request, body_size_limit=None): response = read_response_head(rfile) - response.body = read_message_body(rfile, request, response, body_size_limit) + expected_body_size = expected_http_body_size(request, response) + response.body = b"".join(read_body(rfile, expected_body_size, body_size_limit)) response.timestamp_end = time.time() return response def read_response_head(rfile): + """ + Parse an HTTP response head (response line + headers) from an input stream + + Args: + rfile: The input stream + + Returns: + The HTTP request object (without body) + + Raises: + HttpReadDisconnect: No bytes can be read from rfile. + HttpSyntaxException: The input is malformed HTTP. + HttpException: Any other error occured. + """ + timestamp_start = time.time() if hasattr(rfile, "reset_timestamps"): rfile.reset_timestamps() @@ -68,50 +84,33 @@ def read_response_head(rfile): # more accurate timestamp_start timestamp_start = rfile.first_byte_timestamp - return Response( - http_version, - status_code, - message, - headers, - None, - timestamp_start - ) + return Response(http_version, status_code, message, headers, None, timestamp_start) -def read_message_body(*args, **kwargs): - chunks = read_message_body_chunked(*args, **kwargs) - return b"".join(chunks) - - -def read_message_body_chunked(rfile, request, response=None, limit=None, max_chunk_size=None): +def read_body(rfile, expected_size, limit=None, max_chunk_size=4096): """ - Read an HTTP message body: + Read an HTTP message body Args: - If a request body should be read, only request should be passed. - If a response body should be read, both request and response should be passed. + rfile: The input stream + expected_size: The expected body size (see :py:meth:`expected_body_size`) + limit: Maximum body size + max_chunk_size: Maximium chunk size that gets yielded + + Returns: + A generator that yields byte chunks of the content. Raises: - HttpException - """ - if not response: - headers = request.headers - response_code = None - is_request = True - else: - headers = response.headers - response_code = response.status_code - is_request = False + HttpException, if an error occurs + Caveats: + max_chunk_size is not considered if the transfer encoding is chunked. + """ if not limit or limit < 0: limit = sys.maxsize if not max_chunk_size: max_chunk_size = limit - expected_size = expected_http_body_size( - headers, is_request, request.method, response_code - ) - if expected_size is None: for x in _read_chunked(rfile, limit): yield x @@ -125,6 +124,8 @@ def read_message_body_chunked(rfile, request, response=None, limit=None, max_chu while bytes_left: chunk_size = min(bytes_left, max_chunk_size) content = rfile.read(chunk_size) + if len(content) < chunk_size: + raise HttpException("Unexpected EOF") yield content bytes_left -= chunk_size else: @@ -148,10 +149,10 @@ def connection_close(http_version, headers): """ # At first, check if we have an explicit Connection header. if b"connection" in headers: - toks = utils.get_header_tokens(headers, "connection") - if b"close" in toks: + tokens = utils.get_header_tokens(headers, "connection") + if b"close" in tokens: return True - elif b"keep-alive" in toks: + elif b"keep-alive" in tokens: return False # If we don't have a Connection header, HTTP 1.1 connections are assumed to @@ -159,37 +160,41 @@ def connection_close(http_version, headers): return http_version != (1, 1) -def expected_http_body_size( - headers, - is_request, - request_method, - response_code, -): +def expected_http_body_size(request, response=False): """ - Returns the expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding) - - -1, if all data should be read until end of stream. + Returns: + The expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding) + - -1, if all data should be read until end of stream. Raises: HttpSyntaxException, if the content length header is invalid """ # Determine response size according to # http://tools.ietf.org/html/rfc7230#section-3.3 - if request_method: - request_method = request_method.upper() + if not response: + headers = request.headers + response_code = None + is_request = True + else: + headers = response.headers + response_code = response.status_code + is_request = False - is_empty_response = (not is_request and ( - request_method == b"HEAD" or - 100 <= response_code <= 199 or - (response_code == 200 and request_method == b"CONNECT") or - response_code in (204, 304) - )) + if is_request: + if headers.get(b"expect", b"").lower() == b"100-continue": + return 0 + else: + if request.method.upper() == b"HEAD": + return 0 + if 100 <= response_code <= 199: + return 0 + if response_code == 200 and request.method.upper() == b"CONNECT": + return 0 + if response_code in (204, 304): + return 0 - if is_empty_response: - return 0 - if is_request and headers.get(b"expect", b"").lower() == b"100-continue": - return 0 if b"chunked" in headers.get(b"transfer-encoding", b"").lower(): return None if b"content-length" in headers: @@ -212,18 +217,22 @@ def _get_first_line(rfile): line = rfile.readline() if not line: raise HttpReadDisconnect() - return line + line = line.strip() + try: + line.decode("ascii") + except ValueError: + raise HttpSyntaxException("Non-ascii characters in first line: {}".format(line)) + return line.strip() def _read_request_line(rfile): line = _get_first_line(rfile) try: - method, path, http_version = line.strip().split(b" ") + method, path, http_version = line.split(b" ") if path == b"*" or path.startswith(b"/"): form = "relative" - path.decode("ascii") # should not raise a ValueError scheme, host, port = None, None, None elif method == b"CONNECT": form = "authority" @@ -233,6 +242,7 @@ def _read_request_line(rfile): form = "absolute" scheme, host, port, path = utils.parse_url(path) + _check_http_version(http_version) except ValueError: raise HttpSyntaxException("Bad HTTP request line: {}".format(line)) @@ -253,7 +263,7 @@ def _parse_authority_form(hostport): if not utils.is_valid_host(host) or not utils.is_valid_port(port): raise ValueError() except ValueError: - raise ValueError("Invalid host specification: {}".format(hostport)) + raise HttpSyntaxException("Invalid host specification: {}".format(hostport)) return host, port @@ -263,7 +273,7 @@ def _read_response_line(rfile): try: - parts = line.strip().split(b" ") + parts = line.split(b" ", 2) if len(parts) == 2: # handle missing message gracefully parts.append(b"") @@ -278,7 +288,7 @@ def _read_response_line(rfile): def _check_http_version(http_version): - if not re.match(rb"^HTTP/\d\.\d$", http_version): + if not re.match(br"^HTTP/\d\.\d$", http_version): raise HttpSyntaxException("Unknown HTTP version: {}".format(http_version)) @@ -313,7 +323,7 @@ def _read_headers(rfile): return Headers(ret) -def _read_chunked(rfile, limit): +def _read_chunked(rfile, limit=sys.maxsize): """ Read a HTTP body with chunked transfer encoding. diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index b6d376d33..036bf68ff 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -4,7 +4,7 @@ import time from hpack.hpack import Encoder, Decoder from netlib import http, utils -from netlib.http import semantics +from netlib.http import models as semantics from . import frame @@ -15,7 +15,7 @@ class TCPHandler(object): self.wfile = wfile -class HTTP2Protocol(semantics.ProtocolMixin): +class HTTP2Protocol(object): ERROR_CODES = utils.BiDi( NO_ERROR=0x0, diff --git a/netlib/http/http2/frames.py b/netlib/http/http2/frame.py similarity index 95% rename from netlib/http/http2/frames.py rename to netlib/http/http2/frame.py index b36b3adf3..cb2cde994 100644 --- a/netlib/http/http2/frames.py +++ b/netlib/http/http2/frame.py @@ -1,12 +1,31 @@ -import sys +from __future__ import absolute_import, print_function, division import struct from hpack.hpack import Encoder, Decoder -from .. import utils +from ...utils import BiDi +from ...exceptions import HttpSyntaxException -class FrameSizeError(Exception): - pass +ERROR_CODES = BiDi( + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd +) + +CLIENT_CONNECTION_PREFACE = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + +ALPN_PROTO_H2 = b'h2' class Frame(object): @@ -30,7 +49,9 @@ class Frame(object): length=0, flags=FLAG_NO_FLAGS, stream_id=0x0): - valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) + valid_flags = 0 + for flag in self.VALID_FLAGS: + valid_flags |= flag if flags | valid_flags != valid_flags: raise ValueError('invalid flags detected.') @@ -61,7 +82,7 @@ class Frame(object): SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] if length > max_frame_size: - raise FrameSizeError( + raise HttpSyntaxException( "Frame size exceeded: %d, but only %d allowed." % ( length, max_frame_size)) @@ -80,7 +101,7 @@ class Frame(object): stream_id = fields[4] if raw_header[:4] == b'HTTP': # pragma no cover - print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" + raise HttpSyntaxException("Expected HTTP2 Frame, got HTTP/1 connection") cls._check_frame_size(length, state) @@ -339,7 +360,7 @@ class SettingsFrame(Frame): TYPE = 0x4 VALID_FLAGS = [Frame.FLAG_ACK] - SETTINGS = utils.BiDi( + SETTINGS = BiDi( SETTINGS_HEADER_TABLE_SIZE=0x1, SETTINGS_ENABLE_PUSH=0x2, SETTINGS_MAX_CONCURRENT_STREAMS=0x3, @@ -366,7 +387,7 @@ class SettingsFrame(Frame): def from_bytes(cls, state, length, flags, stream_id, payload): f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - for i in xrange(0, len(payload), 6): + for i in range(0, len(payload), 6): identifier, value = struct.unpack("!HL", payload[i:i + 6]) f.settings[identifier] = value diff --git a/netlib/http/models.py b/netlib/http/models.py index bd5863b1c..572d66c95 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -474,7 +474,6 @@ class Response(object): msg=None, headers=None, body=None, - sslinfo=None, timestamp_start=None, timestamp_end=None, ): @@ -487,7 +486,6 @@ class Response(object): self.msg = msg self.headers = headers self.body = body - self.sslinfo = sslinfo self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end diff --git a/netlib/tutils.py b/netlib/tutils.py index 65c4a3138..758f84107 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -7,13 +7,15 @@ from contextlib import contextmanager import six import sys -from netlib import tcp, utils, http +from . import utils +from .http import Request, Response, Headers def treader(bytes): """ Construct a tcp.Read object from bytes. """ + from . import tcp # TODO: move to top once cryptography is on Python 3.5 fp = BytesIO(bytes) return tcp.Reader(fp) @@ -91,55 +93,39 @@ class RaisesContext(object): test_data = utils.Data(__name__) -def treq(content="content", scheme="http", host="address", port=22): +def treq(**kwargs): """ - @return: libmproxy.protocol.http.HTTPRequest + Returns: + netlib.http.Request """ - headers = http.Headers() - headers["header"] = "qvalue" - req = http.Request( - "relative", - "GET", - scheme, - host, - port, - "/path", - (1, 1), - headers, - content, - None, - None, + default = dict( + form_in="relative", + method=b"GET", + scheme=b"http", + host=b"address", + port=22, + path=b"/path", + httpversion=b"HTTP/1.1", + headers=Headers(header=b"qvalue"), + body=b"content" ) - return req + default.update(kwargs) + return Request(**default) -def treq_absolute(content="content"): +def tresp(**kwargs): """ - @return: libmproxy.protocol.http.HTTPRequest + Returns: + netlib.http.Response """ - r = treq(content) - r.form_in = r.form_out = "absolute" - r.host = "address" - r.port = 22 - r.scheme = "http" - return r - - -def tresp(content="message"): - """ - @return: libmproxy.protocol.http.HTTPResponse - """ - - headers = http.Headers() - headers["header_response"] = "svalue" - - resp = http.semantics.Response( - (1, 1), - 200, - "OK", - headers, - content, + default = dict( + httpversion=b"HTTP/1.1", + status_code=200, + msg=b"OK", + headers=Headers(header_response=b"svalue"), + body=b"message", timestamp_start=time.time(), - timestamp_end=time.time(), + timestamp_end=time.time() ) - return resp + default.update(kwargs) + return Response(**default) diff --git a/netlib/utils.py b/netlib/utils.py index fb579cac6..a86b80195 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -40,9 +40,9 @@ def clean_bin(s, keep_spacing=True): ) else: if keep_spacing: - keep = b"\n\r\t" + keep = (9, 10, 13) # \t, \n, \r, else: - keep = b"" + keep = () return b"".join( six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"." for ch in six.iterbytes(s) @@ -251,7 +251,7 @@ def hostport(scheme, host, port): if (port, scheme) in [(80, "http"), (443, "https")]: return host else: - return b"%s:%s" % (host, port) + return b"%s:%d" % (host, port) def unparse_url(scheme, host, port, path=""): diff --git a/test/http/http1/test_assemble.py b/test/http/http1/test_assemble.py new file mode 100644 index 000000000..8a0a54f16 --- /dev/null +++ b/test/http/http1/test_assemble.py @@ -0,0 +1,91 @@ +from __future__ import absolute_import, print_function, division +from netlib.exceptions import HttpException +from netlib.http import CONTENT_MISSING, Headers +from netlib.http.http1.assemble import ( + assemble_request, assemble_request_head, assemble_response, + assemble_response_head, _assemble_request_line, _assemble_request_headers, + _assemble_response_headers +) +from netlib.tutils import treq, raises, tresp + + +def test_assemble_request(): + c = assemble_request(treq()) == ( + b"GET /path HTTP/1.1\r\n" + b"header: qvalue\r\n" + b"Host: address:22\r\n" + b"Content-Length: 7\r\n" + b"\r\n" + b"content" + ) + + with raises(HttpException): + assemble_request(treq(body=CONTENT_MISSING)) + + +def test_assemble_request_head(): + c = assemble_request_head(treq()) + assert b"GET" in c + assert b"qvalue" in c + assert b"content" not in c + + +def test_assemble_response(): + c = assemble_response(tresp()) == ( + b"HTTP/1.1 200 OK\r\n" + b"header-response: svalue\r\n" + b"Content-Length: 7\r\n" + b"\r\n" + b"message" + ) + + with raises(HttpException): + assemble_response(tresp(body=CONTENT_MISSING)) + + +def test_assemble_response_head(): + c = assemble_response_head(tresp()) + assert b"200" in c + assert b"svalue" in c + assert b"message" not in c + + +def test_assemble_request_line(): + assert _assemble_request_line(treq()) == b"GET /path HTTP/1.1" + + authority_request = treq(method=b"CONNECT", form_in="authority") + assert _assemble_request_line(authority_request) == b"CONNECT address:22 HTTP/1.1" + + absolute_request = treq(form_in="absolute") + assert _assemble_request_line(absolute_request) == b"GET http://address:22/path HTTP/1.1" + + with raises(RuntimeError): + _assemble_request_line(treq(), "invalid_form") + + +def test_assemble_request_headers(): + # https://github.com/mitmproxy/mitmproxy/issues/186 + r = treq(body=b"") + r.headers[b"Transfer-Encoding"] = b"chunked" + c = _assemble_request_headers(r) + assert b"Content-Length" in c + assert b"Transfer-Encoding" not in c + + assert b"Host" in _assemble_request_headers(treq(headers=Headers())) + + assert b"Proxy-Connection" not in _assemble_request_headers( + treq(headers=Headers(Proxy_Connection="42")) + ) + + +def test_assemble_response_headers(): + # https://github.com/mitmproxy/mitmproxy/issues/186 + r = tresp(body=b"") + r.headers["Transfer-Encoding"] = b"chunked" + c = _assemble_response_headers(r) + assert b"Content-Length" in c + assert b"Transfer-Encoding" not in c + + assert b"Proxy-Connection" not in _assemble_response_headers( + tresp(headers=Headers(Proxy_Connection=b"42")) + ) diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index bdcba5cbb..e69de29bb 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -1,466 +0,0 @@ -from io import BytesIO -import textwrap -from http.http1.protocol import _parse_authority_form -from netlib.exceptions import HttpSyntaxException, HttpReadDisconnect, HttpException - -from netlib import http, tcp, tutils -from netlib.http import semantics, Headers -from netlib.http.http1 import HTTP1Protocol, read_message_body, read_request, \ - read_message_body_chunked, expected_http_body_size -from ... import tservers - - -class NoContentLengthHTTPHandler(tcp.BaseHandler): - def handle(self): - self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n") - self.wfile.flush() - - -def mock_protocol(data=''): - rfile = BytesIO(data) - wfile = BytesIO() - return HTTP1Protocol(rfile=rfile, wfile=wfile) - - -def match_http_string(data): - return textwrap.dedent(data).strip().replace('\n', '\r\n') - - -def test_stripped_chunked_encoding_no_content(): - """ - https://github.com/mitmproxy/mitmproxy/issues/186 - """ - - r = tutils.treq(content="") - r.headers["Transfer-Encoding"] = "chunked" - assert "Content-Length" in mock_protocol()._assemble_request_headers(r) - - r = tutils.tresp(content="") - r.headers["Transfer-Encoding"] = "chunked" - assert "Content-Length" in mock_protocol()._assemble_response_headers(r) - - -def test_read_chunked(): - req = tutils.treq(None) - req.headers["Transfer-Encoding"] = "chunked" - - data = b"1\r\na\r\n0\r\n" - with tutils.raises(HttpSyntaxException): - read_message_body(BytesIO(data), req) - - data = b"1\r\na\r\n0\r\n\r\n" - assert read_message_body(BytesIO(data), req) == b"a" - - data = b"\r\n\r\n1\r\na\r\n1\r\nb\r\n0\r\n\r\n" - assert read_message_body(BytesIO(data), req) == b"ab" - - data = b"\r\n" - with tutils.raises("closed prematurely"): - read_message_body(BytesIO(data), req) - - data = b"1\r\nfoo" - with tutils.raises("malformed chunked body"): - read_message_body(BytesIO(data), req) - - data = b"foo\r\nfoo" - with tutils.raises(HttpSyntaxException): - read_message_body(BytesIO(data), req) - - data = b"5\r\naaaaa\r\n0\r\n\r\n" - with tutils.raises("too large"): - read_message_body(BytesIO(data), req, limit=2) - - -def test_connection_close(): - headers = Headers() - assert HTTP1Protocol.connection_close((1, 0), headers) - assert not HTTP1Protocol.connection_close((1, 1), headers) - - headers["connection"] = "keep-alive" - assert not HTTP1Protocol.connection_close((1, 1), headers) - - headers["connection"] = "close" - assert HTTP1Protocol.connection_close((1, 1), headers) - - -def test_read_http_body_request(): - headers = Headers() - data = "testing" - assert mock_protocol(data).read_http_body(headers, None, "GET", None, True) == "" - - -def test_read_http_body_response(): - headers = Headers() - data = "testing" - assert mock_protocol(data).read_http_body(headers, None, "GET", 200, False) == "testing" - - -def test_read_http_body(): - # test default case - headers = Headers() - headers["content-length"] = "7" - data = "testing" - assert mock_protocol(data).read_http_body(headers, None, "GET", 200, False) == "testing" - - # test content length: invalid header - headers["content-length"] = "foo" - data = "testing" - tutils.raises( - http.HttpError, - mock_protocol(data).read_http_body, - headers, None, "GET", 200, False - ) - - # test content length: invalid header #2 - headers["content-length"] = "-1" - data = "testing" - tutils.raises( - http.HttpError, - mock_protocol(data).read_http_body, - headers, None, "GET", 200, False - ) - - # test content length: content length > actual content - headers["content-length"] = "5" - data = "testing" - tutils.raises( - http.HttpError, - mock_protocol(data).read_http_body, - headers, 4, "GET", 200, False - ) - - # test content length: content length < actual content - data = "testing" - assert len(mock_protocol(data).read_http_body(headers, None, "GET", 200, False)) == 5 - - # test no content length: limit > actual content - headers = Headers() - data = "testing" - assert len(mock_protocol(data).read_http_body(headers, 100, "GET", 200, False)) == 7 - - # test no content length: limit < actual content - data = "testing" - tutils.raises( - http.HttpError, - mock_protocol(data).read_http_body, - headers, 4, "GET", 200, False - ) - - # test chunked - headers = Headers() - headers["transfer-encoding"] = "chunked" - data = "5\r\naaaaa\r\n0\r\n\r\n" - assert mock_protocol(data).read_http_body(headers, 100, "GET", 200, False) == "aaaaa" - - -def test_expected_http_body_size(): - # gibber in the content-length field - headers = Headers(content_length="foo") - with tutils.raises(HttpSyntaxException): - expected_http_body_size(headers, False, "GET", 200) is None - # negative number in the content-length field - headers = Headers(content_length="-7") - with tutils.raises(HttpSyntaxException): - expected_http_body_size(headers, False, "GET", 200) is None - # explicit length - headers = Headers(content_length="5") - assert expected_http_body_size(headers, False, "GET", 200) == 5 - # no length - headers = Headers() - assert expected_http_body_size(headers, False, "GET", 200) == -1 - # no length request - headers = Headers() - assert expected_http_body_size(headers, True, "GET", None) == 0 - # expect header - headers = Headers(content_length="5", expect="100-continue") - assert expected_http_body_size(headers, True, "GET", None) == 0 - - -def test_parse_init_connect(): - assert _parse_authority_form(b"CONNECT host.com:443 HTTP/1.0") - tutils.raises(ValueError,_parse_authority_form, b"\0host.com:443") - tutils.raises(ValueError,_parse_authority_form, b"host.com:444444") - tutils.raises(ValueError,_parse_authority_form, b"CONNECT host.com443 HTTP/1.0") - tutils.raises(ValueError,_parse_authority_form, b"CONNECT host.com:foo HTTP/1.0") - - -def test_parse_init_proxy(): - u = b"GET http://foo.com:8888/test HTTP/1.1" - m, s, h, po, pa, httpversion = HTTP1Protocol._parse_absolute_form(u) - assert m == "GET" - assert s == "http" - assert h == "foo.com" - assert po == 8888 - assert pa == "/test" - assert httpversion == (1, 1) - - u = "G\xfeET http://foo.com:8888/test HTTP/1.1" - assert not HTTP1Protocol._parse_absolute_form(u) - - with tutils.raises(ValueError): - assert not HTTP1Protocol._parse_absolute_form("invalid") - with tutils.raises(ValueError): - assert not HTTP1Protocol._parse_absolute_form("GET invalid HTTP/1.1") - with tutils.raises(ValueError): - assert not HTTP1Protocol._parse_absolute_form("GET http://foo.com:8888/test foo/1.1") - - -def test_parse_init_http(): - u = "GET /test HTTP/1.1" - m, u, httpversion = HTTP1Protocol._parse_init_http(u) - assert m == "GET" - assert u == "/test" - assert httpversion == (1, 1) - - u = "G\xfeET /test HTTP/1.1" - assert not HTTP1Protocol._parse_init_http(u) - - assert not HTTP1Protocol._parse_init_http("invalid") - assert not HTTP1Protocol._parse_init_http("GET invalid HTTP/1.1") - assert not HTTP1Protocol._parse_init_http("GET /test foo/1.1") - assert not HTTP1Protocol._parse_init_http("GET /test\xc0 HTTP/1.1") - - -class TestReadHeaders: - - def _read(self, data, verbatim=False): - if not verbatim: - data = textwrap.dedent(data) - data = data.strip() - return mock_protocol(data).read_headers() - - def test_read_simple(self): - data = """ - Header: one - Header2: two - \r\n - """ - headers = self._read(data) - assert headers.fields == [["Header", "one"], ["Header2", "two"]] - - def test_read_multi(self): - data = """ - Header: one - Header: two - \r\n - """ - headers = self._read(data) - assert headers.fields == [["Header", "one"], ["Header", "two"]] - - def test_read_continued(self): - data = """ - Header: one - \ttwo - Header2: three - \r\n - """ - headers = self._read(data) - assert headers.fields == [["Header", "one\r\n two"], ["Header2", "three"]] - - def test_read_continued_err(self): - data = "\tfoo: bar\r\n" - assert self._read(data, True) is None - - def test_read_err(self): - data = """ - foo - """ - assert self._read(data) is None - - -class TestReadRequest(object): - - def tst(self, data, **kwargs): - return mock_protocol(data).read_request(**kwargs) - - def test_invalid(self): - tutils.raises( - "bad http request", - self.tst, - "xxx" - ) - tutils.raises( - "bad http request line", - self.tst, - "get /\xff HTTP/1.1" - ) - tutils.raises( - "invalid headers", - self.tst, - "get / HTTP/1.1\r\nfoo" - ) - tutils.raises( - HttpReadDisconnect, - self.tst, - "\r\n" - ) - - def test_asterisk_form_in(self): - v = self.tst("OPTIONS * HTTP/1.1") - assert v.form_in == "relative" - assert v.method == "OPTIONS" - - def test_absolute_form_in(self): - tutils.raises( - "Bad HTTP request line", - self.tst, - "GET oops-no-protocol.com HTTP/1.1" - ) - v = self.tst("GET http://address:22/ HTTP/1.1") - assert v.form_in == "absolute" - assert v.port == 22 - assert v.host == "address" - assert v.scheme == "http" - - def test_connect(self): - tutils.raises( - "Bad HTTP request line", - self.tst, - "CONNECT oops-no-port.com HTTP/1.1" - ) - v = self.tst("CONNECT foo.com:443 HTTP/1.1") - assert v.form_in == "authority" - assert v.method == "CONNECT" - assert v.port == 443 - assert v.host == "foo.com" - - def test_expect(self): - data = ( - b"GET / HTTP/1.1\r\n" - b"Content-Length: 3\r\n" - b"Expect: 100-continue\r\n" - b"\r\n" - b"foobar" - ) - - rfile = BytesIO(data) - r = read_request(rfile) - assert r.body == b"" - assert rfile.read(-1) == b"foobar" - - -class TestReadResponse(object): - def tst(self, data, method, body_size_limit, include_body=True): - data = textwrap.dedent(data) - return mock_protocol(data).read_response( - method, body_size_limit, include_body=include_body - ) - - def test_errors(self): - tutils.raises("server disconnect", self.tst, "", "GET", None) - tutils.raises("invalid server response", self.tst, "foo", "GET", None) - - def test_simple(self): - data = """ - HTTP/1.1 200 - """ - assert self.tst(data, "GET", None) == http.Response( - (1, 1), 200, '', Headers(), '' - ) - - def test_simple_message(self): - data = """ - HTTP/1.1 200 OK - """ - assert self.tst(data, "GET", None) == http.Response( - (1, 1), 200, 'OK', Headers(), '' - ) - - def test_invalid_http_version(self): - data = """ - HTTP/x 200 OK - """ - tutils.raises("invalid http version", self.tst, data, "GET", None) - - def test_invalid_status_code(self): - data = """ - HTTP/1.1 xx OK - """ - tutils.raises("invalid server response", self.tst, data, "GET", None) - - def test_valid_with_continue(self): - data = """ - HTTP/1.1 100 CONTINUE - - HTTP/1.1 200 OK - """ - assert self.tst(data, "GET", None) == http.Response( - (1, 1), 100, 'CONTINUE', Headers(), '' - ) - - def test_simple_body(self): - data = """ - HTTP/1.1 200 OK - Content-Length: 3 - - foo - """ - assert self.tst(data, "GET", None).body == 'foo' - assert self.tst(data, "HEAD", None).body == '' - - def test_invalid_headers(self): - data = """ - HTTP/1.1 200 OK - \tContent-Length: 3 - - foo - """ - tutils.raises("invalid headers", self.tst, data, "GET", None) - - def test_without_body(self): - data = """ - HTTP/1.1 200 OK - Content-Length: 3 - - foo - """ - assert self.tst(data, "GET", None, include_body=False).body is None - - -class TestReadResponseNoContentLength(tservers.ServerTestBase): - handler = NoContentLengthHTTPHandler - - def test_no_content_length(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - resp = HTTP1Protocol(c).read_response("GET", None) - assert resp.body == "bar\r\n\r\n" - - -class TestAssembleRequest(object): - def test_simple(self): - req = tutils.treq() - b = HTTP1Protocol().assemble_request(req) - assert b == match_http_string(""" - GET /path HTTP/1.1 - header: qvalue - Host: address:22 - Content-Length: 7 - - content""") - - def test_body_missing(self): - req = tutils.treq(content=semantics.CONTENT_MISSING) - tutils.raises(http.HttpError, HTTP1Protocol().assemble_request, req) - - def test_not_a_request(self): - tutils.raises(AssertionError, HTTP1Protocol().assemble_request, 'foo') - - -class TestAssembleResponse(object): - def test_simple(self): - resp = tutils.tresp() - b = HTTP1Protocol().assemble_response(resp) - assert b == match_http_string(""" - HTTP/1.1 200 OK - header_response: svalue - Content-Length: 7 - - message""") - - def test_body_missing(self): - resp = tutils.tresp(content=semantics.CONTENT_MISSING) - tutils.raises(http.HttpError, HTTP1Protocol().assemble_response, resp) - - def test_not_a_request(self): - tutils.raises(AssertionError, HTTP1Protocol().assemble_response, 'foo') diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py new file mode 100644 index 000000000..5e6680afa --- /dev/null +++ b/test/http/http1/test_read.py @@ -0,0 +1,313 @@ +from __future__ import absolute_import, print_function, division +from io import BytesIO +import textwrap + +from mock import Mock + +from netlib.exceptions import HttpException, HttpSyntaxException, HttpReadDisconnect +from netlib.http import Headers +from netlib.http.http1.read import ( + read_request, read_response, read_request_head, + read_response_head, read_body, connection_close, expected_http_body_size, _get_first_line, + _read_request_line, _parse_authority_form, _read_response_line, _check_http_version, + _read_headers, _read_chunked +) +from netlib.tutils import treq, tresp, raises + + +def test_read_request(): + rfile = BytesIO(b"GET / HTTP/1.1\r\n\r\nskip") + r = read_request(rfile) + assert r.method == b"GET" + assert r.body == b"" + assert r.timestamp_end + assert rfile.read() == b"skip" + + +def test_read_request_head(): + rfile = BytesIO( + b"GET / HTTP/1.1\r\n" + b"Content-Length: 4\r\n" + b"\r\n" + b"skip" + ) + rfile.reset_timestamps = Mock() + rfile.first_byte_timestamp = 42 + r = read_request_head(rfile) + assert r.method == b"GET" + assert r.headers["Content-Length"] == b"4" + assert r.body is None + assert rfile.reset_timestamps.called + assert r.timestamp_start == 42 + assert rfile.read() == b"skip" + + +def test_read_response(): + req = treq() + rfile = BytesIO(b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody") + r = read_response(rfile, req) + assert r.status_code == 418 + assert r.body == b"body" + assert r.timestamp_end + + +def test_read_response_head(): + rfile = BytesIO( + b"HTTP/1.1 418 I'm a teapot\r\n" + b"Content-Length: 4\r\n" + b"\r\n" + b"skip" + ) + rfile.reset_timestamps = Mock() + rfile.first_byte_timestamp = 42 + r = read_response_head(rfile) + assert r.status_code == 418 + assert r.headers["Content-Length"] == b"4" + assert r.body is None + assert rfile.reset_timestamps.called + assert r.timestamp_start == 42 + assert rfile.read() == b"skip" + + +class TestReadBody(object): + def test_chunked(self): + rfile = BytesIO(b"3\r\nfoo\r\n0\r\n\r\nbar") + body = b"".join(read_body(rfile, None)) + assert body == b"foo" + assert rfile.read() == b"bar" + + + def test_known_size(self): + rfile = BytesIO(b"foobar") + body = b"".join(read_body(rfile, 3)) + assert body == b"foo" + assert rfile.read() == b"bar" + + + def test_known_size_limit(self): + rfile = BytesIO(b"foobar") + with raises(HttpException): + b"".join(read_body(rfile, 3, 2)) + + def test_known_size_too_short(self): + rfile = BytesIO(b"foo") + with raises(HttpException): + b"".join(read_body(rfile, 6)) + + def test_unknown_size(self): + rfile = BytesIO(b"foobar") + body = b"".join(read_body(rfile, -1)) + assert body == b"foobar" + + + def test_unknown_size_limit(self): + rfile = BytesIO(b"foobar") + with raises(HttpException): + b"".join(read_body(rfile, -1, 3)) + + +def test_connection_close(): + headers = Headers() + assert connection_close((1, 0), headers) + assert not connection_close((1, 1), headers) + + headers["connection"] = "keep-alive" + assert not connection_close((1, 1), headers) + + headers["connection"] = "close" + assert connection_close((1, 1), headers) + + +def test_expected_http_body_size(): + # Expect: 100-continue + assert expected_http_body_size( + treq(headers=Headers(expect=b"100-continue", content_length=b"42")) + ) == 0 + + # http://tools.ietf.org/html/rfc7230#section-3.3 + assert expected_http_body_size( + treq(method=b"HEAD"), + tresp(headers=Headers(content_length=b"42")) + ) == 0 + assert expected_http_body_size( + treq(method=b"CONNECT"), + tresp() + ) == 0 + for code in (100, 204, 304): + assert expected_http_body_size( + treq(), + tresp(status_code=code) + ) == 0 + + # chunked + assert expected_http_body_size( + treq(headers=Headers(transfer_encoding=b"chunked")), + ) is None + + # explicit length + for l in (b"foo", b"-7"): + with raises(HttpSyntaxException): + expected_http_body_size( + treq(headers=Headers(content_length=l)) + ) + assert expected_http_body_size( + treq(headers=Headers(content_length=b"42")) + ) == 42 + + # no length + assert expected_http_body_size( + treq() + ) == 0 + assert expected_http_body_size( + treq(), tresp() + ) == -1 + + +def test_get_first_line(): + rfile = BytesIO(b"foo\r\nbar") + assert _get_first_line(rfile) == b"foo" + + rfile = BytesIO(b"\r\nfoo\r\nbar") + assert _get_first_line(rfile) == b"foo" + + with raises(HttpReadDisconnect): + rfile = BytesIO(b"") + _get_first_line(rfile) + + with raises(HttpSyntaxException): + rfile = BytesIO(b"GET /\xff HTTP/1.1") + _get_first_line(rfile) + + +def test_read_request_line(): + def t(b): + return _read_request_line(BytesIO(b)) + + assert (t(b"GET / HTTP/1.1") == + ("relative", b"GET", None, None, None, b"/", b"HTTP/1.1")) + assert (t(b"OPTIONS * HTTP/1.1") == + ("relative", b"OPTIONS", None, None, None, b"*", b"HTTP/1.1")) + assert (t(b"CONNECT foo:42 HTTP/1.1") == + ("authority", b"CONNECT", None, b"foo", 42, None, b"HTTP/1.1")) + assert (t(b"GET http://foo:42/bar HTTP/1.1") == + ("absolute", b"GET", b"http", b"foo", 42, b"/bar", b"HTTP/1.1")) + + with raises(HttpSyntaxException): + t(b"GET / WTF/1.1") + with raises(HttpSyntaxException): + t(b"this is not http") + + +def test_parse_authority_form(): + assert _parse_authority_form(b"foo:42") == (b"foo", 42) + with raises(HttpSyntaxException): + _parse_authority_form(b"foo") + with raises(HttpSyntaxException): + _parse_authority_form(b"foo:bar") + with raises(HttpSyntaxException): + _parse_authority_form(b"foo:99999999") + with raises(HttpSyntaxException): + _parse_authority_form(b"f\x00oo:80") + + +def test_read_response_line(): + def t(b): + return _read_response_line(BytesIO(b)) + + assert t(b"HTTP/1.1 200 OK") == (b"HTTP/1.1", 200, b"OK") + assert t(b"HTTP/1.1 200") == (b"HTTP/1.1", 200, b"") + with raises(HttpSyntaxException): + assert t(b"HTTP/1.1") + + with raises(HttpSyntaxException): + t(b"HTTP/1.1 OK OK") + with raises(HttpSyntaxException): + t(b"WTF/1.1 200 OK") + + +def test_check_http_version(): + _check_http_version(b"HTTP/0.9") + _check_http_version(b"HTTP/1.0") + _check_http_version(b"HTTP/1.1") + _check_http_version(b"HTTP/2.0") + with raises(HttpSyntaxException): + _check_http_version(b"WTF/1.0") + with raises(HttpSyntaxException): + _check_http_version(b"HTTP/1.10") + with raises(HttpSyntaxException): + _check_http_version(b"HTTP/1.b") + + +class TestReadHeaders(object): + @staticmethod + def _read(data): + return _read_headers(BytesIO(data)) + + def test_read_simple(self): + data = ( + b"Header: one\r\n" + b"Header2: two\r\n" + b"\r\n" + ) + headers = self._read(data) + assert headers.fields == [[b"Header", b"one"], [b"Header2", b"two"]] + + def test_read_multi(self): + data = ( + b"Header: one\r\n" + b"Header: two\r\n" + b"\r\n" + ) + headers = self._read(data) + assert headers.fields == [[b"Header", b"one"], [b"Header", b"two"]] + + def test_read_continued(self): + data = ( + b"Header: one\r\n" + b"\ttwo\r\n" + b"Header2: three\r\n" + b"\r\n" + ) + headers = self._read(data) + assert headers.fields == [[b"Header", b"one\r\n two"], [b"Header2", b"three"]] + + def test_read_continued_err(self): + data = b"\tfoo: bar\r\n" + with raises(HttpSyntaxException): + self._read(data) + + def test_read_err(self): + data = b"foo" + with raises(HttpSyntaxException): + self._read(data) + + +def test_read_chunked(): + req = treq(body=None) + req.headers["Transfer-Encoding"] = "chunked" + + data = b"1\r\na\r\n0\r\n" + with raises(HttpSyntaxException): + b"".join(_read_chunked(BytesIO(data))) + + data = b"1\r\na\r\n0\r\n\r\n" + assert b"".join(_read_chunked(BytesIO(data))) == b"a" + + data = b"\r\n\r\n1\r\na\r\n1\r\nb\r\n0\r\n\r\n" + assert b"".join(_read_chunked(BytesIO(data))) == b"ab" + + data = b"\r\n" + with raises("closed prematurely"): + b"".join(_read_chunked(BytesIO(data))) + + data = b"1\r\nfoo" + with raises("malformed chunked body"): + b"".join(_read_chunked(BytesIO(data))) + + data = b"foo\r\nfoo" + with raises(HttpSyntaxException): + b"".join(_read_chunked(BytesIO(data))) + + data = b"5\r\naaaaa\r\n0\r\n\r\n" + with raises("too large"): + b"".join(_read_chunked(BytesIO(data), limit=2)) diff --git a/test/http/http2/test_frames.py b/test/http/http2/test_frames.py index efdb55e27..4c89b0239 100644 --- a/test/http/http2/test_frames.py +++ b/test/http/http2/test_frames.py @@ -39,7 +39,7 @@ def test_too_large_frames(): flags=Frame.FLAG_END_STREAM, stream_id=0x1234567, payload='foobar' * 3000) - tutils.raises(FrameSizeError, f.to_bytes) + tutils.raises(HttpSyntaxException, f.to_bytes) def test_data_frame_to_bytes(): diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 2b7d7958b..789b6e633 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -2,21 +2,21 @@ import OpenSSL import mock from netlib import tcp, http, tutils -from netlib.http import http2, Headers -from netlib.http.http2 import HTTP2Protocol +from netlib.http import Headers +from netlib.http.http2.connections import HTTP2Protocol, TCPHandler from netlib.http.http2.frame import * from ... import tservers class TestTCPHandlerWrapper: def test_wrapped(self): - h = http2.TCPHandler(rfile='foo', wfile='bar') + h = TCPHandler(rfile='foo', wfile='bar') p = HTTP2Protocol(h) assert p.tcp_handler.rfile == 'foo' assert p.tcp_handler.wfile == 'bar' def test_direct(self): p = HTTP2Protocol(rfile='foo', wfile='bar') - assert isinstance(p.tcp_handler, http2.TCPHandler) + assert isinstance(p.tcp_handler, TCPHandler) assert p.tcp_handler.rfile == 'foo' assert p.tcp_handler.wfile == 'bar' @@ -32,8 +32,8 @@ class EchoHandler(tcp.BaseHandler): class TestProtocol: - @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface") - @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface") + @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface") + @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface") def test_perform_connection_preface(self, mock_client_method, mock_server_method): protocol = HTTP2Protocol(is_server=False) protocol.connection_preface_performed = True @@ -46,8 +46,8 @@ class TestProtocol: assert mock_client_method.called assert not mock_server_method.called - @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface") - @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface") + @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface") + @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface") def test_perform_connection_preface_server(self, mock_client_method, mock_server_method): protocol = HTTP2Protocol(is_server=True) protocol.connection_preface_performed = True diff --git a/test/http/test_exceptions.py b/test/http/test_exceptions.py deleted file mode 100644 index 49588d0ac..000000000 --- a/test/http/test_exceptions.py +++ /dev/null @@ -1,6 +0,0 @@ -from netlib.http.exceptions import * - -class TestHttpError: - def test_simple(self): - e = HttpError(404, "Not found") - assert str(e) diff --git a/test/http/test_semantics.py b/test/http/test_models.py similarity index 74% rename from test/http/test_semantics.py rename to test/http/test_models.py index 44d3c85ef..0f4dcc3be 100644 --- a/test/http/test_semantics.py +++ b/test/http/test_models.py @@ -1,32 +1,11 @@ import mock -from netlib import http -from netlib import odict from netlib import tutils from netlib import utils -from netlib.http import semantics -from netlib.http.semantics import CONTENT_MISSING +from netlib.odict import ODict, ODictCaseless +from netlib.http import Request, Response, Headers, CONTENT_MISSING, HDR_FORM_URLENCODED, \ + HDR_FORM_MULTIPART -class TestProtocolMixin(object): - @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response") - @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request") - def test_assemble_request(self, mock_request_method, mock_response_method): - p = semantics.ProtocolMixin() - p.assemble(tutils.treq()) - assert mock_request_method.called - assert not mock_response_method.called - - @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response") - @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request") - def test_assemble_response(self, mock_request_method, mock_response_method): - p = semantics.ProtocolMixin() - p.assemble(tutils.tresp()) - assert not mock_request_method.called - assert mock_response_method.called - - def test_assemble_foo(self): - p = semantics.ProtocolMixin() - tutils.raises(ValueError, p.assemble, 'foo') class TestRequest(object): def test_repr(self): @@ -34,7 +13,7 @@ class TestRequest(object): assert repr(r) def test_headers(self): - tutils.raises(AssertionError, semantics.Request, + tutils.raises(AssertionError, Request, 'form_in', 'method', 'scheme', @@ -45,7 +24,7 @@ class TestRequest(object): 'foobar', ) - req = semantics.Request( + req = Request( 'form_in', 'method', 'scheme', @@ -54,7 +33,7 @@ class TestRequest(object): 'path', (1, 1), ) - assert isinstance(req.headers, http.Headers) + assert isinstance(req.headers, Headers) def test_equal(self): a = tutils.treq() @@ -66,13 +45,6 @@ class TestRequest(object): assert not 'foo' == a assert not 'foo' == b - def test_legacy_first_line(self): - req = tutils.treq() - - assert req.legacy_first_line('relative') == "GET /path HTTP/1.1" - assert req.legacy_first_line('authority') == "GET address:22 HTTP/1.1" - assert req.legacy_first_line('absolute') == "GET http://address:22/path HTTP/1.1" - tutils.raises(http.HttpError, req.legacy_first_line, 'foobar') def test_anticache(self): req = tutils.treq() @@ -103,44 +75,44 @@ class TestRequest(object): def test_get_form(self): req = tutils.treq() - assert req.get_form() == odict.ODict() + assert req.get_form() == ODict() - @mock.patch("netlib.http.semantics.Request.get_form_multipart") - @mock.patch("netlib.http.semantics.Request.get_form_urlencoded") + @mock.patch("netlib.http.Request.get_form_multipart") + @mock.patch("netlib.http.Request.get_form_urlencoded") def test_get_form_with_url_encoded(self, mock_method_urlencoded, mock_method_multipart): req = tutils.treq() - assert req.get_form() == odict.ODict() + assert req.get_form() == ODict() req = tutils.treq() req.body = "foobar" - req.headers["Content-Type"] = semantics.HDR_FORM_URLENCODED + req.headers["Content-Type"] = HDR_FORM_URLENCODED req.get_form() assert req.get_form_urlencoded.called assert not req.get_form_multipart.called - @mock.patch("netlib.http.semantics.Request.get_form_multipart") - @mock.patch("netlib.http.semantics.Request.get_form_urlencoded") + @mock.patch("netlib.http.Request.get_form_multipart") + @mock.patch("netlib.http.Request.get_form_urlencoded") def test_get_form_with_multipart(self, mock_method_urlencoded, mock_method_multipart): req = tutils.treq() req.body = "foobar" - req.headers["Content-Type"] = semantics.HDR_FORM_MULTIPART + req.headers["Content-Type"] = HDR_FORM_MULTIPART req.get_form() assert not req.get_form_urlencoded.called assert req.get_form_multipart.called def test_get_form_urlencoded(self): - req = tutils.treq("foobar") - assert req.get_form_urlencoded() == odict.ODict() + req = tutils.treq(body="foobar") + assert req.get_form_urlencoded() == ODict() - req.headers["Content-Type"] = semantics.HDR_FORM_URLENCODED - assert req.get_form_urlencoded() == odict.ODict(utils.urldecode(req.body)) + req.headers["Content-Type"] = HDR_FORM_URLENCODED + assert req.get_form_urlencoded() == ODict(utils.urldecode(req.body)) def test_get_form_multipart(self): - req = tutils.treq("foobar") - assert req.get_form_multipart() == odict.ODict() + req = tutils.treq(body="foobar") + assert req.get_form_multipart() == ODict() - req.headers["Content-Type"] = semantics.HDR_FORM_MULTIPART - assert req.get_form_multipart() == odict.ODict( + req.headers["Content-Type"] = HDR_FORM_MULTIPART + assert req.get_form_multipart() == ODict( utils.multipartdecode( req.headers, req.body @@ -149,8 +121,8 @@ class TestRequest(object): def test_set_form_urlencoded(self): req = tutils.treq() - req.set_form_urlencoded(odict.ODict([('foo', 'bar'), ('rab', 'oof')])) - assert req.headers["Content-Type"] == semantics.HDR_FORM_URLENCODED + req.set_form_urlencoded(ODict([('foo', 'bar'), ('rab', 'oof')])) + assert req.headers["Content-Type"] == HDR_FORM_URLENCODED assert req.body def test_get_path_components(self): @@ -172,7 +144,7 @@ class TestRequest(object): def test_set_query(self): req = tutils.treq() - req.set_query(odict.ODict([])) + req.set_query(ODict([])) def test_pretty_host(self): r = tutils.treq() @@ -203,21 +175,21 @@ class TestRequest(object): assert req.pretty_url(False) == "http://address:22/path" def test_get_cookies_none(self): - headers = http.Headers() + headers = Headers() r = tutils.treq() r.headers = headers assert len(r.get_cookies()) == 0 def test_get_cookies_single(self): r = tutils.treq() - r.headers = http.Headers(cookie="cookiename=cookievalue") + r.headers = Headers(cookie="cookiename=cookievalue") result = r.get_cookies() assert len(result) == 1 assert result['cookiename'] == ['cookievalue'] def test_get_cookies_double(self): r = tutils.treq() - r.headers = http.Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue") + r.headers = Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue") result = r.get_cookies() assert len(result) == 2 assert result['cookiename'] == ['cookievalue'] @@ -225,7 +197,7 @@ class TestRequest(object): def test_get_cookies_withequalsign(self): r = tutils.treq() - r.headers = http.Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue") + r.headers = Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue") result = r.get_cookies() assert len(result) == 2 assert result['cookiename'] == ['coo=kievalue'] @@ -233,14 +205,14 @@ class TestRequest(object): def test_set_cookies(self): r = tutils.treq() - r.headers = http.Headers(cookie="cookiename=cookievalue") + r.headers = Headers(cookie="cookiename=cookievalue") result = r.get_cookies() result["cookiename"] = ["foo"] r.set_cookies(result) assert r.get_cookies()["cookiename"] == ["foo"] def test_set_url(self): - r = tutils.treq_absolute() + r = tutils.treq(form_in="absolute") r.url = "https://otheraddress:42/ORLY" assert r.scheme == "https" assert r.host == "otheraddress" @@ -332,24 +304,19 @@ class TestRequest(object): # "Host: address\r\n" # "Content-Length: 0\r\n\r\n") -class TestEmptyRequest(object): - def test_init(self): - req = semantics.EmptyRequest() - assert req - class TestResponse(object): def test_headers(self): - tutils.raises(AssertionError, semantics.Response, + tutils.raises(AssertionError, Response, (1, 1), 200, headers='foobar', ) - resp = semantics.Response( + resp = Response( (1, 1), 200, ) - assert isinstance(resp.headers, http.Headers) + assert isinstance(resp.headers, Headers) def test_equal(self): a = tutils.tresp() @@ -366,24 +333,24 @@ class TestResponse(object): assert "unknown content type" in repr(r) r.headers["content-type"] = "foo" assert "foo" in repr(r) - assert repr(tutils.tresp(content=CONTENT_MISSING)) + assert repr(tutils.tresp(body=CONTENT_MISSING)) def test_get_cookies_none(self): resp = tutils.tresp() - resp.headers = http.Headers() + resp.headers = Headers() assert not resp.get_cookies() def test_get_cookies_simple(self): resp = tutils.tresp() - resp.headers = http.Headers(set_cookie="cookiename=cookievalue") + resp.headers = Headers(set_cookie="cookiename=cookievalue") result = resp.get_cookies() assert len(result) == 1 assert "cookiename" in result - assert result["cookiename"][0] == ["cookievalue", odict.ODict()] + assert result["cookiename"][0] == ["cookievalue", ODict()] def test_get_cookies_with_parameters(self): resp = tutils.tresp() - resp.headers = http.Headers(set_cookie="cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly") + resp.headers = Headers(set_cookie="cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly") result = resp.get_cookies() assert len(result) == 1 assert "cookiename" in result @@ -397,7 +364,7 @@ class TestResponse(object): def test_get_cookies_no_value(self): resp = tutils.tresp() - resp.headers = http.Headers(set_cookie="cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/") + resp.headers = Headers(set_cookie="cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/") result = resp.get_cookies() assert len(result) == 1 assert "cookiename" in result @@ -406,31 +373,31 @@ class TestResponse(object): def test_get_cookies_twocookies(self): resp = tutils.tresp() - resp.headers = http.Headers([ + resp.headers = Headers([ ["Set-Cookie", "cookiename=cookievalue"], ["Set-Cookie", "othercookie=othervalue"] ]) result = resp.get_cookies() assert len(result) == 2 assert "cookiename" in result - assert result["cookiename"][0] == ["cookievalue", odict.ODict()] + assert result["cookiename"][0] == ["cookievalue", ODict()] assert "othercookie" in result - assert result["othercookie"][0] == ["othervalue", odict.ODict()] + assert result["othercookie"][0] == ["othervalue", ODict()] def test_set_cookies(self): resp = tutils.tresp() v = resp.get_cookies() - v.add("foo", ["bar", odict.ODictCaseless()]) + v.add("foo", ["bar", ODictCaseless()]) resp.set_cookies(v) v = resp.get_cookies() assert len(v) == 1 - assert v["foo"] == [["bar", odict.ODictCaseless()]] + assert v["foo"] == [["bar", ODictCaseless()]] class TestHeaders(object): def _2host(self): - return semantics.Headers( + return Headers( [ ["Host", "example.com"], ["host", "example.org"] @@ -438,25 +405,25 @@ class TestHeaders(object): ) def test_init(self): - headers = semantics.Headers() + headers = Headers() assert len(headers) == 0 - headers = semantics.Headers([["Host", "example.com"]]) + headers = Headers([["Host", "example.com"]]) assert len(headers) == 1 assert headers["Host"] == "example.com" - headers = semantics.Headers(Host="example.com") + headers = Headers(Host="example.com") assert len(headers) == 1 assert headers["Host"] == "example.com" - headers = semantics.Headers( + headers = Headers( [["Host", "invalid"]], Host="example.com" ) assert len(headers) == 1 assert headers["Host"] == "example.com" - headers = semantics.Headers( + headers = Headers( [["Host", "invalid"], ["Accept", "text/plain"]], Host="example.com" ) @@ -465,7 +432,7 @@ class TestHeaders(object): assert headers["Accept"] == "text/plain" def test_getitem(self): - headers = semantics.Headers(Host="example.com") + headers = Headers(Host="example.com") assert headers["Host"] == "example.com" assert headers["host"] == "example.com" tutils.raises(KeyError, headers.__getitem__, "Accept") @@ -474,17 +441,17 @@ class TestHeaders(object): assert headers["Host"] == "example.com, example.org" def test_str(self): - headers = semantics.Headers(Host="example.com") + headers = Headers(Host="example.com") assert bytes(headers) == "Host: example.com\r\n" - headers = semantics.Headers([ + headers = Headers([ ["Host", "example.com"], ["Accept", "text/plain"] ]) assert str(headers) == "Host: example.com\r\nAccept: text/plain\r\n" def test_setitem(self): - headers = semantics.Headers() + headers = Headers() headers["Host"] = "example.com" assert "Host" in headers assert "host" in headers @@ -507,7 +474,7 @@ class TestHeaders(object): assert "Host" in headers def test_delitem(self): - headers = semantics.Headers(Host="example.com") + headers = Headers(Host="example.com") assert len(headers) == 1 del headers["host"] assert len(headers) == 0 @@ -523,7 +490,7 @@ class TestHeaders(object): assert len(headers) == 0 def test_keys(self): - headers = semantics.Headers(Host="example.com") + headers = Headers(Host="example.com") assert len(headers.keys()) == 1 assert headers.keys()[0] == "Host" @@ -532,13 +499,13 @@ class TestHeaders(object): assert headers.keys()[0] == "Host" def test_eq_ne(self): - headers1 = semantics.Headers(Host="example.com") - headers2 = semantics.Headers(host="example.com") + headers1 = Headers(Host="example.com") + headers2 = Headers(host="example.com") assert not (headers1 == headers2) assert headers1 != headers2 - headers1 = semantics.Headers(Host="example.com") - headers2 = semantics.Headers(Host="example.com") + headers1 = Headers(Host="example.com") + headers2 = Headers(Host="example.com") assert headers1 == headers2 assert not (headers1 != headers2) @@ -550,7 +517,7 @@ class TestHeaders(object): assert headers.get_all("accept") == [] def test_set_all(self): - headers = semantics.Headers(Host="example.com") + headers = Headers(Host="example.com") headers.set_all("Accept", ["text/plain"]) assert len(headers) == 2 assert "accept" in headers @@ -565,9 +532,9 @@ class TestHeaders(object): def test_state(self): headers = self._2host() assert len(headers.get_state()) == 2 - assert headers == semantics.Headers.from_state(headers.get_state()) + assert headers == Headers.from_state(headers.get_state()) - headers2 = semantics.Headers() + headers2 = Headers() assert headers != headers2 headers2.load_state(headers.get_state()) assert headers == headers2 diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 57cfd1662..3fdeb6839 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -1,11 +1,13 @@ import os from nose.tools import raises +from netlib.http.http1 import read_response, read_request from netlib import tcp, tutils, websockets, http from netlib.http import status_codes -from netlib.http.exceptions import * -from netlib.http.http1 import HTTP1Protocol +from netlib.tutils import treq + +from netlib.exceptions import * from .. import tservers @@ -34,9 +36,8 @@ class WebSocketsEchoHandler(tcp.BaseHandler): frame.to_file(self.wfile) def handshake(self): - http1_protocol = HTTP1Protocol(self) - req = http1_protocol.read_request() + req = read_request(self.rfile) key = self.protocol.check_client_handshake(req.headers) preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) @@ -61,8 +62,6 @@ class WebSocketsClient(tcp.TCPClient): def connect(self): super(WebSocketsClient, self).connect() - http1_protocol = HTTP1Protocol(self) - preamble = 'GET / HTTP/1.1' self.wfile.write(preamble + "\r\n") headers = self.protocol.client_handshake_headers() @@ -70,7 +69,7 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.write(str(headers) + "\r\n") self.wfile.flush() - resp = http1_protocol.read_response("GET", None) + resp = read_response(self.rfile, treq(method="GET")) server_nonce = self.protocol.check_server_handshake(resp.headers) if not server_nonce == self.protocol.create_server_nonce( @@ -158,9 +157,8 @@ class TestWebSockets(tservers.ServerTestBase): class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): - http1_protocol = HTTP1Protocol(self) - client_hs = http1_protocol.read_request() + client_hs = read_request(self.rfile) self.protocol.check_client_handshake(client_hs.headers) preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)