From a7058e2a3c59cc2b13aaea3d7c767a3ca4a4bc40 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 29 Aug 2015 20:53:25 +0200 Subject: [PATCH] fix bugs, fix tests --- libmproxy/console/statusbar.py | 11 +++--- libmproxy/protocol2/http.py | 54 +++++++++++++++++++--------- test/test_proxy.py | 9 ++--- test/test_server.py | 16 +++++---- test/tservers.py | 64 +++++++--------------------------- 5 files changed, 69 insertions(+), 85 deletions(-) diff --git a/libmproxy/console/statusbar.py b/libmproxy/console/statusbar.py index 7eb2131be..ea2dbfa8a 100644 --- a/libmproxy/console/statusbar.py +++ b/libmproxy/console/statusbar.py @@ -199,11 +199,12 @@ class StatusBar(urwid.WidgetWrap): r.append("[%s]" % (":".join(opts))) if self.master.server.config.mode in ["reverse", "upstream"]: - dst = self.master.server.config.mode.dst - scheme = "https" if dst[0] else "http" - if dst[1] != dst[0]: - scheme += "2https" if dst[1] else "http" - r.append("[dest:%s]" % utils.unparse_url(scheme, *dst[2:])) + dst = self.master.server.config.upstream_server + r.append("[dest:%s]" % netlib.utils.unparse_url( + dst.scheme, + dst.address.host, + dst.address.port + )) if self.master.scripts: r.append("[") r.append(("heading_key", "s")) diff --git a/libmproxy/protocol2/http.py b/libmproxy/protocol2/http.py index 0fde9fb10..a3f329260 100644 --- a/libmproxy/protocol2/http.py +++ b/libmproxy/protocol2/http.py @@ -40,6 +40,7 @@ class _HttpLayer(Layer): def send_response(self, response): raise NotImplementedError() + class _StreamingHttpLayer(_HttpLayer): supports_streaming = True @@ -58,7 +59,6 @@ class _StreamingHttpLayer(_HttpLayer): class Http1Layer(_StreamingHttpLayer): - def __init__(self, ctx, mode): super(Http1Layer, self).__init__(ctx) self.mode = mode @@ -105,12 +105,12 @@ class Http1Layer(_StreamingHttpLayer): def send_response_headers(self, response): h = self.client_protocol._assemble_response_first_line(response) - self.client_conn.wfile.write(h+"\r\n") + self.client_conn.wfile.write(h + "\r\n") h = self.client_protocol._assemble_response_headers( response, preserve_transfer_encoding=True ) - self.client_conn.send(h+"\r\n") + self.client_conn.send(h + "\r\n") def send_response_body(self, response, chunks): if self.client_protocol.has_chunked_encoding(response.headers): @@ -142,8 +142,10 @@ class Http2Layer(_HttpLayer): def __init__(self, ctx, mode): super(Http2Layer, self).__init__(ctx) self.mode = mode - self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True, unhandled_frame_cb=self.handle_unexpected_frame) - self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame) + self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True, + unhandled_frame_cb=self.handle_unexpected_frame) + self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, + unhandled_frame_cb=self.handle_unexpected_frame) def read_request(self): request = HTTPRequest.from_protocol( @@ -172,17 +174,20 @@ class Http2Layer(_HttpLayer): def connect(self): self.ctx.connect() - self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame) + self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, + unhandled_frame_cb=self.handle_unexpected_frame) self.server_protocol.perform_connection_preface() def reconnect(self): self.ctx.reconnect() - self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame) + self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, + unhandled_frame_cb=self.handle_unexpected_frame) self.server_protocol.perform_connection_preface() def set_server(self, *args, **kwargs): self.ctx.set_server(*args, **kwargs) - self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame) + self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, + unhandled_frame_cb=self.handle_unexpected_frame) self.server_protocol.perform_connection_preface() def __call__(self): @@ -264,7 +269,10 @@ class UpstreamConnectLayer(Layer): def __init__(self, ctx, connect_request): super(UpstreamConnectLayer, self).__init__(ctx) self.connect_request = connect_request - self.server_conn = ConnectServerConnection((connect_request.host, connect_request.port), self.ctx) + self.server_conn = ConnectServerConnection( + (connect_request.host, connect_request.port), + self.ctx + ) def __call__(self): layer = self.ctx.next_layer(self) @@ -280,6 +288,9 @@ class UpstreamConnectLayer(Layer): def reconnect(self): self.ctx.reconnect() self.send_request(self.connect_request) + resp = self.read_response("CONNECT") + if resp.code != 200: + raise ProtocolException("Reconnect: Upstream server refuses CONNECT request") def set_server(self, address, server_tls=None, sni=None, depth=1): if depth == 1: @@ -290,7 +301,7 @@ class UpstreamConnectLayer(Layer): self.connect_request.port = address.port self.server_conn.address = address else: - self.ctx.set_server(address, server_tls, sni, depth-1) + self.ctx.set_server(address, server_tls, sni, depth - 1) class HttpLayer(Layer): @@ -413,10 +424,10 @@ class HttpLayer(Layer): # First send the headers and then transfer the response incrementally self.send_response_headers(flow.response) chunks = self.read_response_body( - flow.response.headers, - flow.request.method, - flow.response.code, - max_chunk_size=4096 + flow.response.headers, + flow.request.method, + flow.response.code, + max_chunk_size=4096 ) if callable(flow.response.stream): chunks = flow.response.stream(chunks) @@ -521,7 +532,8 @@ class HttpLayer(Layer): # If there's not TlsLayer below which could catch the exception, # TLS will not be established. if tls and not self.server_conn.tls_established: - raise ProtocolException("Cannot upgrade to SSL, no TLS layer on the protocol stack.") + raise ProtocolException( + "Cannot upgrade to SSL, no TLS layer on the protocol stack.") else: if not self.server_conn: self.connect() @@ -542,7 +554,8 @@ class HttpLayer(Layer): def validate_request(self, request): if request.form_in == "absolute" and request.scheme != "http": - self.send_response(make_error_response(400, "Invalid request scheme: %s" % request.scheme)) + self.send_response( + make_error_response(400, "Invalid request scheme: %s" % request.scheme)) raise HttpException("Invalid request scheme: %s" % request.scheme) expected_request_forms = { @@ -570,7 +583,11 @@ class HttpLayer(Layer): self.send_response(make_error_response( 407, "Proxy Authentication Required", - odict.ODictCaseless([[k,v] for k, v in self.config.authenticator.auth_challenge_headers().items()]) + odict.ODictCaseless( + [ + [k, v] for k, v in + self.config.authenticator.auth_challenge_headers().items() + ]) )) raise InvalidCredentials("Proxy Authentication Required") @@ -614,6 +631,9 @@ class RequestReplayThread(threading.Thread): if r.scheme == "https": connect_request = make_connect_request((r.host, r.port)) server.send(protocol.assemble(connect_request)) + resp = protocol.read_response("CONNECT") + if resp.code != 200: + raise HttpError(502, "Upstream server refuses CONNECT request") server.establish_ssl( self.config.clientcerts, sni=self.flow.server_conn.sni diff --git a/test/test_proxy.py b/test/test_proxy.py index 9c01ab639..fac4a4f4d 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -1,9 +1,8 @@ -import argparse from libmproxy import cmdline from libmproxy.proxy import ProxyConfig, process_proxy_options from libmproxy.proxy.connection import ServerConnection from libmproxy.proxy.primitives import ProxyError -from libmproxy.proxy.server import DummyServer, ProxyServer, ConnectionHandler +from libmproxy.proxy.server import DummyServer, ProxyServer, ConnectionHandler2 import tutils from libpathod import test from netlib import http, tcp @@ -175,8 +174,10 @@ class TestDummyServer: class TestConnectionHandler: def test_fatal_error(self): config = mock.Mock() - config.mode.get_upstream_server.side_effect = RuntimeError - c = ConnectionHandler( + root_layer = mock.Mock() + root_layer.side_effect = RuntimeError + config.mode.return_value = root_layer + c = ConnectionHandler2( config, mock.MagicMock(), ("127.0.0.1", diff --git a/test/test_server.py b/test/test_server.py index 1216a349a..7b66c5822 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -68,7 +68,7 @@ class CommonMixin: # SSL with the upstream proxy. rt = self.master.replay_request(l, block=True) assert not rt - if isinstance(self, tservers.HTTPUpstreamProxTest) and not self.ssl: + if isinstance(self, tservers.HTTPUpstreamProxTest): assert l.response.code == 502 else: assert l.error @@ -506,7 +506,7 @@ class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin, TcpMixin): p = pathoc.Pathoc(("localhost", self.proxy.port), fp=None) p.connect() r = p.request("get:/") - assert r.status_code == 400 + assert r.status_code == 502 class TestProxy(tservers.HTTPProxTest): @@ -724,9 +724,9 @@ class TestStreamRequest(tservers.HTTPProxTest): assert resp.headers["Transfer-Encoding"][0] == 'chunked' assert resp.status_code == 200 - chunks = list( - content for _, content, _ in protocol.read_http_body_chunked( - resp.headers, None, "GET", 200, False)) + chunks = list(protocol.read_http_body_chunked( + resp.headers, None, "GET", 200, False + )) assert chunks == ["this", "isatest", ""] connection.close() @@ -959,6 +959,9 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxTest): p = self.pathoc() req = p.request("get:'/p/418:b\"content\"'") + assert req.content == "content" + assert req.status_code == 418 + assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT and request # CONNECT, failing request, assert self.chain[0].tmaster.state.flow_count() == 4 @@ -967,8 +970,7 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxTest): assert self.chain[1].tmaster.state.flow_count() == 2 # (doesn't store (repeated) CONNECTs from chain[0] # as it is a regular proxy) - assert req.content == "content" - assert req.status_code == 418 + assert not self.chain[1].tmaster.state.flows[0].response # killed assert self.chain[1].tmaster.state.flows[1].response diff --git a/test/tservers.py b/test/tservers.py index 43ebf2bb2..dfd3f6277 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -181,22 +181,24 @@ class TResolver: def original_addr(self, sock): return ("127.0.0.1", self.port) - class TransparentProxTest(ProxTestBase): ssl = None resolver = TResolver @classmethod - @mock.patch("libmproxy.platform.resolver") - def setupAll(cls, _): + def setupAll(cls): super(TransparentProxTest, cls).setupAll() - if cls.ssl: - ports = [cls.server.port, cls.server2.port] - else: - ports = [] - cls.config.mode = TransparentProxyMode( - cls.resolver(cls.server.port), - ports) + + cls._resolver = mock.patch( + "libmproxy.platform.resolver", + new=lambda: cls.resolver(cls.server.port) + ) + cls._resolver.start() + + @classmethod + def teardownAll(cls): + cls._resolver.stop() + super(TransparentProxTest, cls).teardownAll() @classmethod def get_proxy_config(cls): @@ -270,48 +272,6 @@ class SocksModeTest(HTTPProxTest): d["mode"] = "socks5" return d -class SpoofModeTest(ProxTestBase): - ssl = None - - @classmethod - def get_proxy_config(cls): - d = ProxTestBase.get_proxy_config() - d["upstream_server"] = None - d["mode"] = "spoof" - return d - - def pathoc(self, sni=None): - """ - Returns a connected Pathoc instance. - """ - p = libpathod.pathoc.Pathoc( - ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None - ) - p.connect() - return p - - -class SSLSpoofModeTest(ProxTestBase): - ssl = True - - @classmethod - def get_proxy_config(cls): - d = ProxTestBase.get_proxy_config() - d["upstream_server"] = None - d["mode"] = "sslspoof" - d["spoofed_ssl_port"] = 443 - return d - - def pathoc(self, sni=None): - """ - Returns a connected Pathoc instance. - """ - p = libpathod.pathoc.Pathoc( - ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None - ) - p.connect() - return p - class ChainProxTest(ProxTestBase): """