diff --git a/netlib/tcp.py b/netlib/tcp.py index 3c5c89b7c..276d3162c 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -59,6 +59,7 @@ class TCPClient: context.use_certificate_file(self.clientcert) self.connection = SSL.Connection(context, self.connection) self.connection.set_connect_state() + self.connection.do_handshake() self.cert = self.connection.get_peer_certificate() self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) @@ -95,6 +96,7 @@ class BaseHandler: ctx.use_certificate_file(cert) self.connection = SSL.Connection(ctx, self.connection) self.connection.set_accept_state() + self.connection.do_handshake() self.rfile = FileLike(self.connection) self.wfile = FileLike(self.connection) diff --git a/test/test_tcp.py b/test/test_tcp.py index 26286bc48..a81632e79 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -25,13 +25,8 @@ class ServerTestBase: cls.server.shutdown() -class THandler(tcp.BaseHandler): +class EchoHandler(tcp.BaseHandler): def handle(self): - if self.server.ssl: - self.convert_to_ssl( - tutils.test_data.path("data/server.crt"), - tutils.test_data.path("data/server.key"), - ) v = self.rfile.readline() if v.startswith("echo"): self.wfile.write(v) @@ -40,13 +35,24 @@ class THandler(tcp.BaseHandler): self.wfile.flush() +class DisconnectHandler(tcp.BaseHandler): + def handle(self): + self.finish() + + class TServer(tcp.TCPServer): - def __init__(self, addr, ssl, q): + def __init__(self, addr, ssl, q, handler): tcp.TCPServer.__init__(self, addr) self.ssl, self.q = ssl, q + self.handler = handler def handle_connection(self, request, client_address): - h = THandler(request, client_address, self) + h = self.handler(request, client_address, self) + if self.ssl: + h.convert_to_ssl( + tutils.test_data.path("data/server.crt"), + tutils.test_data.path("data/server.key"), + ) h.handle() h.finish() @@ -60,7 +66,7 @@ class TestServer(ServerTestBase): @classmethod def makeserver(cls): cls.q = Queue.Queue() - s = TServer(("127.0.0.1", 0), False, cls.q) + s = TServer(("127.0.0.1", 0), False, cls.q, EchoHandler) cls.port = s.port return s @@ -77,7 +83,7 @@ class TestServerSSL(ServerTestBase): @classmethod def makeserver(cls): cls.q = Queue.Queue() - s = TServer(("127.0.0.1", 0), True, cls.q) + s = TServer(("127.0.0.1", 0), True, cls.q, EchoHandler) cls.port = s.port return s @@ -91,6 +97,22 @@ class TestServerSSL(ServerTestBase): assert c.rfile.readline() == testval +class TestSSLDisconnect(ServerTestBase): + @classmethod + def makeserver(cls): + cls.q = Queue.Queue() + s = TServer(("127.0.0.1", 0), True, cls.q, DisconnectHandler) + cls.port = s.port + return s + + def test_echo(self): + c = tcp.TCPClient("127.0.0.1", self.port) + c.connect() + c.convert_to_ssl() + # Excercise SSL.ZeroReturnError + c.rfile.read(10) + + class TestTCPClient: def test_conerr(self): c = tcp.TCPClient("127.0.0.1", 0)