New SNI handling mechanism.

This commit is contained in:
Aldo Cortesi 2013-03-01 09:05:39 +13:00
parent b077189dd5
commit ba674ad551
3 changed files with 93 additions and 59 deletions

View File

@ -80,8 +80,7 @@ class ServerConnection(tcp.TCPClient):
def terminate(self):
try:
if not self.wfile.closed:
self.wfile.flush()
self.wfile.flush()
self.connection.close()
except IOError:
pass
@ -110,6 +109,27 @@ class RequestReplayThread(threading.Thread):
self.channel.ask(err)
class HandleSNI:
def __init__(self, handler, client_conn, host, port, cert, key):
self.handler, self.client_conn, self.host, self.port = handler, client_conn, host, port
self.cert, self.key = cert, key
def __call__(self, connection):
try:
sn = connection.get_servername()
if sn:
self.handler.get_server_connection(self.client_conn, "https", self.host, self.port, sn)
new_context = SSL.Context(SSL.TLSv1_METHOD)
new_context.use_privatekey_file(self.key)
new_context.use_certificate_file(self.cert)
connection.set_context(new_context)
self.handler.sni = sn.decode("utf8").encode("idna")
# An unhandled exception in this method will core dump PyOpenSSL, so
# make dang sure it doesn't happen.
except Exception, e:
pass
class ProxyHandler(tcp.BaseHandler):
def __init__(self, config, connection, client_address, server, channel, server_version):
self.channel, self.server_version = channel, server_version
@ -266,18 +286,15 @@ class ProxyHandler(tcp.BaseHandler):
l = Log(msg)
self.channel.tell(l)
def find_cert(self, host, port, sni):
def find_cert(self, cc, host, port, sni):
if self.config.certfile:
return self.config.certfile
else:
sans = []
if not self.config.no_upstream_cert:
try:
cert = certutils.get_remote_cert(host, port, sni)
except tcp.NetLibError, v:
raise ProxyError(502, "Unable to get remote cert: %s"%str(v))
sans = cert.altnames
host = cert.cn.decode("utf8").encode("idna")
conn = self.get_server_connection(cc, "https", host, port, sni)
sans = conn.cert.altnames
host = conn.cert.cn.decode("utf8").encode("idna")
ret = self.config.certstore.get_cert(host, sans, self.config.cacert)
if not ret:
raise ProxyError(502, "mitmproxy: Unable to generate dummy cert.")
@ -292,11 +309,6 @@ class ProxyHandler(tcp.BaseHandler):
line = fp.readline()
return line
def handle_sni(self, conn):
sn = conn.get_servername()
if sn:
self.sni = sn.decode("utf8").encode("idna")
def read_request_transparent(self, client_conn):
orig = self.config.transparent_proxy["resolver"].original_addr(self.connection)
if not orig:
@ -304,9 +316,13 @@ class ProxyHandler(tcp.BaseHandler):
host, port = orig
if not self.ssl_established and (port in self.config.transparent_proxy["sslports"]):
scheme = "https"
certfile = self.find_cert(host, port, None)
dummycert = self.find_cert(client_conn, host, port, host)
try:
self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert)
sni = HandleSNI(
self, client_conn, host, port,
dummycert, self.config.certfile or self.config.cacert
)
self.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, handle_sni=sni)
except tcp.NetLibError, v:
raise ProxyError(400, str(v))
else:
@ -346,9 +362,14 @@ class ProxyHandler(tcp.BaseHandler):
'\r\n'
)
self.wfile.flush()
certfile = self.find_cert(host, port, None)
certfile = self.find_cert(client_conn, host, port, host)
sni = HandleSNI(
self, client_conn, host, port,
dummycert, self.config.certfile or self.config.cacert
)
try:
self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert)
self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert, handle_sni=sni)
except tcp.NetLibError, v:
raise ProxyError(400, str(v))
self.proxy_connect_state = (host, port, httpversion)

View File

