diff --git a/netlib/http.py b/netlib/http.py index e160bd790..454edb3a9 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -95,14 +95,17 @@ def read_headers(fp): return odict.ODictCaseless(ret) -def read_chunked(code, fp, limit): +def read_chunked(fp, headers, limit, is_request): """ Read a chunked HTTP body. May raise HttpError. """ + # FIXME: Should check if chunked is the final encoding in the headers + # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 3.3 2. content = "" total = 0 + code = 400 if is_request else 502 while 1: line = fp.readline(128) if line == "": @@ -151,35 +154,6 @@ def has_chunked_encoding(headers): return "chunked" in [i.lower() for i in get_header_tokens(headers, "transfer-encoding")] -def read_http_body(code, rfile, headers, all, limit): - """ - Read an HTTP body: - - code: The HTTP error code to be used when raising HttpError - rfile: A file descriptor to read from - headers: An ODictCaseless object - all: Should we read all data? - limit: Size limit. - """ - if has_chunked_encoding(headers): - content = read_chunked(code, rfile, limit) - elif "content-length" in headers: - try: - l = int(headers["content-length"][0]) - except ValueError: - # FIXME: Not strictly correct - this could be from the server, in which - # case we should send a 502. - raise HttpError(code, "Invalid content-length header: %s"%headers["content-length"]) - if limit is not None and l > limit: - raise HttpError(code, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) - content = rfile.read(l) - elif all: - content = rfile.read(limit if limit else -1) - else: - content = "" - return content - - def parse_http_protocol(s): """ Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or @@ -304,28 +278,6 @@ def connection_close(httpversion, headers): return True - -def read_http_body_request(rfile, wfile, headers, httpversion, limit): - """ - Read the HTTP body from a client request. - """ - if "expect" in headers: - # FIXME: Should be forwarded upstream - if "100-continue" in headers['expect'] and httpversion >= (1, 1): - wfile.write('HTTP/1.1 100 Continue\r\n') - wfile.write('\r\n') - del headers['expect'] - return read_http_body(400, rfile, headers, False, limit) - - -def read_http_body_response(rfile, headers, limit): - """ - Read the HTTP body from a server response. - """ - all = "close" in get_header_tokens(headers, "connection") - return read_http_body(500, rfile, headers, all, limit) - - def parse_response_line(line): parts = line.strip().split(" ", 2) if len(parts) == 2: # handle missing message gracefully @@ -359,10 +311,41 @@ def read_response(rfile, method, body_size_limit): headers = read_headers(rfile) if headers is None: raise HttpError(502, "Invalid headers.") - if code >= 100 and code <= 199: - return read_response(rfile, method, body_size_limit) - if method == "HEAD" or code == 204 or code == 304: + + # Parse response body according to http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 + if method == "HEAD" or (code in [204, 304]) or 100 <= code <= 199: content = "" else: - content = read_http_body_response(rfile, headers, body_size_limit) + content = read_http_body(rfile, headers, body_size_limit, False) return httpversion, code, msg, headers, content + + +def read_http_body(rfile, headers, limit, is_request): + """ + Read an HTTP message body: + + rfile: A file descriptor to read from + headers: An ODictCaseless object + limit: Size limit. + is_request: True if the body to read belongs to a request, False otherwise + """ + if has_chunked_encoding(headers): + content = read_chunked(rfile, headers, limit, is_request) + elif "content-length" in headers: + try: + l = int(headers["content-length"][0]) + if l < 0: + raise ValueError() + except ValueError: + raise HttpError(400 if is_request else 502, "Invalid content-length header: %s"%headers["content-length"]) + if limit is not None and l > limit: + raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) + content = rfile.read(l) + elif is_request: + content = "" + else: + content = rfile.read(limit if limit else -1) + not_done = rfile.read(1) + if not_done: + raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s," % limit) + return content \ No newline at end of file diff --git a/netlib/test.py b/netlib/test.py index cd1a38471..85a567391 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -18,7 +18,7 @@ class ServerTestBase: handler = None addr = ("localhost", 0) use_ipv6 = False - + @classmethod def setupAll(cls): cls.q = Queue.Queue() diff --git a/test/test_http.py b/test/test_http.py index 4d89bf246..a03861151 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -1,5 +1,5 @@ import cStringIO, textwrap, binascii -from netlib import http, odict +from netlib import http, odict, tcp, test import tutils @@ -17,25 +17,25 @@ def test_has_chunked_encoding(): def test_read_chunked(): s = cStringIO.StringIO("1\r\na\r\n0\r\n") - tutils.raises("closed prematurely", http.read_chunked, 500, s, None) + tutils.raises("closed prematurely", http.read_chunked, s, None, None, True) s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") - assert http.read_chunked(500, s, None) == "a" + assert http.read_chunked(s, None, None, True) == "a" s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n") - assert http.read_chunked(500, s, None) == "a" + assert http.read_chunked(s, None, None, True) == "a" s = cStringIO.StringIO("\r\n") - tutils.raises("closed prematurely", http.read_chunked, 500, s, None) + tutils.raises("closed prematurely", http.read_chunked, s, None, None, True) s = cStringIO.StringIO("1\r\nfoo") - tutils.raises("malformed chunked body", http.read_chunked, 500, s, None) + tutils.raises("malformed chunked body", http.read_chunked, s, None, None, True) s = cStringIO.StringIO("foo\r\nfoo") - tutils.raises(http.HttpError, http.read_chunked, 500, s, None) + tutils.raises(http.HttpError, http.read_chunked, s, None, None, True) s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") - tutils.raises("too large", http.read_chunked, 500, s, 2) + tutils.raises("too large", http.read_chunked, s, None, 2, True) def test_connection_close(): @@ -49,23 +49,6 @@ def test_connection_close(): h["connection"] = ["close"] assert http.connection_close((1, 1), h) -def test_read_http_body_response(): - h = odict.ODictCaseless() - h["content-length"] = [7] - s = cStringIO.StringIO("testing") - assert http.read_http_body_response(s, h, None) == "testing" - - - h = odict.ODictCaseless() - s = cStringIO.StringIO("testing") - assert not http.read_http_body_response(s, h, None) - - h = odict.ODictCaseless() - h["connection"] = ["close"] - s = cStringIO.StringIO("testing") - assert http.read_http_body_response(s, h, None) == "testing" - - def test_get_header_tokens(): h = odict.ODictCaseless() assert http.get_header_tokens(h, "foo") == [] @@ -79,38 +62,54 @@ def test_get_header_tokens(): def test_read_http_body_request(): h = odict.ODictCaseless() - h["expect"] = ["100-continue"] r = cStringIO.StringIO("testing") - w = cStringIO.StringIO() - assert http.read_http_body_request(r, w, h, (1, 1), None) == "" - assert "100 Continue" in w.getvalue() + assert http.read_http_body(r, h, None, True) == "" +def test_read_http_body_response(): + h = odict.ODictCaseless() + s = cStringIO.StringIO("testing") + assert http.read_http_body(s, h, None, False) == "testing" def test_read_http_body(): + # test default case h = odict.ODictCaseless() + h["content-length"] = [7] s = cStringIO.StringIO("testing") - assert http.read_http_body(500, s, h, False, None) == "" + assert http.read_http_body(s, h, None, False) == "testing" + # test content length: invalid header h["content-length"] = ["foo"] s = cStringIO.StringIO("testing") - tutils.raises(http.HttpError, http.read_http_body, 500, s, h, False, None) + tutils.raises(http.HttpError, http.read_http_body, s, h, None, False) + # test content length: invalid header #2 + h["content-length"] = [-1] + s = cStringIO.StringIO("testing") + tutils.raises(http.HttpError, http.read_http_body, s, h, None, False) + + # test content length: content length > actual content h["content-length"] = [5] s = cStringIO.StringIO("testing") - assert len(http.read_http_body(500, s, h, False, None)) == 5 - s = cStringIO.StringIO("testing") - tutils.raises(http.HttpError, http.read_http_body, 500, s, h, False, 4) + tutils.raises(http.HttpError, http.read_http_body, s, h, 4, False) + # test content length: content length < actual content + s = cStringIO.StringIO("testing") + assert len(http.read_http_body(s, h, None, False)) == 5 + + # test no content length: limit > actual content h = odict.ODictCaseless() s = cStringIO.StringIO("testing") - assert len(http.read_http_body(500, s, h, True, 4)) == 4 - s = cStringIO.StringIO("testing") - assert len(http.read_http_body(500, s, h, True, 100)) == 7 + assert len(http.read_http_body(s, h, 100, False)) == 7 + # test no content length: limit < actual content + s = cStringIO.StringIO("testing") + tutils.raises(http.HttpError, http.read_http_body, s, h, 4, False) + + # test chunked h = odict.ODictCaseless() h["transfer-encoding"] = ["chunked"] s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") - assert http.read_http_body(500, s, h, True, 100) == "aaaaa" + assert http.read_http_body(s, h, 100, False) == "aaaaa" def test_parse_http_protocol(): @@ -214,6 +213,21 @@ class TestReadHeaders: assert self._read(data) is None +class NoContentLengthHTTPHandler(tcp.BaseHandler): + def handle(self): + self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n") + self.wfile.flush() + + +class TestReadResponseNoContentLength(test.ServerTestBase): + handler = NoContentLengthHTTPHandler + + def test_no_content_length(self): + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None) + assert content == "bar\r\n\r\n" + def test_read_response(): def tst(data, method, limit): data = textwrap.dedent(data) @@ -244,7 +258,7 @@ def test_read_response(): HTTP/1.1 200 OK """ - assert tst(data, "GET", None) == ((1, 1), 200, 'OK', odict.ODictCaseless(), '') + assert tst(data, "GET", None) == ((1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '') data = """ HTTP/1.1 200 OK diff --git a/test/test_tcp.py b/test/test_tcp.py index a4e66516e..7f2c21c4e 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -133,7 +133,6 @@ class TestFinishFail(test.ServerTestBase): c.wfile.flush() c.rfile.read(4) - class TestDisconnect(test.ServerTestBase): handler = EchoHandler def test_echo(self):