From f6253a80fff2ed3a6f7846e866469c8776f1254d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 4 Feb 2014 02:56:59 +0100 Subject: [PATCH] add priorities for the destination server address --- libmproxy/protocol/http.py | 36 +++++----- libmproxy/proxy.py | 133 +++++++++++++++++++++++++++---------- 2 files changed, 114 insertions(+), 55 deletions(-) diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 5faf78e05..8c44461eb 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -5,7 +5,7 @@ from netlib import http, tcp, http_status, odict from netlib.odict import ODict, ODictCaseless from . import ProtocolHandler, ConnectionTypeChange, KILL from .. import encoding, utils, version, filt, controller, stateobject -from ..proxy import ProxyError +from ..proxy import ProxyError, AddressPriority from ..flow import Flow, Error @@ -816,7 +816,7 @@ class HTTPHandler(ProtocolHandler): raise v def handle_flow(self): - flow = HTTPFlow(self.c.client_conn, self.c.server_conn, None, None, None) + flow = HTTPFlow(self.c.client_conn, self.c.server_conn) try: flow.request = HTTPRequest.from_stream(self.c.client_conn.rfile, body_size_limit=self.c.config.body_size_limit) @@ -831,9 +831,10 @@ class HTTPHandler(ProtocolHandler): flow.response = request_reply else: self.process_request(flow.request) + self.c.establish_server_connection() flow.response = self.get_response_from_server(flow.request) - self.c.log("response", [flow.response._assemble_response_line()]) + self.c.log("response", [flow.response._assemble_first_line()]) response_reply = self.c.channel.ask("response" if LEGACY else "httpresponse", flow.response if LEGACY else flow) if response_reply is None or response_reply == KILL: @@ -853,16 +854,6 @@ class HTTPHandler(ProtocolHandler): flow.server_conn = self.c.server_conn - """ - FIXME: Remove state test - d = flow._get_state() - print d - flow._load_state(d) - print flow._get_state() - copy = HTTPFlow._from_state(d) - print copy._get_state() - """ - return True except (HttpAuthenticationError, http.HttpError, ProxyError, tcp.NetLibError), e: self.handle_error(e, flow) @@ -887,8 +878,10 @@ class HTTPHandler(ProtocolHandler): if flow: flow.error = Error(err) - self.c.channel.ask("error" if LEGACY else "httperror", - flow.error if LEGACY else flow) + if not (LEGACY and not flow.request) and not (LEGACY and flow.request and flow.response): + # no flows without request or with both request and response in legacy mode + self.c.channel.ask("error" if LEGACY else "httperror", + flow.error if LEGACY else flow) else: pass # FIXME: Is there any use case for persisting errors that occur outside of flows? @@ -923,6 +916,7 @@ class HTTPHandler(ProtocolHandler): This isn't particular beautiful code, but it isolates this rare edge-case from the protocol-agnostic ConnectionHandler """ + self.c.log("Received CONNECT request. Upgrading to SSL...") self.c.mode = "transparent" self.c.determine_conntype() self.c.establish_ssl(server=True, client=True) @@ -933,7 +927,7 @@ class HTTPHandler(ProtocolHandler): def reconnect_http_proxy(): self.c.log("Hooked reconnect function") - self.c.log("Hook: Run original redirect") + self.c.log("Hook: Run original reconnect") original_reconnect_func(no_ssl=True) self.c.log("Hook: Write CONNECT request to upstream proxy", [upstream_request._assemble_first_line()]) self.c.server_conn.wfile.write(upstream_request._assemble()) @@ -948,6 +942,7 @@ class HTTPHandler(ProtocolHandler): self.c.server_reconnect = reconnect_http_proxy + self.c.log("Upgrade to SSL completed.") raise ConnectionTypeChange def process_request(self, request): @@ -958,9 +953,10 @@ class HTTPHandler(ProtocolHandler): # If we have a CONNECT request, we might need to intercept if request.form_in == "authority": - directly_addressed_at_mitmproxy = (self.c.mode == "regular") and not self.c.config.forward_proxy + directly_addressed_at_mitmproxy = (self.c.mode == "regular" and not self.c.config.forward_proxy) if directly_addressed_at_mitmproxy: - self.c.establish_server_connection((request.host, request.port)) + self.c.set_server_address((request.host, request.port), AddressPriority.FROM_PROTOCOL) + self.c.establish_server_connection() self.c.client_conn.wfile.write( 'HTTP/1.1 200 Connection established\r\n' + ('Proxy-agent: %s\r\n' % self.c.server_version) + @@ -977,9 +973,7 @@ class HTTPHandler(ProtocolHandler): raise http.HttpError(400, "Invalid Request") if not self.c.config.forward_proxy: request.form_out = "origin" - if ((not self.c.server_conn) or - (self.c.server_conn.address != (request.host, request.port))): - self.c.establish_server_connection((request.host, request.port)) + self.c.set_server_address((request.host, request.port), AddressPriority.FROM_PROTOCOL) else: raise http.HttpError(400, "Invalid Request") diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index e3e40c7ba..4842a81f1 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -7,6 +7,23 @@ import utils, version, platform, controller, stateobject TRANSPARENT_SSL_PORTS = [443, 8443] +class AddressPriority(object): + """ + Enum that signifies the priority of the given address when choosing the destination host. + Higher is better (None < i) + """ + FORCE = 5 + """forward mode""" + MANUALLY_CHANGED = 4 + """user changed the target address in the ui""" + FROM_SETTINGS = 3 + """reverse proxy mode""" + FROM_CONNECTION = 2 + """derived from transparent resolver""" + FROM_PROTOCOL = 1 + """derived from protocol (e.g. absolute-form http requests)""" + + class ProxyError(Exception): def __init__(self, code, msg, headers=None): self.code, self.msg, self.headers = code, msg, headers @@ -189,6 +206,7 @@ class ConnectionHandler: self.close = False self.conntype = None self.sni = None + self.server_address_priority = None self.mode = "regular" if self.config.reverse_proxy: @@ -196,14 +214,6 @@ class ConnectionHandler: if self.config.transparent_proxy: self.mode = "transparent" - def del_server_connection(self): - if self.server_conn and self.server_conn.connection: - self.server_conn.finish() - self.log("serverdisconnect", ["%s:%s" % (self.server_conn.address.host, self.server_conn.address.port)]) - self.channel.tell("serverdisconnect", self) - self.server_conn = None - self.sni = None - def handle(self): self.log("clientconnect") self.channel.ask("clientconnect", self) @@ -214,20 +224,23 @@ class ConnectionHandler: try: # Can we already identify the target server and connect to it? server_address = None + address_priority = None if self.config.forward_proxy: server_address = self.config.forward_proxy[1:] - else: - if self.config.reverse_proxy: - server_address = self.config.reverse_proxy[1:] - elif self.config.transparent_proxy: - server_address = self.config.transparent_proxy["resolver"].original_addr( - self.client_conn.connection) - if not server_address: - raise ProxyError(502, "Transparent mode failure: could not resolve original destination.") - self.log("transparent to %s:%s" % server_address) + address_priority = AddressPriority.FORCE + elif self.config.reverse_proxy: + server_address = self.config.reverse_proxy[1:] + address_priority = AddressPriority.FROM_SETTINGS + elif self.config.transparent_proxy: + server_address = self.config.transparent_proxy["resolver"].original_addr( + self.client_conn.connection) + if not server_address: + raise ProxyError(502, "Transparent mode failure: could not resolve original destination.") + address_priority = AddressPriority.FROM_CONNECTION + self.log("transparent to %s:%s" % server_address) if server_address: - self.establish_server_connection(server_address) + self.set_server_address(server_address, address_priority) self._handle_ssl() while not self.close: @@ -252,53 +265,95 @@ class ConnectionHandler: def _handle_ssl(self): """ + Helper function of .handle() Check if we can already identify SSL connections. + If so, connect to the server and establish an SSL connection """ + client_ssl = False + server_ssl = False + if self.config.transparent_proxy: client_ssl = server_ssl = (self.server_conn.address.port in self.config.transparent_proxy["sslports"]) elif self.config.reverse_proxy: client_ssl = server_ssl = (self.config.reverse_proxy[0] == "https") # TODO: Make protocol generic (as with transparent proxies) # TODO: Add SSL-terminating capatbility (SSL -> mitmproxy -> plain and vice versa) - self.establish_ssl(client=client_ssl, server=server_ssl) + if client_ssl or server_ssl: + self.establish_server_connection() + self.establish_ssl(client=client_ssl, server=server_ssl) - def finish(self): - self.client_conn.finish() + def del_server_connection(self): + """ + Deletes an existing server connection. + """ + if self.server_conn and self.server_conn.connection: + self.server_conn.finish() + self.log("serverdisconnect", ["%s:%s" % (self.server_conn.address.host, self.server_conn.address.port)]) + self.channel.tell("serverdisconnect", self) + self.server_conn = None + self.server_address_priority = None + self.sni = None def determine_conntype(self): #TODO: Add ruleset to select correct protocol depending on mode/target port etc. self.conntype = "http" - def establish_server_connection(self, address): + def set_server_address(self, address, priority): """ - Establishes a new server connection to the given server - If there is already an existing server connection, it will be killed. + Sets a new server address with the given priority + @type priority: AddressPriority """ - self.del_server_connection() - self.server_conn = ServerConnection(address) + address = tcp.Address.wrap(address) + self.log("Set server address: %s:%s" % (address.host, address.port)) + if self.server_conn and (self.server_address_priority > priority): + self.log("Server address priority too low (is: %s, got: %s)" % (self.server_address_priority, priority)) + return + + self.address_priority = priority + + if self.server_conn and (self.server_conn.address == address): + self.log("Addresses match, skip.") + return + + server_conn = ServerConnection(address) + if self.server_conn and self.server_conn.connection: + self.del_server_connection() + self.server_conn = server_conn + self.establish_server_connection() + else: + self.server_conn = server_conn + + def establish_server_connection(self): + """ + Establishes a new server connection. + If there is already an existing server connection, the function returns immediately. + """ + if self.server_conn.connection: + return + self.log("serverconnect", ["%s:%s" % self.server_conn.address()[:2]]) + self.channel.tell("serverconnect", self) try: self.server_conn.connect() except tcp.NetLibError, v: raise ProxyError(502, v) - self.log("serverconnect", ["%s:%s" % address[:2]]) - self.channel.tell("serverconnect", self) def establish_ssl(self, client=False, server=False): """ Establishes SSL on the existing connection(s) to the server or the client, as specified by the parameters. If the target server is on the pass-through list, - the conntype attribute will be changed and no the SSL connection won't be wrapped. + the conntype attribute will be changed and the SSL connection won't be wrapped. A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening """ # TODO: Implement SSL pass-through handling and change conntype passthrough = [ "echo.websocket.org", - "174.129.224.73" # echo.websocket.org, transparent mode + "174.129.224.73" # echo.websocket.org, transparent mode ] if self.server_conn.address.host in passthrough or self.sni in passthrough: self.conntype = "tcp" return + # Logging if client or server: subs = [] if client: @@ -319,13 +374,21 @@ class ConnectionHandler: handle_sni=self.handle_sni) def server_reconnect(self, no_ssl=False): - had_ssl, sni = self.server_conn.ssl_established, self.sni + address = self.server_conn.address + had_ssl = self.server_conn.ssl_established + priority = self.server_address_priority + sni = self.sni self.log("(server reconnect follows)") - self.establish_server_connection(self.server_conn.address()) + self.del_server_connection() + self.set_server_address(address, priority) + self.establish_server_connection() if had_ssl and not no_ssl: self.sni = sni self.establish_ssl(server=True) + def finish(self): + self.client_conn.finish() + def log(self, msg, subs=()): msg = [ "%s:%s: %s" % (self.client_conn.address.host, self.client_conn.address.port, msg) @@ -363,6 +426,7 @@ class ConnectionHandler: sn = connection.get_servername() if sn and sn != self.sni: self.sni = sn.decode("utf8").encode("idna") + self.log("SNI received: %s" % self.sni) self.server_reconnect() # reconnect to upstream server with SNI # Now, change client context to reflect changed certificate: new_context = SSL.Context(SSL.TLSv1_METHOD) @@ -372,11 +436,12 @@ class ConnectionHandler: connection.set_context(new_context) # An unhandled exception in this method will core dump PyOpenSSL, so # make dang sure it doesn't happen. - except Exception, e: # pragma: no cover + except Exception, e: # pragma: no cover pass -class ProxyServerError(Exception): pass +class ProxyServerError(Exception): + pass class ProxyServer(tcp.TCPServer):