@ -13,11 +13,7 @@ from libmproxy import flow, proxy
for a 200 response.
"""
class SanityMixin:
def test_http(self):
assert self.pathod("304").status_code == 304
assert self.master.state.view
class CommonMixin:
def test_large(self):
assert len(self.pathod("200:b@50k").content) == 1024*50
@ -40,19 +36,23 @@ class SanityMixin:
self.master.replay_request(l, block=True)
assert l.error
def test_http(self):
f = self.pathod("304")
assert f.status_code == 304
class TestHTTP(tservers.HTTPProxTest, SanityMixin):
def test_app(self):
p = self.pathoc()
ret = p.request("get:'http://testapp/'")
assert ret[1] == 200
assert ret[4] == "testapp"
l = self.master.state.view[0]
assert l.request.client_conn.address
assert "host" in l.request.headers
assert l.response.code == 304
class TestHTTP(tservers.HTTPProxTest, CommonMixin):
def test_app_err(self):
p = self.pathoc()
ret = p.request("get:'http://errapp/'")
assert ret[1] == 500
assert "ValueError" in ret[4]
assert ret.status_code == 500
assert "ValueError" in ret.content
def test_invalid_http(self):
t = tcp.TCPClient("127.0.0.1", self.proxy.port)
@ -71,16 +71,7 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin):
def test_upstream_ssl_error(self):
p = self.pathoc()
ret = p.request("get:'https://localhost:%s/'"%self.server.port)
assert ret[1] == 400
def test_http(self):
f = self.pathod("304")
assert f.status_code == 304
l = self.master.state.view[0]
assert l.request.client_conn.address
assert "host" in l.request.headers
assert l.response.code == 304
assert ret.status_code == 400
def test_connection_close(self):
# Add a body, so we have a content-length header, which combined with
@ -116,7 +107,7 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin):
# within our read loop.
with mock.patch("libmproxy.proxy.ProxyHandler.read_request") as m:
m.side_effect = IOError("error!")
tutils.raises("empty reply", self.pathod, "304")
tutils.raises("server disconnect", self.pathod, "304")
def test_get_connection_switching(self):
def switched(l):
@ -132,30 +123,39 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin):
def test_get_connection_err(self):
p = self.pathoc()
ret = p.request("get:'http://localhost:0'")
assert ret[1] == 502
assert ret.status_code == 502
class TestHTTPS(tservers.HTTPProxTest, SanityMixin):
class TestHTTPS(tservers.HTTPProxTest, CommonMixin):
ssl = True
clientcerts = True
def test_clientcert(self):
f = self.pathod("304")
assert self.server.last_log()["request"]["clientcert"]["keyinfo"]
def test_sni(self):
pass
class TestHTTPSCertfile(tservers.HTTPProxTest, SanityMixin):
class TestHTTPSCertfile(tservers.HTTPProxTest, CommonMixin):
ssl = True
certfile = True
def test_certfile(self):
assert self.pathod("304")
class TestReverse(tservers.ReverseProxTest, SanityMixin):
class TestReverse(tservers.ReverseProxTest, CommonMixin):
reverse = True
class TestTransparent(tservers.TransparentProxTest, SanityMixin):
class TestTransparent(tservers.TransparentProxTest, CommonMixin):
transparent = True
ssl = False
class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin):
transparent = True
ssl = True
class TestProxy(tservers.HTTPProxTest):
@ -232,7 +232,7 @@ class TestKillRequest(tservers.HTTPProxTest):
masterclass = MasterKillRequest
def test_kill(self):
p = self.pathoc()
tutils.raises("empty reply", self.pathod, "200")
tutils.raises("server disconnect", self.pathod, "200")
# Nothing should have hit the server
assert not self.server.last_log()
@ -246,7 +246,7 @@ class TestKillResponse(tservers.HTTPProxTest):
masterclass = MasterKillResponse
def test_kill(self):
p = self.pathoc()
tutils.raises("empty reply", self.pathod, "200")
tutils.raises("server disconnect", self.pathod, "200")
# The server should have seen a request
assert self.server.last_log()

View File

@ -126,20 +126,21 @@ class HTTPProxTest(ProxTestBase):
"""
Returns a connected Pathoc instance.
"""
p = libpathod.pathoc.Pathoc("localhost", self.proxy.port)
p = libpathod.pathoc.Pathoc("localhost", self.proxy.port, ssl=self.ssl)
p.connect(connect_to)
return p
def pathod(self, spec):
"""
Constructs a pathod request, with the appropriate base and proxy.
Constructs a pathod GET request, with the appropriate base and proxy.
"""
return hurl.get(
self.server.urlbase + "/p/" + spec,
proxy=self.proxies,
validate_cert=False,
#debug=hurl.utils.stdout_debug
)
if self.ssl:
p = self.pathoc(("127.0.0.1", self.server.port))
q = "get:'/p/%s'"%spec
else:
p = self.pathoc()
q = "get:'%s/p/%s'"%(self.server.urlbase, spec)
return p.request(q)
class TResolver:
@ -155,9 +156,13 @@ class TransparentProxTest(ProxTestBase):
@classmethod
def get_proxy_config(cls):
d = ProxTestBase.get_proxy_config()
if cls.ssl:
ports = [cls.server.port, cls.server2.port]
else:
ports = []
d["transparent_proxy"] = dict(
resolver = TResolver(cls.server.port),
sslports = []
sslports = ports
)
return d
@ -166,12 +171,20 @@ class TransparentProxTest(ProxTestBase):
Constructs a pathod request, with the appropriate base and proxy.
"""
r = hurl.get(
"http://127.0.0.1:%s"%self.proxy.port + "/p/" + spec,
"%s://127.0.0.1:%s"%(self.scheme, self.proxy.port) + "/p/" + spec,
validate_cert=False,
#debug=hurl.utils.stdout_debug
)
return r
def pathoc(self, connect= None):
"""
Returns a connected Pathoc instance.
"""
p = libpathod.pathoc.Pathoc("localhost", self.proxy.port)
p.connect(connect_to)
return p
class ReverseProxTest(ProxTestBase):
ssl = None