From ba674ad5514c5f30315fc688a07fdac634d94dfc Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 1 Mar 2013 09:05:39 +1300 Subject: [PATCH] New SNI handling mechanism. --- libmproxy/proxy.py | 57 ++++++++++++++++++++++++++++------------- test/test_server.py | 62 ++++++++++++++++++++++----------------------- test/tservers.py | 33 ++++++++++++++++-------- 3 files changed, 93 insertions(+), 59 deletions(-) diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 7c229064d..c9ceb8de5 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -80,8 +80,7 @@ class ServerConnection(tcp.TCPClient): def terminate(self): try: - if not self.wfile.closed: - self.wfile.flush() + self.wfile.flush() self.connection.close() except IOError: pass @@ -110,6 +109,27 @@ class RequestReplayThread(threading.Thread): self.channel.ask(err) +class HandleSNI: + def __init__(self, handler, client_conn, host, port, cert, key): + self.handler, self.client_conn, self.host, self.port = handler, client_conn, host, port + self.cert, self.key = cert, key + + def __call__(self, connection): + try: + sn = connection.get_servername() + if sn: + self.handler.get_server_connection(self.client_conn, "https", self.host, self.port, sn) + new_context = SSL.Context(SSL.TLSv1_METHOD) + new_context.use_privatekey_file(self.key) + new_context.use_certificate_file(self.cert) + connection.set_context(new_context) + self.handler.sni = sn.decode("utf8").encode("idna") + # An unhandled exception in this method will core dump PyOpenSSL, so + # make dang sure it doesn't happen. + except Exception, e: + pass + + class ProxyHandler(tcp.BaseHandler): def __init__(self, config, connection, client_address, server, channel, server_version): self.channel, self.server_version = channel, server_version @@ -266,18 +286,15 @@ class ProxyHandler(tcp.BaseHandler): l = Log(msg) self.channel.tell(l) - def find_cert(self, host, port, sni): + def find_cert(self, cc, host, port, sni): if self.config.certfile: return self.config.certfile else: sans = [] if not self.config.no_upstream_cert: - try: - cert = certutils.get_remote_cert(host, port, sni) - except tcp.NetLibError, v: - raise ProxyError(502, "Unable to get remote cert: %s"%str(v)) - sans = cert.altnames - host = cert.cn.decode("utf8").encode("idna") + conn = self.get_server_connection(cc, "https", host, port, sni) + sans = conn.cert.altnames + host = conn.cert.cn.decode("utf8").encode("idna") ret = self.config.certstore.get_cert(host, sans, self.config.cacert) if not ret: raise ProxyError(502, "mitmproxy: Unable to generate dummy cert.") @@ -292,11 +309,6 @@ class ProxyHandler(tcp.BaseHandler): line = fp.readline() return line - def handle_sni(self, conn): - sn = conn.get_servername() - if sn: - self.sni = sn.decode("utf8").encode("idna") - def read_request_transparent(self, client_conn): orig = self.config.transparent_proxy["resolver"].original_addr(self.connection) if not orig: @@ -304,9 +316,13 @@ class ProxyHandler(tcp.BaseHandler): host, port = orig if not self.ssl_established and (port in self.config.transparent_proxy["sslports"]): scheme = "https" - certfile = self.find_cert(host, port, None) + dummycert = self.find_cert(client_conn, host, port, host) try: - self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert) + sni = HandleSNI( + self, client_conn, host, port, + dummycert, self.config.certfile or self.config.cacert + ) + self.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, handle_sni=sni) except tcp.NetLibError, v: raise ProxyError(400, str(v)) else: @@ -346,9 +362,14 @@ class ProxyHandler(tcp.BaseHandler): '\r\n' ) self.wfile.flush() - certfile = self.find_cert(host, port, None) + certfile = self.find_cert(client_conn, host, port, host) + + sni = HandleSNI( + self, client_conn, host, port, + dummycert, self.config.certfile or self.config.cacert + ) try: - self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert) + self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert, handle_sni=sni) except tcp.NetLibError, v: raise ProxyError(400, str(v)) self.proxy_connect_state = (host, port, httpversion) diff --git a/test/test_server.py b/test/test_server.py index 034fab418..466c0f946 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -13,11 +13,7 @@ from libmproxy import flow, proxy for a 200 response. """ -class SanityMixin: - def test_http(self): - assert self.pathod("304").status_code == 304 - assert self.master.state.view - +class CommonMixin: def test_large(self): assert len(self.pathod("200:b@50k").content) == 1024*50 @@ -40,19 +36,23 @@ class SanityMixin: self.master.replay_request(l, block=True) assert l.error + def test_http(self): + f = self.pathod("304") + assert f.status_code == 304 -class TestHTTP(tservers.HTTPProxTest, SanityMixin): - def test_app(self): - p = self.pathoc() - ret = p.request("get:'http://testapp/'") - assert ret[1] == 200 - assert ret[4] == "testapp" + l = self.master.state.view[0] + assert l.request.client_conn.address + assert "host" in l.request.headers + assert l.response.code == 304 + + +class TestHTTP(tservers.HTTPProxTest, CommonMixin): def test_app_err(self): p = self.pathoc() ret = p.request("get:'http://errapp/'") - assert ret[1] == 500 - assert "ValueError" in ret[4] + assert ret.status_code == 500 + assert "ValueError" in ret.content def test_invalid_http(self): t = tcp.TCPClient("127.0.0.1", self.proxy.port) @@ -71,16 +71,7 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin): def test_upstream_ssl_error(self): p = self.pathoc() ret = p.request("get:'https://localhost:%s/'"%self.server.port) - assert ret[1] == 400 - - def test_http(self): - f = self.pathod("304") - assert f.status_code == 304 - - l = self.master.state.view[0] - assert l.request.client_conn.address - assert "host" in l.request.headers - assert l.response.code == 304 + assert ret.status_code == 400 def test_connection_close(self): # Add a body, so we have a content-length header, which combined with @@ -116,7 +107,7 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin): # within our read loop. with mock.patch("libmproxy.proxy.ProxyHandler.read_request") as m: m.side_effect = IOError("error!") - tutils.raises("empty reply", self.pathod, "304") + tutils.raises("server disconnect", self.pathod, "304") def test_get_connection_switching(self): def switched(l): @@ -132,30 +123,39 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin): def test_get_connection_err(self): p = self.pathoc() ret = p.request("get:'http://localhost:0'") - assert ret[1] == 502 + assert ret.status_code == 502 -class TestHTTPS(tservers.HTTPProxTest, SanityMixin): +class TestHTTPS(tservers.HTTPProxTest, CommonMixin): ssl = True clientcerts = True def test_clientcert(self): f = self.pathod("304") assert self.server.last_log()["request"]["clientcert"]["keyinfo"] + def test_sni(self): + pass -class TestHTTPSCertfile(tservers.HTTPProxTest, SanityMixin): + +class TestHTTPSCertfile(tservers.HTTPProxTest, CommonMixin): ssl = True certfile = True def test_certfile(self): assert self.pathod("304") -class TestReverse(tservers.ReverseProxTest, SanityMixin): +class TestReverse(tservers.ReverseProxTest, CommonMixin): reverse = True -class TestTransparent(tservers.TransparentProxTest, SanityMixin): +class TestTransparent(tservers.TransparentProxTest, CommonMixin): transparent = True + ssl = False + + +class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin): + transparent = True + ssl = True class TestProxy(tservers.HTTPProxTest): @@ -232,7 +232,7 @@ class TestKillRequest(tservers.HTTPProxTest): masterclass = MasterKillRequest def test_kill(self): p = self.pathoc() - tutils.raises("empty reply", self.pathod, "200") + tutils.raises("server disconnect", self.pathod, "200") # Nothing should have hit the server assert not self.server.last_log() @@ -246,7 +246,7 @@ class TestKillResponse(tservers.HTTPProxTest): masterclass = MasterKillResponse def test_kill(self): p = self.pathoc() - tutils.raises("empty reply", self.pathod, "200") + tutils.raises("server disconnect", self.pathod, "200") # The server should have seen a request assert self.server.last_log() diff --git a/test/tservers.py b/test/tservers.py index 998ad6c60..c8bc7100e 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -126,20 +126,21 @@ class HTTPProxTest(ProxTestBase): """ Returns a connected Pathoc instance. """ - p = libpathod.pathoc.Pathoc("localhost", self.proxy.port) + p = libpathod.pathoc.Pathoc("localhost", self.proxy.port, ssl=self.ssl) p.connect(connect_to) return p def pathod(self, spec): """ - Constructs a pathod request, with the appropriate base and proxy. + Constructs a pathod GET request, with the appropriate base and proxy. """ - return hurl.get( - self.server.urlbase + "/p/" + spec, - proxy=self.proxies, - validate_cert=False, - #debug=hurl.utils.stdout_debug - ) + if self.ssl: + p = self.pathoc(("127.0.0.1", self.server.port)) + q = "get:'/p/%s'"%spec + else: + p = self.pathoc() + q = "get:'%s/p/%s'"%(self.server.urlbase, spec) + return p.request(q) class TResolver: @@ -155,9 +156,13 @@ class TransparentProxTest(ProxTestBase): @classmethod def get_proxy_config(cls): d = ProxTestBase.get_proxy_config() + if cls.ssl: + ports = [cls.server.port, cls.server2.port] + else: + ports = [] d["transparent_proxy"] = dict( resolver = TResolver(cls.server.port), - sslports = [] + sslports = ports ) return d @@ -166,12 +171,20 @@ class TransparentProxTest(ProxTestBase): Constructs a pathod request, with the appropriate base and proxy. """ r = hurl.get( - "http://127.0.0.1:%s"%self.proxy.port + "/p/" + spec, + "%s://127.0.0.1:%s"%(self.scheme, self.proxy.port) + "/p/" + spec, validate_cert=False, #debug=hurl.utils.stdout_debug ) return r + def pathoc(self, connect= None): + """ + Returns a connected Pathoc instance. + """ + p = libpathod.pathoc.Pathoc("localhost", self.proxy.port) + p.connect(connect_to) + return p + class ReverseProxTest(ProxTestBase): ssl = None