diff --git a/netlib/tcp.py b/netlib/tcp.py index 276d3162c..c8ffefdf0 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -53,11 +53,13 @@ class TCPClient: self.connection, self.rfile, self.wfile = None, None, None self.cert = None - def convert_to_ssl(self, clientcert=None): + def convert_to_ssl(self, clientcert=None, sni=None): context = SSL.Context(SSL.SSLv23_METHOD) if clientcert: context.use_certificate_file(self.clientcert) self.connection = SSL.Connection(context, self.connection) + if sni: + self.connection.set_tlsext_host_name(sni) self.connection.set_connect_state() self.connection.do_handshake() self.cert = self.connection.get_peer_certificate() @@ -92,10 +94,12 @@ class BaseHandler: def convert_to_ssl(self, cert, key): ctx = SSL.Context(SSL.SSLv23_METHOD) + ctx.set_tlsext_servername_callback(self.handle_sni) ctx.use_privatekey_file(key) ctx.use_certificate_file(cert) self.connection = SSL.Connection(ctx, self.connection) self.connection.set_accept_state() + # SNI callback happens during do_handshake() self.connection.do_handshake() self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) @@ -111,6 +115,23 @@ class BaseHandler: except IOError: # pragma: no cover pass + def handle_sni(self, connection): + """ + Called if the client has given a server name indication. + + Server name can be retrieved like this: + + connection.get_servername() + + And you can specify the connection keys as follows: + + new_context = Context(TLSv1_METHOD) + new_context.use_privatekey(key) + new_context.use_certificate(cert) + connection.set_context(new_context) + """ + pass + def handle(self): # pragma: no cover raise NotImplementedError diff --git a/test/test_tcp.py b/test/test_tcp.py index a81632e79..a2ee5e368 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -25,7 +25,21 @@ class ServerTestBase: cls.server.shutdown() +class SNIHandler(tcp.BaseHandler): + sni = None + def handle_sni(self, connection): + self.sni = connection.get_servername() + + def handle(self): + self.wfile.write(self.sni) + self.wfile.flush() + + class EchoHandler(tcp.BaseHandler): + sni = None + def handle_sni(self, connection): + self.sni = connection.get_servername() + def handle(self): v = self.rfile.readline() if v.startswith("echo"): @@ -90,13 +104,28 @@ class TestServerSSL(ServerTestBase): def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) c.connect() - c.convert_to_ssl() + c.convert_to_ssl(sni="foo.com") testval = "echo!\n" c.wfile.write(testval) c.wfile.flush() assert c.rfile.readline() == testval +class TestSNI(ServerTestBase): + @classmethod + def makeserver(cls): + cls.q = Queue.Queue() + s = TServer(("127.0.0.1", 0), True, cls.q, SNIHandler) + cls.port = s.port + return s + + def test_echo(self): + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + c.convert_to_ssl(sni="foo.com") + assert c.rfile.readline() == "foo.com" + + class TestSSLDisconnect(ServerTestBase): @classmethod def makeserver(cls):