diff --git a/netlib/tcp.py b/netlib/tcp.py index d0ca09f35..56cc0dead 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,4 +1,4 @@ -import select, socket, threading, traceback, sys, time +import select, socket, threading, sys, time, traceback from OpenSSL import SSL import certutils @@ -84,13 +84,14 @@ class _FileLike: def reset_timestamps(self): self.first_byte_timestamp = None + class Writer(_FileLike): def flush(self): - try: - if hasattr(self.o, "flush"): + if hasattr(self.o, "flush"): + try: self.o.flush() - except socket.error, v: - raise NetLibDisconnect(str(v)) + except socket.error, v: + raise NetLibDisconnect(str(v)) def write(self, v): if v: diff --git a/netlib/test.py b/netlib/test.py new file mode 100644 index 000000000..2f72f9799 --- /dev/null +++ b/netlib/test.py @@ -0,0 +1,67 @@ +import threading, Queue, cStringIO +import tcp + +class ServerThread(threading.Thread): + def __init__(self, server): + self.server = server + threading.Thread.__init__(self) + + def run(self): + self.server.serve_forever() + + def shutdown(self): + self.server.shutdown() + + +class ServerTestBase: + @classmethod + def setupAll(cls): + cls.q = Queue.Queue() + s = cls.makeserver() + cls.port = s.port + cls.server = ServerThread(s) + cls.server.start() + + @classmethod + def teardownAll(cls): + cls.server.shutdown() + + + @property + def last_handler(self): + return self.server.server.last_handler + + +class TServer(tcp.TCPServer): + def __init__(self, ssl, q, handler_klass, addr=("127.0.0.1", 0)): + """ + ssl: A {cert, key, v3_only} dict. + """ + tcp.TCPServer.__init__(self, addr) + self.ssl, self.q = ssl, q + self.handler_klass = handler_klass + self.last_handler = None + + def handle_connection(self, request, client_address): + h = self.handler_klass(request, client_address, self) + self.last_handler = h + if self.ssl: + if self.ssl["v3_only"]: + method = tcp.SSLv3_METHOD + options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1 + else: + method = tcp.SSLv23_METHOD + options = None + h.convert_to_ssl( + self.ssl["cert"], + self.ssl["key"], + method = method, + options = options, + ) + h.handle() + h.finish() + + def handle_error(self, request, client_address): + s = cStringIO.StringIO() + tcp.TCPServer.handle_error(self, request, client_address, s) + self.q.put(s.getvalue()) diff --git a/test/test_certutils.py b/test/test_certutils.py index 582fb9c45..334a6be46 100644 --- a/test/test_certutils.py +++ b/test/test_certutils.py @@ -30,6 +30,7 @@ class TestCertStore: ca = os.path.join(d, "ca") assert certutils.dummy_ca(ca) c = certutils.CertStore() + assert not c.get_cert("../foo.com", []) assert not c.get_cert("foo.com", []) assert c.get_cert("foo.com", [], ca) assert c.get_cert("foo.com", [], ca) diff --git a/test/test_tcp.py b/test/test_tcp.py index 0417aa21a..ce06ad663 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -1,38 +1,7 @@ import cStringIO, threading, Queue, time -from netlib import tcp, certutils +from netlib import tcp, certutils, test import tutils -class ServerThread(threading.Thread): - def __init__(self, server): - self.server = server - threading.Thread.__init__(self) - - def run(self): - self.server.serve_forever() - - def shutdown(self): - self.server.shutdown() - - -class ServerTestBase: - @classmethod - def setupAll(cls): - cls.q = Queue.Queue() - s = cls.makeserver() - cls.port = s.port - cls.server = ServerThread(s) - cls.server.start() - - @classmethod - def teardownAll(cls): - cls.server.shutdown() - - - @property - def last_handler(self): - return self.server.server.last_handler - - class SNIHandler(tcp.BaseHandler): sni = None def handle_sni(self, connection): @@ -88,43 +57,10 @@ class TimeoutHandler(tcp.BaseHandler): self.timeout = True -class TServer(tcp.TCPServer): - def __init__(self, addr, ssl, q, handler_klass, v3_only=False): - tcp.TCPServer.__init__(self, addr) - self.ssl, self.q = ssl, q - self.v3_only = v3_only - self.handler_klass = handler_klass - self.last_handler = None - - def handle_connection(self, request, client_address): - h = self.handler_klass(request, client_address, self) - self.last_handler = h - if self.ssl: - if self.v3_only: - method = tcp.SSLv3_METHOD - options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1 - else: - method = tcp.SSLv23_METHOD - options = None - h.convert_to_ssl( - tutils.test_data.path("data/server.crt"), - tutils.test_data.path("data/server.key"), - method = method, - options = options, - ) - h.handle() - h.finish() - - def handle_error(self, request, client_address): - s = cStringIO.StringIO() - tcp.TCPServer.handle_error(self, request, client_address, s) - self.q.put(s.getvalue()) - - -class TestServer(ServerTestBase): +class TestServer(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler) + return test.TServer(False, cls.q, EchoHandler) def test_echo(self): testval = "echo!\n" @@ -135,10 +71,10 @@ class TestServer(ServerTestBase): assert c.rfile.readline() == testval -class TestDisconnect(ServerTestBase): +class TestDisconnect(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler) + return test.TServer(False, cls.q, EchoHandler) def test_echo(self): testval = "echo!\n" @@ -149,10 +85,18 @@ class TestDisconnect(ServerTestBase): assert c.rfile.readline() == testval -class TestServerSSL(ServerTestBase): +class TestServerSSL(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), True, cls.q, EchoHandler) + return test.TServer( + dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = False + ), + cls.q, + EchoHandler + ) def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -167,10 +111,19 @@ class TestServerSSL(ServerTestBase): assert certutils.get_remote_cert("127.0.0.1", self.port, None).digest("sha1") -class TestSSLv3Only(ServerTestBase): +class TestSSLv3Only(test.ServerTestBase): + v3_only = True @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), True, cls.q, EchoHandler, True) + return test.TServer( + dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = True + ), + cls.q, + EchoHandler, + ) def test_failure(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -178,10 +131,18 @@ class TestSSLv3Only(ServerTestBase): tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com", method=tcp.TLSv1_METHOD) -class TestSSLClientCert(ServerTestBase): +class TestSSLClientCert(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), True, cls.q, CertHandler) + return test.TServer( + dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = False + ), + cls.q, + CertHandler + ) def test_clientcert(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -199,10 +160,18 @@ class TestSSLClientCert(ServerTestBase): ) -class TestSNI(ServerTestBase): +class TestSNI(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), True, cls.q, SNIHandler) + return test.TServer( + dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = False + ), + cls.q, + SNIHandler + ) def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -211,10 +180,18 @@ class TestSNI(ServerTestBase): assert c.rfile.readline() == "foo.com" -class TestSSLDisconnect(ServerTestBase): +class TestSSLDisconnect(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), True, cls.q, DisconnectHandler) + return test.TServer( + dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = False + ), + cls.q, + DisconnectHandler + ) def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -227,10 +204,10 @@ class TestSSLDisconnect(ServerTestBase): tutils.raises(Queue.Empty, self.q.get_nowait) -class TestDisconnect(ServerTestBase): +class TestSSLDisconnect(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), False, cls.q, DisconnectHandler) + return test.TServer(False, cls.q, DisconnectHandler) def test_echo(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -242,10 +219,10 @@ class TestDisconnect(ServerTestBase): c.close() -class TestServerTimeOut(ServerTestBase): +class TestServerTimeOut(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), False, cls.q, TimeoutHandler) + return test.TServer(False, cls.q, TimeoutHandler) def test_timeout(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -254,10 +231,10 @@ class TestServerTimeOut(ServerTestBase): assert self.last_handler.timeout -class TestTimeOut(ServerTestBase): +class TestTimeOut(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), False, cls.q, HangHandler) + return test.TServer(False, cls.q, HangHandler) def test_timeout(self): c = tcp.TCPClient("127.0.0.1", self.port) @@ -266,10 +243,18 @@ class TestTimeOut(ServerTestBase): tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) -class TestSSLTimeOut(ServerTestBase): +class TestSSLTimeOut(test.ServerTestBase): @classmethod def makeserver(cls): - return TServer(("127.0.0.1", 0), True, cls.q, HangHandler) + return test.TServer( + dict( + cert = tutils.test_data.path("data/server.crt"), + key = tutils.test_data.path("data/server.key"), + v3_only = False + ), + cls.q, + HangHandler + ) def test_timeout_client(self): c = tcp.TCPClient("127.0.0.1", self.port)