From 91752990d5863526745e5c31cfb4b7459d11047e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 24 Jul 2012 11:39:49 +1200 Subject: [PATCH] Handle HTTP responses that have a body but no content-length or transfer encoding We check if the server sent a connection:close header, and read till the socket closes. Closes #2 --- netlib/http.py | 37 +++++++++++++++++++++++-------------- netlib/tcp.py | 11 ++++++++--- test/test_http.py | 23 ++++++++++++++++++++++- test/test_tcp.py | 6 ++++++ 4 files changed, 59 insertions(+), 18 deletions(-) diff --git a/netlib/http.py b/netlib/http.py index 9d6db0036..980d3f625 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -97,12 +97,21 @@ def read_chunked(code, fp, limit): return content -def has_chunked_encoding(headers): - for i in headers["transfer-encoding"]: +def get_header_tokens(headers, key): + """ + Retrieve all tokens for a header key. A number of different headers + follow a pattern where each header line can containe comma-separated + tokens, and headers can be set multiple times. + """ + toks = [] + for i in headers[key]: for j in i.split(","): - if j.lower() == "chunked": - return True - return False + toks.append(j.strip()) + return toks + + +def has_chunked_encoding(headers): + return "chunked" in [i.lower() for i in get_header_tokens(headers, "transfer-encoding")] def read_http_body(code, rfile, headers, all, limit): @@ -207,12 +216,11 @@ def request_connection_close(httpversion, headers): Checks the request to see if the client connection should be closed. """ if "connection" in headers: - for value in ",".join(headers['connection']).split(","): - value = value.strip() - if value == "close": - return True - elif value == "keep-alive": - return False + toks = get_header_tokens(headers, "connection") + if "close" in toks: + return True + elif "keep-alive" in toks: + return False # HTTP 1.1 connections are assumed to be persistent if httpversion == (1, 1): return False @@ -243,10 +251,11 @@ def read_http_body_request(rfile, wfile, headers, httpversion, limit): return read_http_body(400, rfile, headers, False, limit) -def read_http_body_response(rfile, headers, all, limit): +def read_http_body_response(rfile, headers, limit): """ Read the HTTP body from a server response. """ + all = "close" in get_header_tokens(headers, "connection") return read_http_body(500, rfile, headers, all, limit) @@ -267,7 +276,7 @@ def read_response(rfile, method, body_size_limit): proto, code, msg = parts httpversion = parse_http_protocol(proto) if httpversion is None: - raise HttpError(502, "Invalid HTTP version: %s"%repr(httpversion)) + raise HttpError(502, "Invalid HTTP version in line: %s"%repr(proto)) try: code = int(code) except ValueError: @@ -278,5 +287,5 @@ def read_response(rfile, method, body_size_limit): if method == "HEAD" or code == 204 or code == 304: content = "" else: - content = read_http_body_response(rfile, headers, False, body_size_limit) + content = read_http_body_response(rfile, headers, body_size_limit) return httpversion, code, msg, headers, content diff --git a/netlib/tcp.py b/netlib/tcp.py index 66a26872f..7d3705dac 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -40,6 +40,7 @@ class NetLibTimeout(Exception): pass class FileLike: + BLOCKSIZE = 1024 * 32 def __init__(self, o): self.o = o @@ -51,11 +52,14 @@ class FileLike: self.o.flush() def read(self, length): + """ + If length is None, we read until connection closes. + """ result = '' start = time.time() - while length > 0: + while length == -1 or length > 0: try: - data = self.o.read(length) + data = self.o.read(self.BLOCKSIZE if length == -1 else length) except SSL.ZeroReturnError: break except SSL.WantReadError: @@ -73,7 +77,8 @@ class FileLike: if not data: break result += data - length -= len(data) + if length != -1: + length -= len(data) return result def write(self, v): diff --git a/test/test_http.py b/test/test_http.py index 0174a4aa5..0b83e65a1 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -64,7 +64,28 @@ def test_read_http_body_response(): h = odict.ODictCaseless() h["content-length"] = [7] s = cStringIO.StringIO("testing") - assert http.read_http_body_response(s, h, False, None) == "testing" + assert http.read_http_body_response(s, h, None) == "testing" + + + h = odict.ODictCaseless() + s = cStringIO.StringIO("testing") + assert not http.read_http_body_response(s, h, None) + + h = odict.ODictCaseless() + h["connection"] = ["close"] + s = cStringIO.StringIO("testing") + assert http.read_http_body_response(s, h, None) == "testing" + + +def test_get_header_tokens(): + h = odict.ODictCaseless() + assert http.get_header_tokens(h, "foo") == [] + h["foo"] = ["bar"] + assert http.get_header_tokens(h, "foo") == ["bar"] + h["foo"] = ["bar, voing"] + assert http.get_header_tokens(h, "foo") == ["bar", "voing"] + h["foo"] = ["bar, voing", "oink"] + assert http.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] def test_read_http_body_request(): diff --git a/test/test_tcp.py b/test/test_tcp.py index d6235b015..67c56a374 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -239,3 +239,9 @@ class TestFileLike: s = cStringIO.StringIO("foobar\nfoobar") s = tcp.FileLike(s) assert s.readline(3) == "foo" + + def test_limitless(self): + s = cStringIO.StringIO("f"*(50*1024)) + s = tcp.FileLike(s) + ret = s.read(-1) + assert len(ret) == 50 * 1024