New SNI handling mechanism.
This commit is contained in:
parent
b077189dd5
commit
ba674ad551
|
@ -80,8 +80,7 @@ class ServerConnection(tcp.TCPClient):
|
|||
|
||||
def terminate(self):
|
||||
try:
|
||||
if not self.wfile.closed:
|
||||
self.wfile.flush()
|
||||
self.wfile.flush()
|
||||
self.connection.close()
|
||||
except IOError:
|
||||
pass
|
||||
|
@ -110,6 +109,27 @@ class RequestReplayThread(threading.Thread):
|
|||
self.channel.ask(err)
|
||||
|
||||
|
||||
class HandleSNI:
|
||||
def __init__(self, handler, client_conn, host, port, cert, key):
|
||||
self.handler, self.client_conn, self.host, self.port = handler, client_conn, host, port
|
||||
self.cert, self.key = cert, key
|
||||
|
||||
def __call__(self, connection):
|
||||
try:
|
||||
sn = connection.get_servername()
|
||||
if sn:
|
||||
self.handler.get_server_connection(self.client_conn, "https", self.host, self.port, sn)
|
||||
new_context = SSL.Context(SSL.TLSv1_METHOD)
|
||||
new_context.use_privatekey_file(self.key)
|
||||
new_context.use_certificate_file(self.cert)
|
||||
connection.set_context(new_context)
|
||||
self.handler.sni = sn.decode("utf8").encode("idna")
|
||||
# An unhandled exception in this method will core dump PyOpenSSL, so
|
||||
# make dang sure it doesn't happen.
|
||||
except Exception, e:
|
||||
pass
|
||||
|
||||
|
||||
class ProxyHandler(tcp.BaseHandler):
|
||||
def __init__(self, config, connection, client_address, server, channel, server_version):
|
||||
self.channel, self.server_version = channel, server_version
|
||||
|
@ -266,18 +286,15 @@ class ProxyHandler(tcp.BaseHandler):
|
|||
l = Log(msg)
|
||||
self.channel.tell(l)
|
||||
|
||||
def find_cert(self, host, port, sni):
|
||||
def find_cert(self, cc, host, port, sni):
|
||||
if self.config.certfile:
|
||||
return self.config.certfile
|
||||
else:
|
||||
sans = []
|
||||
if not self.config.no_upstream_cert:
|
||||
try:
|
||||
cert = certutils.get_remote_cert(host, port, sni)
|
||||
except tcp.NetLibError, v:
|
||||
raise ProxyError(502, "Unable to get remote cert: %s"%str(v))
|
||||
sans = cert.altnames
|
||||
host = cert.cn.decode("utf8").encode("idna")
|
||||
conn = self.get_server_connection(cc, "https", host, port, sni)
|
||||
sans = conn.cert.altnames
|
||||
host = conn.cert.cn.decode("utf8").encode("idna")
|
||||
ret = self.config.certstore.get_cert(host, sans, self.config.cacert)
|
||||
if not ret:
|
||||
raise ProxyError(502, "mitmproxy: Unable to generate dummy cert.")
|
||||
|
@ -292,11 +309,6 @@ class ProxyHandler(tcp.BaseHandler):
|
|||
line = fp.readline()
|
||||
return line
|
||||
|
||||
def handle_sni(self, conn):
|
||||
sn = conn.get_servername()
|
||||
if sn:
|
||||
self.sni = sn.decode("utf8").encode("idna")
|
||||
|
||||
def read_request_transparent(self, client_conn):
|
||||
orig = self.config.transparent_proxy["resolver"].original_addr(self.connection)
|
||||
if not orig:
|
||||
|
@ -304,9 +316,13 @@ class ProxyHandler(tcp.BaseHandler):
|
|||
host, port = orig
|
||||
if not self.ssl_established and (port in self.config.transparent_proxy["sslports"]):
|
||||
scheme = "https"
|
||||
certfile = self.find_cert(host, port, None)
|
||||
dummycert = self.find_cert(client_conn, host, port, host)
|
||||
try:
|
||||
self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert)
|
||||
sni = HandleSNI(
|
||||
self, client_conn, host, port,
|
||||
dummycert, self.config.certfile or self.config.cacert
|
||||
)
|
||||
self.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, handle_sni=sni)
|
||||
except tcp.NetLibError, v:
|
||||
raise ProxyError(400, str(v))
|
||||
else:
|
||||
|
@ -346,9 +362,14 @@ class ProxyHandler(tcp.BaseHandler):
|
|||
'\r\n'
|
||||
)
|
||||
self.wfile.flush()
|
||||
certfile = self.find_cert(host, port, None)
|
||||
certfile = self.find_cert(client_conn, host, port, host)
|
||||
|
||||
sni = HandleSNI(
|
||||
self, client_conn, host, port,
|
||||
dummycert, self.config.certfile or self.config.cacert
|
||||
)
|
||||
try:
|
||||
self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert)
|
||||
self.convert_to_ssl(certfile, self.config.certfile or self.config.cacert, handle_sni=sni)
|
||||
except tcp.NetLibError, v:
|
||||
raise ProxyError(400, str(v))
|
||||
self.proxy_connect_state = (host, port, httpversion)
|
||||
|
|
|
@ -13,11 +13,7 @@ from libmproxy import flow, proxy
|
|||
for a 200 response.
|
||||
"""
|
||||
|
||||
class SanityMixin:
|
||||
def test_http(self):
|
||||
assert self.pathod("304").status_code == 304
|
||||
assert self.master.state.view
|
||||
|
||||
class CommonMixin:
|
||||
def test_large(self):
|
||||
assert len(self.pathod("200:b@50k").content) == 1024*50
|
||||
|
||||
|
@ -40,19 +36,23 @@ class SanityMixin:
|
|||
self.master.replay_request(l, block=True)
|
||||
assert l.error
|
||||
|
||||
def test_http(self):
|
||||
f = self.pathod("304")
|
||||
assert f.status_code == 304
|
||||
|
||||
class TestHTTP(tservers.HTTPProxTest, SanityMixin):
|
||||
def test_app(self):
|
||||
p = self.pathoc()
|
||||
ret = p.request("get:'http://testapp/'")
|
||||
assert ret[1] == 200
|
||||
assert ret[4] == "testapp"
|
||||
l = self.master.state.view[0]
|
||||
assert l.request.client_conn.address
|
||||
assert "host" in l.request.headers
|
||||
assert l.response.code == 304
|
||||
|
||||
|
||||
|
||||
class TestHTTP(tservers.HTTPProxTest, CommonMixin):
|
||||
def test_app_err(self):
|
||||
p = self.pathoc()
|
||||
ret = p.request("get:'http://errapp/'")
|
||||
assert ret[1] == 500
|
||||
assert "ValueError" in ret[4]
|
||||
assert ret.status_code == 500
|
||||
assert "ValueError" in ret.content
|
||||
|
||||
def test_invalid_http(self):
|
||||
t = tcp.TCPClient("127.0.0.1", self.proxy.port)
|
||||
|
@ -71,16 +71,7 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin):
|
|||
def test_upstream_ssl_error(self):
|
||||
p = self.pathoc()
|
||||
ret = p.request("get:'https://localhost:%s/'"%self.server.port)
|
||||
assert ret[1] == 400
|
||||
|
||||
def test_http(self):
|
||||
f = self.pathod("304")
|
||||
assert f.status_code == 304
|
||||
|
||||
l = self.master.state.view[0]
|
||||
assert l.request.client_conn.address
|
||||
assert "host" in l.request.headers
|
||||
assert l.response.code == 304
|
||||
assert ret.status_code == 400
|
||||
|
||||
def test_connection_close(self):
|
||||
# Add a body, so we have a content-length header, which combined with
|
||||
|
@ -116,7 +107,7 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin):
|
|||
# within our read loop.
|
||||
with mock.patch("libmproxy.proxy.ProxyHandler.read_request") as m:
|
||||
m.side_effect = IOError("error!")
|
||||
tutils.raises("empty reply", self.pathod, "304")
|
||||
tutils.raises("server disconnect", self.pathod, "304")
|
||||
|
||||
def test_get_connection_switching(self):
|
||||
def switched(l):
|
||||
|
@ -132,30 +123,39 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin):
|
|||
def test_get_connection_err(self):
|
||||
p = self.pathoc()
|
||||
ret = p.request("get:'http://localhost:0'")
|
||||
assert ret[1] == 502
|
||||
assert ret.status_code == 502
|
||||
|
||||
|
||||
class TestHTTPS(tservers.HTTPProxTest, SanityMixin):
|
||||
class TestHTTPS(tservers.HTTPProxTest, CommonMixin):
|
||||
ssl = True
|
||||
clientcerts = True
|
||||
def test_clientcert(self):
|
||||
f = self.pathod("304")
|
||||
assert self.server.last_log()["request"]["clientcert"]["keyinfo"]
|
||||
|
||||
def test_sni(self):
|
||||
pass
|
||||
|
||||
class TestHTTPSCertfile(tservers.HTTPProxTest, SanityMixin):
|
||||
|
||||
class TestHTTPSCertfile(tservers.HTTPProxTest, CommonMixin):
|
||||
ssl = True
|
||||
certfile = True
|
||||
def test_certfile(self):
|
||||
assert self.pathod("304")
|
||||
|
||||
|
||||
class TestReverse(tservers.ReverseProxTest, SanityMixin):
|
||||
class TestReverse(tservers.ReverseProxTest, CommonMixin):
|
||||
reverse = True
|
||||
|
||||
|
||||
class TestTransparent(tservers.TransparentProxTest, SanityMixin):
|
||||
class TestTransparent(tservers.TransparentProxTest, CommonMixin):
|
||||
transparent = True
|
||||
ssl = False
|
||||
|
||||
|
||||
class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin):
|
||||
transparent = True
|
||||
ssl = True
|
||||
|
||||
|
||||
class TestProxy(tservers.HTTPProxTest):
|
||||
|
@ -232,7 +232,7 @@ class TestKillRequest(tservers.HTTPProxTest):
|
|||
masterclass = MasterKillRequest
|
||||
def test_kill(self):
|
||||
p = self.pathoc()
|
||||
tutils.raises("empty reply", self.pathod, "200")
|
||||
tutils.raises("server disconnect", self.pathod, "200")
|
||||
# Nothing should have hit the server
|
||||
assert not self.server.last_log()
|
||||
|
||||
|
@ -246,7 +246,7 @@ class TestKillResponse(tservers.HTTPProxTest):
|
|||
masterclass = MasterKillResponse
|
||||
def test_kill(self):
|
||||
p = self.pathoc()
|
||||
tutils.raises("empty reply", self.pathod, "200")
|
||||
tutils.raises("server disconnect", self.pathod, "200")
|
||||
# The server should have seen a request
|
||||
assert self.server.last_log()
|
||||
|
||||
|
|
|
@ -126,20 +126,21 @@ class HTTPProxTest(ProxTestBase):
|
|||
"""
|
||||
Returns a connected Pathoc instance.
|
||||
"""
|
||||
p = libpathod.pathoc.Pathoc("localhost", self.proxy.port)
|
||||
p = libpathod.pathoc.Pathoc("localhost", self.proxy.port, ssl=self.ssl)
|
||||
p.connect(connect_to)
|
||||
return p
|
||||
|
||||
def pathod(self, spec):
|
||||
"""
|
||||
Constructs a pathod request, with the appropriate base and proxy.
|
||||
Constructs a pathod GET request, with the appropriate base and proxy.
|
||||
"""
|
||||
return hurl.get(
|
||||
self.server.urlbase + "/p/" + spec,
|
||||
proxy=self.proxies,
|
||||
validate_cert=False,
|
||||
#debug=hurl.utils.stdout_debug
|
||||
)
|
||||
if self.ssl:
|
||||
p = self.pathoc(("127.0.0.1", self.server.port))
|
||||
q = "get:'/p/%s'"%spec
|
||||
else:
|
||||
p = self.pathoc()
|
||||
q = "get:'%s/p/%s'"%(self.server.urlbase, spec)
|
||||
return p.request(q)
|
||||
|
||||
|
||||
class TResolver:
|
||||
|
@ -155,9 +156,13 @@ class TransparentProxTest(ProxTestBase):
|
|||
@classmethod
|
||||
def get_proxy_config(cls):
|
||||
d = ProxTestBase.get_proxy_config()
|
||||
if cls.ssl:
|
||||
ports = [cls.server.port, cls.server2.port]
|
||||
else:
|
||||
ports = []
|
||||
d["transparent_proxy"] = dict(
|
||||
resolver = TResolver(cls.server.port),
|
||||
sslports = []
|
||||
sslports = ports
|
||||
)
|
||||
return d
|
||||
|
||||
|
@ -166,12 +171,20 @@ class TransparentProxTest(ProxTestBase):
|
|||
Constructs a pathod request, with the appropriate base and proxy.
|
||||
"""
|
||||
r = hurl.get(
|
||||
"http://127.0.0.1:%s"%self.proxy.port + "/p/" + spec,
|
||||
"%s://127.0.0.1:%s"%(self.scheme, self.proxy.port) + "/p/" + spec,
|
||||
validate_cert=False,
|
||||
#debug=hurl.utils.stdout_debug
|
||||
)
|
||||
return r
|
||||
|
||||
def pathoc(self, connect= None):
|
||||
"""
|
||||
Returns a connected Pathoc instance.
|
||||
"""
|
||||
p = libpathod.pathoc.Pathoc("localhost", self.proxy.port)
|
||||
p.connect(connect_to)
|
||||
return p
|
||||
|
||||
|
||||
class ReverseProxTest(ProxTestBase):
|
||||
ssl = None
|
||||
|
|
Loading…
Reference in New Issue