diff --git a/netlib/tcp.py b/netlib/tcp.py index 007cf3a56..25e83e075 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -48,11 +48,10 @@ class FileLike: class TCPClient: - def __init__(self, ssl, host, port, clientcert, sni): - self.ssl, self.host, self.port, self.clientcert, self.sni = ssl, host, port, clientcert, sni + def __init__(self, ssl, host, port, clientcert): + self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert self.connection, self.rfile, self.wfile = None, None, None self.cert = None - self.connect() def connect(self): try: @@ -75,6 +74,9 @@ class TCPClient: class BaseHandler: + """ + The instantiator is expected to call the handle() and finish() methods. + """ rbufsize = -1 wbufsize = 0 def __init__(self, connection, client_address, server): @@ -84,8 +86,6 @@ class BaseHandler: self.client_address = client_address self.server = server - self.handle() - self.finish() def convert_to_ssl(self, cert, key): ctx = SSL.Context(SSL.SSLv23_METHOD) diff --git a/test/test_tcp.py b/test/test_tcp.py index 9aebb2f00..1bad9a040 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -46,7 +46,9 @@ class TServer(tcp.TCPServer): self.ssl, self.q = ssl, q def handle_connection(self, request, client_address): - THandler(request, client_address, self) + h = THandler(request, client_address, self) + h.handle() + h.finish() def handle_error(self, request, client_address): s = cStringIO.StringIO() @@ -64,7 +66,8 @@ class TestServer(ServerTestBase): def test_echo(self): testval = "echo!\n" - c = tcp.TCPClient(False, "127.0.0.1", self.port, None, None) + c = tcp.TCPClient(False, "127.0.0.1", self.port, None) + c.connect() c.wfile.write(testval) c.wfile.flush() assert c.rfile.readline() == testval @@ -79,7 +82,8 @@ class TestServerSSL(ServerTestBase): return s def test_echo(self): - c = tcp.TCPClient(True, "127.0.0.1", self.port, None, None) + c = tcp.TCPClient(True, "127.0.0.1", self.port, None) + c.connect() testval = "echo!\n" c.wfile.write(testval) c.wfile.flush() @@ -88,7 +92,8 @@ class TestServerSSL(ServerTestBase): class TestTCPClient: def test_conerr(self): - tutils.raises(tcp.NetLibError, tcp.TCPClient, False, "127.0.0.1", 0, None, None) + c = tcp.TCPClient(True, "127.0.0.1", 0, None) + tutils.raises(tcp.NetLibError, c.connect) class TestFileLike: