Extract TCP test utilities into netlib.test

This commit is contained in:
Aldo Cortesi 2013-01-25 15:54:41 +13:00
parent 7248a22d5e
commit 2eb6651e51
4 changed files with 146 additions and 92 deletions

View File

@ -1,4 +1,4 @@
import select, socket, threading, traceback, sys, time
import select, socket, threading, sys, time, traceback
from OpenSSL import SSL
import certutils
@ -84,13 +84,14 @@ class _FileLike:
def reset_timestamps(self):
self.first_byte_timestamp = None
class Writer(_FileLike):
def flush(self):
try:
if hasattr(self.o, "flush"):
if hasattr(self.o, "flush"):
try:
self.o.flush()
except socket.error, v:
raise NetLibDisconnect(str(v))
except socket.error, v:
raise NetLibDisconnect(str(v))
def write(self, v):
if v:

67
netlib/test.py Normal file
View File

@ -0,0 +1,67 @@
import threading, Queue, cStringIO
import tcp
class ServerThread(threading.Thread):
def __init__(self, server):
self.server = server
threading.Thread.__init__(self)
def run(self):
self.server.serve_forever()
def shutdown(self):
self.server.shutdown()
class ServerTestBase:
@classmethod
def setupAll(cls):
cls.q = Queue.Queue()
s = cls.makeserver()
cls.port = s.port
cls.server = ServerThread(s)
cls.server.start()
@classmethod
def teardownAll(cls):
cls.server.shutdown()
@property
def last_handler(self):
return self.server.server.last_handler
class TServer(tcp.TCPServer):
def __init__(self, ssl, q, handler_klass, addr=("127.0.0.1", 0)):
"""
ssl: A {cert, key, v3_only} dict.
"""
tcp.TCPServer.__init__(self, addr)
self.ssl, self.q = ssl, q
self.handler_klass = handler_klass
self.last_handler = None
def handle_connection(self, request, client_address):
h = self.handler_klass(request, client_address, self)
self.last_handler = h
if self.ssl:
if self.ssl["v3_only"]:
method = tcp.SSLv3_METHOD
options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1
else:
method = tcp.SSLv23_METHOD
options = None
h.convert_to_ssl(
self.ssl["cert"],
self.ssl["key"],
method = method,
options = options,
)
h.handle()
h.finish()
def handle_error(self, request, client_address):
s = cStringIO.StringIO()
tcp.TCPServer.handle_error(self, request, client_address, s)
self.q.put(s.getvalue())

View File

@ -30,6 +30,7 @@ class TestCertStore:
ca = os.path.join(d, "ca")
assert certutils.dummy_ca(ca)
c = certutils.CertStore()
assert not c.get_cert("../foo.com", [])
assert not c.get_cert("foo.com", [])
assert c.get_cert("foo.com", [], ca)
assert c.get_cert("foo.com", [], ca)

View File

@ -1,38 +1,7 @@
import cStringIO, threading, Queue, time
from netlib import tcp, certutils
from netlib import tcp, certutils, test
import tutils
class ServerThread(threading.Thread):
def __init__(self, server):
self.server = server
threading.Thread.__init__(self)
def run(self):
self.server.serve_forever()
def shutdown(self):
self.server.shutdown()
class ServerTestBase:
@classmethod
def setupAll(cls):
cls.q = Queue.Queue()
s = cls.makeserver()
cls.port = s.port
cls.server = ServerThread(s)
cls.server.start()
@classmethod
def teardownAll(cls):
cls.server.shutdown()
@property
def last_handler(self):
return self.server.server.last_handler
class SNIHandler(tcp.BaseHandler):
sni = None
def handle_sni(self, connection):
@ -88,43 +57,10 @@ class TimeoutHandler(tcp.BaseHandler):
self.timeout = True
class TServer(tcp.TCPServer):
def __init__(self, addr, ssl, q, handler_klass, v3_only=False):
tcp.TCPServer.__init__(self, addr)
self.ssl, self.q = ssl, q
self.v3_only = v3_only
self.handler_klass = handler_klass
self.last_handler = None
def handle_connection(self, request, client_address):
h = self.handler_klass(request, client_address, self)
self.last_handler = h
if self.ssl:
if self.v3_only:
method = tcp.SSLv3_METHOD
options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1
else:
method = tcp.SSLv23_METHOD
options = None
h.convert_to_ssl(
tutils.test_data.path("data/server.crt"),
tutils.test_data.path("data/server.key"),
method = method,
options = options,
)
h.handle()
h.finish()
def handle_error(self, request, client_address):
s = cStringIO.StringIO()
tcp.TCPServer.handle_error(self, request, client_address, s)
self.q.put(s.getvalue())
class TestServer(ServerTestBase):
class TestServer(test.ServerTestBase):
@classmethod
def makeserver(cls):
return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler)
return test.TServer(False, cls.q, EchoHandler)
def test_echo(self):
testval = "echo!\n"
@ -135,10 +71,10 @@ class TestServer(ServerTestBase):
assert c.rfile.readline() == testval
class TestDisconnect(ServerTestBase):
class TestDisconnect(test.ServerTestBase):
@classmethod
def makeserver(cls):
return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler)
return test.TServer(False, cls.q, EchoHandler)
def test_echo(self):
testval = "echo!\n"
@ -149,10 +85,18 @@ class TestDisconnect(ServerTestBase):
assert c.rfile.readline() == testval
class TestServerSSL(ServerTestBase):
class TestServerSSL(test.ServerTestBase):
@classmethod
def makeserver(cls):
return TServer(("127.0.0.1", 0), True, cls.q, EchoHandler)
return test.TServer(
dict(
cert = tutils.test_data.path("data/server.crt"),
key = tutils.test_data.path("data/server.key"),
v3_only = False
),
cls.q,
EchoHandler
)
def test_echo(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@ -167,10 +111,19 @@ class TestServerSSL(ServerTestBase):
assert certutils.get_remote_cert("127.0.0.1", self.port, None).digest("sha1")
class TestSSLv3Only(ServerTestBase):
class TestSSLv3Only(test.ServerTestBase):
v3_only = True
@classmethod
def makeserver(cls):
return TServer(("127.0.0.1", 0), True, cls.q, EchoHandler, True)
return test.TServer(
dict(
cert = tutils.test_data.path("data/server.crt"),
key = tutils.test_data.path("data/server.key"),
v3_only = True
),
cls.q,
EchoHandler,
)
def test_failure(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@ -178,10 +131,18 @@ class TestSSLv3Only(ServerTestBase):
tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com", method=tcp.TLSv1_METHOD)
class TestSSLClientCert(ServerTestBase):
class TestSSLClientCert(test.ServerTestBase):
@classmethod
def makeserver(cls):
return TServer(("127.0.0.1", 0), True, cls.q, CertHandler)
return test.TServer(
dict(
cert = tutils.test_data.path("data/server.crt"),
key = tutils.test_data.path("data/server.key"),
v3_only = False
),
cls.q,
CertHandler
)
def test_clientcert(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@ -199,10 +160,18 @@ class TestSSLClientCert(ServerTestBase):
)
class TestSNI(ServerTestBase):
class TestSNI(test.ServerTestBase):
@classmethod
def makeserver(cls):
return TServer(("127.0.0.1", 0), True, cls.q, SNIHandler)
return test.TServer(
dict(
cert = tutils.test_data.path("data/server.crt"),
key = tutils.test_data.path("data/server.key"),
v3_only = False
),
cls.q,
SNIHandler
)
def test_echo(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@ -211,10 +180,18 @@ class TestSNI(ServerTestBase):
assert c.rfile.readline() == "foo.com"
class TestSSLDisconnect(ServerTestBase):
class TestSSLDisconnect(test.ServerTestBase):
@classmethod
def makeserver(cls):
return TServer(("127.0.0.1", 0), True, cls.q, DisconnectHandler)
return test.TServer(
dict(
cert = tutils.test_data.path("data/server.crt"),
key = tutils.test_data.path("data/server.key"),
v3_only = False
),
cls.q,
DisconnectHandler
)
def test_echo(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@ -227,10 +204,10 @@ class TestSSLDisconnect(ServerTestBase):
tutils.raises(Queue.Empty, self.q.get_nowait)
class TestDisconnect(ServerTestBase):
class TestSSLDisconnect(test.ServerTestBase):
@classmethod
def makeserver(cls):
return TServer(("127.0.0.1", 0), False, cls.q, DisconnectHandler)
return test.TServer(False, cls.q, DisconnectHandler)
def test_echo(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@ -242,10 +219,10 @@ class TestDisconnect(ServerTestBase):
c.close()
class TestServerTimeOut(ServerTestBase):
class TestServerTimeOut(test.ServerTestBase):
@classmethod
def makeserver(cls):
return TServer(("127.0.0.1", 0), False, cls.q, TimeoutHandler)
return test.TServer(False, cls.q, TimeoutHandler)
def test_timeout(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@ -254,10 +231,10 @@ class TestServerTimeOut(ServerTestBase):
assert self.last_handler.timeout
class TestTimeOut(ServerTestBase):
class TestTimeOut(test.ServerTestBase):
@classmethod
def makeserver(cls):
return TServer(("127.0.0.1", 0), False, cls.q, HangHandler)
return test.TServer(False, cls.q, HangHandler)
def test_timeout(self):
c = tcp.TCPClient("127.0.0.1", self.port)
@ -266,10 +243,18 @@ class TestTimeOut(ServerTestBase):
tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10)
class TestSSLTimeOut(ServerTestBase):
class TestSSLTimeOut(test.ServerTestBase):
@classmethod
def makeserver(cls):
return TServer(("127.0.0.1", 0), True, cls.q, HangHandler)
return test.TServer(
dict(
cert = tutils.test_data.path("data/server.crt"),
key = tutils.test_data.path("data/server.key"),
v3_only = False
),
cls.q,
HangHandler
)
def test_timeout_client(self):
c = tcp.TCPClient("127.0.0.1", self.port)