Add SNI.
This commit is contained in:
parent
ea457fac2e
commit
ccf2603ddc
|
@ -53,11 +53,13 @@ class TCPClient:
|
|||
self.connection, self.rfile, self.wfile = None, None, None
|
||||
self.cert = None
|
||||
|
||||
def convert_to_ssl(self, clientcert=None):
|
||||
def convert_to_ssl(self, clientcert=None, sni=None):
|
||||
context = SSL.Context(SSL.SSLv23_METHOD)
|
||||
if clientcert:
|
||||
context.use_certificate_file(self.clientcert)
|
||||
self.connection = SSL.Connection(context, self.connection)
|
||||
if sni:
|
||||
self.connection.set_tlsext_host_name(sni)
|
||||
self.connection.set_connect_state()
|
||||
self.connection.do_handshake()
|
||||
self.cert = self.connection.get_peer_certificate()
|
||||
|
@ -92,10 +94,12 @@ class BaseHandler:
|
|||
|
||||
def convert_to_ssl(self, cert, key):
|
||||
ctx = SSL.Context(SSL.SSLv23_METHOD)
|
||||
ctx.set_tlsext_servername_callback(self.handle_sni)
|
||||
ctx.use_privatekey_file(key)
|
||||
ctx.use_certificate_file(cert)
|
||||
self.connection = SSL.Connection(ctx, self.connection)
|
||||
self.connection.set_accept_state()
|
||||
# SNI callback happens during do_handshake()
|
||||
self.connection.do_handshake()
|
||||
self.rfile = FileLike(self.connection)
|
||||
self.wfile = FileLike(self.connection)
|
||||
|
@ -111,6 +115,23 @@ class BaseHandler:
|
|||
except IOError: # pragma: no cover
|
||||
pass
|
||||
|
||||
def handle_sni(self, connection):
|
||||
"""
|
||||
Called if the client has given a server name indication.
|
||||
|
||||
Server name can be retrieved like this:
|
||||
|
||||
connection.get_servername()
|
||||
|
||||
And you can specify the connection keys as follows:
|
||||
|
||||
new_context = Context(TLSv1_METHOD)
|
||||
new_context.use_privatekey(key)
|
||||
new_context.use_certificate(cert)
|
||||
connection.set_context(new_context)
|
||||
"""
|
||||
pass
|
||||
|
||||
def handle(self): # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -25,7 +25,21 @@ class ServerTestBase:
|
|||
cls.server.shutdown()
|
||||
|
||||
|
||||
class SNIHandler(tcp.BaseHandler):
|
||||
sni = None
|
||||
def handle_sni(self, connection):
|
||||
self.sni = connection.get_servername()
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write(self.sni)
|
||||
self.wfile.flush()
|
||||
|
||||
|
||||
class EchoHandler(tcp.BaseHandler):
|
||||
sni = None
|
||||
def handle_sni(self, connection):
|
||||
self.sni = connection.get_servername()
|
||||
|
||||
def handle(self):
|
||||
v = self.rfile.readline()
|
||||
if v.startswith("echo"):
|
||||
|
@ -90,13 +104,28 @@ class TestServerSSL(ServerTestBase):
|
|||
def test_echo(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
c.convert_to_ssl(sni="foo.com")
|
||||
testval = "echo!\n"
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
assert c.rfile.readline() == testval
|
||||
|
||||
|
||||
class TestSNI(ServerTestBase):
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
cls.q = Queue.Queue()
|
||||
s = TServer(("127.0.0.1", 0), True, cls.q, SNIHandler)
|
||||
cls.port = s.port
|
||||
return s
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
c.connect()
|
||||
c.convert_to_ssl(sni="foo.com")
|
||||
assert c.rfile.readline() == "foo.com"
|
||||
|
||||
|
||||
class TestSSLDisconnect(ServerTestBase):
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
|
|
Loading…
Reference in New Issue