diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index dbe91e7eb..8b2f6aabc 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -104,10 +104,16 @@ def read_chunked(fp, limit): return content -def read_http_body(rfile, client_conn, headers, all, limit): - if 'transfer-encoding' in headers: - if not ",".join(headers["transfer-encoding"]).lower() == "chunked": - raise IOError('Invalid transfer-encoding') +def has_chunked_encoding(headers): + for i in headers["transfer-encoding"]: + for j in i.split(","): + if j.lower() == "chunked": + return True + return False + + +def read_http_body(rfile, headers, all, limit): + if has_chunked_encoding(headers): content = read_chunked(rfile, limit) elif "content-length" in headers: try: @@ -121,7 +127,6 @@ def read_http_body(rfile, client_conn, headers, all, limit): content = rfile.read(l) elif all: content = rfile.read(limit if limit else None) - client_conn.close = True else: content = "" return content @@ -185,10 +190,9 @@ def parse_init_http(line): return method, url, httpversion -def should_connection_close(httpversion, headers): +def request_connection_close(httpversion, headers): """ - Checks the HTTP version and headers to see if this connection should be - closed. + Checks the request to see if the client connection should be closed. """ if "connection" in headers: for value in ",".join(headers['connection']).split(","): @@ -203,7 +207,18 @@ def should_connection_close(httpversion, headers): return True -def read_http_body_request(rfile, wfile, client_conn, headers, httpversion, limit): +def response_connection_close(httpversion, headers): + """ + Checks the response to see if the client connection should be closed. + """ + if request_connection_close(httpversion, headers): + return True + elif not has_chunked_encoding(headers) and "content-length" in headers: + return True + return False + + +def read_http_body_request(rfile, wfile, headers, httpversion, limit): if "expect" in headers: # FIXME: Should be forwarded upstream expect = ",".join(headers['expect']) @@ -212,7 +227,7 @@ def read_http_body_request(rfile, wfile, client_conn, headers, httpversion, limi wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION) wfile.write('\r\n') del headers['expect'] - return read_http_body(rfile, client_conn, headers, False, limit) + return read_http_body(rfile, headers, False, limit) class FileLike: @@ -335,7 +350,7 @@ class ServerConnection: if request.method == "HEAD" or code == 204 or code == 304: content = "" else: - content = read_http_body(self.rfile, self, headers, True, self.config.body_size_limit) + content = read_http_body(self.rfile, headers, True, self.config.body_size_limit) return flow.Response(request, httpversion, code, msg, headers, content, self.cert) def terminate(self): @@ -413,7 +428,13 @@ class ProxyHandler(SocketServer.StreamRequestHandler): if response is None: return self.send_response(response) - if should_connection_close(request.httpversion, request.headers): + if request_connection_close(request.httpversion, request.headers): + return + # We could keep the client connection when the server + # connection needs to go away. However, we want to mimic + # behaviour as closely as possible to the client, so we + # disconnect. + if response_connection_close(response.httpversion, response.headers): return except IOError, v: cc.connection_error = v @@ -467,7 +488,7 @@ class ProxyHandler(SocketServer.StreamRequestHandler): method, path, httpversion = parse_init_http(line) headers = read_headers(self.rfile) content = read_http_body_request( - self.rfile, self.wfile, client_conn, headers, httpversion, self.config.body_size_limit + self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit ) return flow.Request(client_conn, httpversion, host, port, "http", method, path, headers, content) else: @@ -495,14 +516,14 @@ class ProxyHandler(SocketServer.StreamRequestHandler): method, path, httpversion = parse_init_http(line) headers = read_headers(self.rfile) content = read_http_body_request( - self.rfile, self.wfile, client_conn, headers, httpversion, self.config.body_size_limit + self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit ) return flow.Request(client_conn, httpversion, host, port, "https", method, path, headers, content) else: method, scheme, host, port, path, httpversion = parse_init_proxy(line) headers = read_headers(self.rfile) content = read_http_body_request( - self.rfile, self.wfile, client_conn, headers, httpversion, self.config.body_size_limit + self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit ) return flow.Request(client_conn, httpversion, host, port, scheme, method, path, headers, content) diff --git a/test/test_proxy.py b/test/test_proxy.py index 9fd030084..9d7239dd7 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -4,7 +4,12 @@ import libpry from libmproxy import proxy, flow import tutils -class Dummy: pass + +def test_has_chunked_encoding(): + h = flow.ODictCaseless() + assert not proxy.has_chunked_encoding(h) + h["transfer-encoding"] = ["chunked"] + assert proxy.has_chunked_encoding(h) def test_read_chunked(): @@ -24,36 +29,35 @@ def test_read_chunked(): tutils.raises(proxy.ProxyError, proxy.read_chunked, s, None) -def test_should_connection_close(): +def test_request_connection_close(): h = flow.ODictCaseless() - assert proxy.should_connection_close((1, 0), h) - assert not proxy.should_connection_close((1, 1), h) + assert proxy.request_connection_close((1, 0), h) + assert not proxy.request_connection_close((1, 1), h) h["connection"] = ["keep-alive"] - assert not proxy.should_connection_close((1, 1), h) + assert not proxy.request_connection_close((1, 1), h) def test_read_http_body(): - d = Dummy() h = flow.ODict() s = cStringIO.StringIO("testing") - assert proxy.read_http_body(s, d, h, False, None) == "" + assert proxy.read_http_body(s, h, False, None) == "" h["content-length"] = ["foo"] s = cStringIO.StringIO("testing") - tutils.raises(proxy.ProxyError, proxy.read_http_body, s, d, h, False, None) + tutils.raises(proxy.ProxyError, proxy.read_http_body, s, h, False, None) h["content-length"] = [5] s = cStringIO.StringIO("testing") - assert len(proxy.read_http_body(s, d, h, False, None)) == 5 + assert len(proxy.read_http_body(s, h, False, None)) == 5 s = cStringIO.StringIO("testing") - tutils.raises(proxy.ProxyError, proxy.read_http_body, s, d, h, False, 4) + tutils.raises(proxy.ProxyError, proxy.read_http_body, s, h, False, 4) h = flow.ODict() s = cStringIO.StringIO("testing") - assert len(proxy.read_http_body(s, d, h, True, 4)) == 4 + assert len(proxy.read_http_body(s, h, True, 4)) == 4 s = cStringIO.StringIO("testing") - assert len(proxy.read_http_body(s, d, h, True, 100)) == 7 + assert len(proxy.read_http_body(s, h, True, 100)) == 7 class TestFileLike: