From 17f09aa0afe9695505b746c370e1c5b889c19058 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 28 Jan 2014 17:28:20 +0100 Subject: [PATCH] unify ipv4/ipv6 address handling --- .travis.yml | 4 ++-- libmproxy/protocol.py | 7 +++---- libmproxy/proxy.py | 49 +++++++++++++++++++------------------------ test/test_server.py | 4 ++-- 4 files changed, 28 insertions(+), 36 deletions(-) diff --git a/.travis.yml b/.travis.yml index eb1fc9a3b..d6a6c149c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,9 +5,9 @@ python: install: - "pip install coveralls --use-mirrors" - "pip install nose-cov --use-mirrors" - - "pip install --upgrade git+https://github.com/mitmproxy/netlib.git" + - "pip install --upgrade git+https://github.com/mitmproxy/netlib.git@tcp_proxy" - "pip install -r requirements.txt --use-mirrors" - - "pip install --upgrade git+https://github.com/mitmproxy/pathod.git" + - "pip install --upgrade git+https://github.com/mitmproxy/pathod.git@tcp_proxy" # command to run tests, e.g. python setup.py test script: - "nosetests --with-cov --cov-report term-missing" diff --git a/libmproxy/protocol.py b/libmproxy/protocol.py index dcc8b75e5..279ff0159 100644 --- a/libmproxy/protocol.py +++ b/libmproxy/protocol.py @@ -115,7 +115,7 @@ class HTTPResponse(HTTPMessage): class HTTPRequest(HTTPMessage): def __init__(self, form_in, method, scheme, host, port, path, httpversion, headers, content, - timestamp_start, timestamp_end, form_out=None, ip=None): + timestamp_start, timestamp_end, form_out=None): self.form_in = form_in self.method = method self.scheme = scheme @@ -129,7 +129,6 @@ class HTTPRequest(HTTPMessage): self.timestamp_end = timestamp_end self.form_out = form_out or self.form_in - self.ip = ip # resolved ip address assert isinstance(headers, ODictCaseless) #FIXME: Compatibility Fix @@ -352,7 +351,7 @@ class HTTPHandler(ProtocolHandler): if request.form_in == "authority": directly_addressed_at_mitmproxy = (self.c.mode == "regular") and not self.c.config.forward_proxy if directly_addressed_at_mitmproxy: - self.c.establish_server_connection(request.host, request.port) + self.c.establish_server_connection((request.host, request.port)) self.c.client_conn.wfile.write( 'HTTP/1.1 200 Connection established\r\n' + ('Proxy-agent: %s\r\n' % self.c.server_version) + @@ -369,7 +368,7 @@ class HTTPHandler(ProtocolHandler): request.form_out = "origin" if ((not self.c.server_conn) or (self.c.server_conn.address != (request.host, request.port))): - self.c.establish_server_connection(request.host, request.port) + self.c.establish_server_connection((request.host, request.port)) else: raise http.HttpError(400, "Invalid Request") diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 96363a675..9e3e317b0 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -40,18 +40,13 @@ class ProxyConfig: class ClientConnection(tcp.BaseHandler): - def __init__(self, client_connection, host, port): - tcp.BaseHandler.__init__(self, client_connection) - self.host, self.port = host, port + def __init__(self, client_connection, address): + tcp.BaseHandler.__init__(self, client_connection, address) self.timestamp_start = utils.timestamp() self.timestamp_end = None self.timestamp_ssl_setup = None - @property - def address(self): - return self.host, self.port - def convert_to_ssl(self, *args, **kwargs): tcp.BaseHandler.convert_to_ssl(self, *args, **kwargs) self.timestamp_ssl_setup = utils.timestamp() @@ -62,21 +57,19 @@ class ClientConnection(tcp.BaseHandler): class ServerConnection(tcp.TCPClient): - def __init__(self, host, port): - tcp.TCPClient.__init__(self, host, port) + def __init__(self, address): + tcp.TCPClient.__init__(self, address) + self.peername = None self.timestamp_start = None self.timestamp_end = None self.timestamp_tcp_setup = None self.timestamp_ssl_setup = None - @property - def address(self): - return self.host, self.port - def connect(self): self.timestamp_start = utils.timestamp() tcp.TCPClient.connect(self) + self.peername = self.connection.getpeername() self.timestamp_tcp_setup = utils.timestamp() def establish_ssl(self, clientcerts, sni): @@ -125,7 +118,7 @@ class RequestReplayThread(threading.Thread): class ConnectionHandler: def __init__(self, config, client_connection, client_address, server, channel, server_version): self.config = config - self.client_conn = ClientConnection(client_connection, *client_address) + self.client_conn = ClientConnection(client_connection, client_address) self.server_conn = None self.channel, self.server_version = channel, server_version @@ -140,9 +133,9 @@ class ConnectionHandler: self.mode = "transparent" def del_server_connection(self): - if self.server_conn and self.server_conn.connection: + if self.server_conn and self.server_conn.connection: self.server_conn.finish() - self.log("serverdisconnect", ["%s:%s" % (self.server_conn.host, self.server_conn.port)]) + self.log("serverdisconnect", ["%s:%s" % self.server_conn.address]) self.channel.tell("serverdisconnect", self) self.server_conn = None self.sni = None @@ -169,7 +162,7 @@ class ConnectionHandler: self.determine_conntype() if server_address: - self.establish_server_connection(*server_address) + self.establish_server_connection(server_address) self._handle_ssl() while not self.close: @@ -191,7 +184,7 @@ class ConnectionHandler: Check if we can already identify SSL connections. """ if self.config.transparent_proxy: - client_ssl = server_ssl = (self.server_conn.port in self.config.transparent_proxy["sslports"]) + client_ssl = server_ssl = (self.server_conn.address.port in self.config.transparent_proxy["sslports"]) elif self.config.reverse_proxy: client_ssl = server_ssl = (self.config.reverse_proxy[0] == "https") # TODO: Make protocol generic (as with transparent proxies) @@ -205,18 +198,18 @@ class ConnectionHandler: #TODO: Add ruleset to select correct protocol depending on mode/target port etc. self.conntype = "http" - def establish_server_connection(self, host, port): + def establish_server_connection(self, address): """ Establishes a new server connection to the given server If there is already an existing server connection, it will be killed. """ self.del_server_connection() - self.server_conn = ServerConnection(host, port) + self.server_conn = ServerConnection(address) try: self.server_conn.connect() except tcp.NetLibError, v: raise ProxyError(502, v) - self.log("serverconnect", ["%s:%s" % (host, port)]) + self.log("serverconnect", ["%s:%s" % address]) self.channel.tell("serverconnect", self) def establish_ssl(self, client=False, server=False): @@ -227,7 +220,7 @@ class ConnectionHandler: A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening """ # TODO: Implement SSL pass-through handling and change conntype - if self.server_conn.host == "ycombinator.com": + if self.server_conn.address.host == "ycombinator.com": self.conntype = "tcp" if server: @@ -244,14 +237,14 @@ class ConnectionHandler: def server_reconnect(self, no_ssl=False): self.log("server reconnect") had_ssl, sni = self.server_conn.ssl_established, self.sni - self.establish_server_connection(*self.server_conn.address) + self.establish_server_connection(self.server_conn.address) if had_ssl and not no_ssl: self.sni = sni self.establish_ssl(server=True) def log(self, msg, subs=()): msg = [ - "%s:%s: %s" % (self.client_conn.host, self.client_conn.port, msg) + "%s:%s: %s" % (self.client_conn.address.host, self.client_conn.address.port, msg) ] for i in subs: msg.append(" -> " + i) @@ -263,7 +256,7 @@ class ConnectionHandler: with open(self.config.certfile, "rb") as f: return certutils.SSLCert.from_pem(f.read()) else: - host = self.server_conn.host + host = self.server_conn.address.host sans = [] if not self.config.no_upstream_cert or not self.server_conn.ssl_established: upstream_cert = self.server_conn.cert @@ -307,14 +300,14 @@ class ProxyServer(tcp.TCPServer): allow_reuse_address = True bound = True - def __init__(self, config, port, address='', server_version=version.NAMEVERSION): + def __init__(self, config, port, host='', server_version=version.NAMEVERSION): """ Raises ProxyServerError if there's a startup problem. """ - self.config, self.port, self.address = config, port, address + self.config = config self.server_version = server_version try: - tcp.TCPServer.__init__(self, (address, port)) + tcp.TCPServer.__init__(self, (host, port)) except socket.error, v: raise ProxyServerError('Error starting proxy server: ' + v.strerror) self.channel = None diff --git a/test/test_server.py b/test/test_server.py index 646460ab7..ba152dc20 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -46,7 +46,7 @@ class CommonMixin: assert l.response.code == 304 def test_invalid_http(self): - t = tcp.TCPClient("127.0.0.1", self.proxy.port) + t = tcp.TCPClient(("127.0.0.1", self.proxy.address.port)) t.connect() t.wfile.write("invalid\r\n\r\n") t.wfile.flush() @@ -70,7 +70,7 @@ class TestHTTP(tservers.HTTPProxTest, CommonMixin, AppMixin): assert "ValueError" in ret.content def test_invalid_connect(self): - t = tcp.TCPClient("127.0.0.1", self.proxy.port) + t = tcp.TCPClient(("127.0.0.1", self.proxy.address.port)) t.connect() t.wfile.write("CONNECT invalid\n\n") t.wfile.flush()