diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 3d9ef722e..8605d7a14 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -8,6 +8,7 @@ import Cookie import cookielib import os import re +from libmproxy.protocol2.http import RequestReplayThread from netlib import odict, wsgi, tcp from netlib.http.semantics import CONTENT_MISSING @@ -934,7 +935,7 @@ class FlowMaster(controller.Master): f.response = None f.error = None self.process_new_request(f) - rt = http.RequestReplayThread( + rt = RequestReplayThread( self.server.config, f, self.masterq if run_scripthooks else False, diff --git a/libmproxy/protocol2/http.py b/libmproxy/protocol2/http.py index 32c0116be..792cf2661 100644 --- a/libmproxy/protocol2/http.py +++ b/libmproxy/protocol2/http.py @@ -212,10 +212,11 @@ class UpstreamConnectLayer(Layer): self.ctx.reconnect() self.send_to_server(self.connect_request) - def set_server(self, address, server_tls, sni, depth=1): + def set_server(self, address, server_tls=None, sni=None, depth=1): if depth == 1: if self.ctx.server_conn: self.ctx.reconnect() + address = Address.wrap(address) self.connect_request.host = address.host self.connect_request.port = address.port self.server_conn.address = address @@ -227,11 +228,16 @@ class HttpLayer(Layer): def __init__(self, ctx, mode): super(HttpLayer, self).__init__(ctx) self.mode = mode + self.__original_server_conn = None + "Contains the original destination in transparent mode, which needs to be restored" + "if an inline script modified the target server for a single http request" def __call__(self): + if self.mode == "transparent": + self.__original_server_conn = self.server_conn while True: try: - flow = HTTPFlow(self.client_conn, self.server_conn, live=True) + flow = HTTPFlow(self.client_conn, self.server_conn, live=self) try: request = self.read_from_client() @@ -288,7 +294,7 @@ class HttpLayer(Layer): flow.live = False def handle_regular_mode_connect(self, request): - self.set_server((request.host, request.port), False, None) + self.set_server((request.host, request.port)) self.send_to_client(make_connect_response(request.httpversion)) layer = self.ctx.next_layer(self) layer() @@ -433,11 +439,10 @@ class HttpLayer(Layer): if flow.request.form_in == "authority": flow.request.scheme = "http" # pseudo value else: - flow.request.host = self.ctx.server_conn.address.host - flow.request.port = self.ctx.server_conn.address.port - flow.request.scheme = "https" if self.server_conn.tls_established else "http" + flow.request.host = self.__original_server_conn.address.host + flow.request.port = self.__original_server_conn.address.port + flow.request.scheme = "https" if self.__original_server_conn.tls_established else "http" - # TODO: Expose .set_server functionality to inline scripts request_reply = self.channel.ask("request", flow) if request_reply is None or request_reply == KILL: raise Kill() diff --git a/libmproxy/protocol2/layer.py b/libmproxy/protocol2/layer.py index 7cb765912..f72320ff7 100644 --- a/libmproxy/protocol2/layer.py +++ b/libmproxy/protocol2/layer.py @@ -112,7 +112,7 @@ class ServerConnectionMixin(object): self.server_conn.address = address self.connect() - def set_server(self, address, server_tls, sni, depth=1): + def set_server(self, address, server_tls=None, sni=None, depth=1): if depth == 1: if self.server_conn: self._disconnect() diff --git a/libmproxy/protocol2/tls.py b/libmproxy/protocol2/tls.py index 433dd65d9..b1b80034f 100644 --- a/libmproxy/protocol2/tls.py +++ b/libmproxy/protocol2/tls.py @@ -110,9 +110,9 @@ class TlsLayer(Layer): if self._server_tls and not self.server_conn.tls_established: self._establish_tls_with_server() - def set_server(self, address, server_tls, sni, depth=1): + def set_server(self, address, server_tls=None, sni=None, depth=1): self.ctx.set_server(address, server_tls, sni, depth) - if server_tls is not None: + if depth == 1 and server_tls is not None: self._sni_from_server_change = sni self._server_tls = server_tls diff --git a/test/test_server.py b/test/test_server.py index 529024f53..e9c40d1a0 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -1,6 +1,7 @@ import socket import time from OpenSSL import SSL +from netlib.tcp import Address import netlib.tutils from netlib import tcp, http, socks @@ -655,63 +656,67 @@ class MasterRedirectRequest(tservers.TestMaster): redirect_port = None # Set by TestRedirectRequest def handle_request(self, f): - request = f.request - if request.path == "/p/201": - addr = f.live.c.server_conn.address - assert f.live.change_server( - ("127.0.0.1", self.redirect_port), ssl=False) - assert not f.live.change_server( - ("127.0.0.1", self.redirect_port), ssl=False) - tutils.raises( - "SSL handshake error", - f.live.change_server, - ("127.0.0.1", - self.redirect_port), - ssl=True) - assert f.live.change_server(addr, ssl=False) - request.url = "http://127.0.0.1:%s/p/201" % self.redirect_port - tservers.TestMaster.handle_request(self, f) + if f.request.path == "/p/201": + + # This part should have no impact, but it should not cause any exceptions. + addr = f.live.server_conn.address + addr2 = Address(("127.0.0.1", self.redirect_port)) + f.live.set_server(addr2) + f.live.connect() + f.live.set_server(addr) + f.live.connect() + + # This is the actual redirection. + f.request.port = self.redirect_port + super(MasterRedirectRequest, self).handle_request(f) def handle_response(self, f): f.response.content = str(f.client_conn.address.port) f.response.headers[ "server-conn-id"] = [str(f.server_conn.source_address.port)] - tservers.TestMaster.handle_response(self, f) + super(MasterRedirectRequest, self).handle_response(f) class TestRedirectRequest(tservers.HTTPProxTest): masterclass = MasterRedirectRequest + ssl = True def test_redirect(self): + """ + Imagine a single HTTPS connection with three requests: + + 1. First request should pass through unmodified + 2. Second request will be redirected to a different host by an inline script + 3. Third request should pass through unmodified + + This test verifies that the original destination is restored for the third request. + """ self.master.redirect_port = self.server2.port p = self.pathoc() self.server.clear_log() self.server2.clear_log() - r1 = p.request("get:'%s/p/200'" % self.server.urlbase) + r1 = p.request("get:'/p/200'") assert r1.status_code == 200 assert self.server.last_log() assert not self.server2.last_log() self.server.clear_log() self.server2.clear_log() - r2 = p.request("get:'%s/p/201'" % self.server.urlbase) + r2 = p.request("get:'/p/201'") assert r2.status_code == 201 assert not self.server.last_log() assert self.server2.last_log() self.server.clear_log() self.server2.clear_log() - r3 = p.request("get:'%s/p/202'" % self.server.urlbase) + r3 = p.request("get:'/p/202'") assert r3.status_code == 202 assert self.server.last_log() assert not self.server2.last_log() assert r1.content == r2.content == r3.content - assert r1.headers.get_first( - "server-conn-id") == r3.headers.get_first("server-conn-id") - # Make sure that we actually use the same connection in this test case class MasterStreamRequest(tservers.TestMaster):