Extract TCP test utilities into netlib.test
This commit is contained in:
parent
7248a22d5e
commit
2eb6651e51
|
@ -1,4 +1,4 @@
|
|||
import select, socket, threading, traceback, sys, time
|
||||
import select, socket, threading, sys, time, traceback
|
||||
from OpenSSL import SSL
|
||||
import certutils
|
||||
|
||||
|
@ -84,13 +84,14 @@ class _FileLike:
|
|||
def reset_timestamps(self):
|
||||
self.first_byte_timestamp = None
|
||||
|
||||
|
||||
class Writer(_FileLike):
|
||||
def flush(self):
|
||||
try:
|
||||
if hasattr(self.o, "flush"):
|
||||
if hasattr(self.o, "flush"):
|
||||
try:
|
||||
self.o.flush()
|
||||
except socket.error, v:
|
||||
raise NetLibDisconnect(str(v))
|
||||
except socket.error, v:
|
||||
raise NetLibDisconnect(str(v))
|
||||
|
||||
def write(self, v):
|
||||
if v:
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
import threading, Queue, cStringIO
|
||||
import tcp
|
||||
|
||||
class ServerThread(threading.Thread):
|
||||
def __init__(self, server):
|
||||
self.server = server
|
||||
threading.Thread.__init__(self)
|
||||
|
||||
def run(self):
|
||||
self.server.serve_forever()
|
||||
|
||||
def shutdown(self):
|
||||
self.server.shutdown()
|
||||
|
||||
|
||||
class ServerTestBase:
|
||||
@classmethod
|
||||
def setupAll(cls):
|
||||
cls.q = Queue.Queue()
|
||||
s = cls.makeserver()
|
||||
cls.port = s.port
|
||||
cls.server = ServerThread(s)
|
||||
cls.server.start()
|
||||
|
||||
@classmethod
|
||||
def teardownAll(cls):
|
||||
cls.server.shutdown()
|
||||
|
||||
|
||||
@property
|
||||
def last_handler(self):
|
||||
return self.server.server.last_handler
|
||||
|
||||
|
||||
class TServer(tcp.TCPServer):
|
||||
def __init__(self, ssl, q, handler_klass, addr=("127.0.0.1", 0)):
|
||||
"""
|
||||
ssl: A {cert, key, v3_only} dict.
|
||||
"""
|
||||
tcp.TCPServer.__init__(self, addr)
|
||||
self.ssl, self.q = ssl, q
|
||||
self.handler_klass = handler_klass
|
||||
self.last_handler = None
|
||||
|
||||
def handle_connection(self, request, client_address):
|
||||
h = self.handler_klass(request, client_address, self)
|
||||
self.last_handler = h
|
||||
if self.ssl:
|
||||
if self.ssl["v3_only"]:
|
||||
method = tcp.SSLv3_METHOD
|
||||
options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1
|
||||
else:
|
||||
method = tcp.SSLv23_METHOD
|
||||
options = None
|
||||
h.convert_to_ssl(
|
||||
self.ssl["cert"],
|
||||
self.ssl["key"],
|
||||
method = method,
|
||||
options = options,
|
||||
)
|
||||
h.handle()
|
||||
h.finish()
|
||||
|
||||
def handle_error(self, request, client_address):
|
||||
s = cStringIO.StringIO()
|
||||
tcp.TCPServer.handle_error(self, request, client_address, s)
|
||||
self.q.put(s.getvalue())
|
|
@ -30,6 +30,7 @@ class TestCertStore:
|
|||
ca = os.path.join(d, "ca")
|
||||
assert certutils.dummy_ca(ca)
|
||||
c = certutils.CertStore()
|
||||
assert not c.get_cert("../foo.com", [])
|
||||
assert not c.get_cert("foo.com", [])
|
||||
assert c.get_cert("foo.com", [], ca)
|
||||
assert c.get_cert("foo.com", [], ca)
|
||||
|
|
159
test/test_tcp.py
159
test/test_tcp.py
|
@ -1,38 +1,7 @@
|
|||
import cStringIO, threading, Queue, time
|
||||
from netlib import tcp, certutils
|
||||
from netlib import tcp, certutils, test
|
||||
import tutils
|
||||
|
||||
class ServerThread(threading.Thread):
|
||||
def __init__(self, server):
|
||||
self.server = server
|
||||
threading.Thread.__init__(self)
|
||||
|
||||
def run(self):
|
||||
self.server.serve_forever()
|
||||
|
||||
def shutdown(self):
|
||||
self.server.shutdown()
|
||||
|
||||
|
||||
class ServerTestBase:
|
||||
@classmethod
|
||||
def setupAll(cls):
|
||||
cls.q = Queue.Queue()
|
||||
s = cls.makeserver()
|
||||
cls.port = s.port
|
||||
cls.server = ServerThread(s)
|
||||
cls.server.start()
|
||||
|
||||
@classmethod
|
||||
def teardownAll(cls):
|
||||
cls.server.shutdown()
|
||||
|
||||
|
||||
@property
|
||||
def last_handler(self):
|
||||
return self.server.server.last_handler
|
||||
|
||||
|
||||
class SNIHandler(tcp.BaseHandler):
|
||||
sni = None
|
||||
def handle_sni(self, connection):
|
||||
|
@ -88,43 +57,10 @@ class TimeoutHandler(tcp.BaseHandler):
|
|||
self.timeout = True
|
||||
|
||||
|
||||
class TServer(tcp.TCPServer):
|
||||
def __init__(self, addr, ssl, q, handler_klass, v3_only=False):
|
||||
tcp.TCPServer.__init__(self, addr)
|
||||
self.ssl, self.q = ssl, q
|
||||
self.v3_only = v3_only
|
||||
self.handler_klass = handler_klass
|
||||
self.last_handler = None
|
||||
|
||||
def handle_connection(self, request, client_address):
|
||||
h = self.handler_klass(request, client_address, self)
|
||||
self.last_handler = h
|
||||
if self.ssl:
|
||||
if self.v3_only:
|
||||
method = tcp.SSLv3_METHOD
|
||||
options = tcp.OP_NO_SSLv2|tcp.OP_NO_TLSv1
|
||||
else:
|
||||
method = tcp.SSLv23_METHOD
|
||||
options = None
|
||||
h.convert_to_ssl(
|
||||
tutils.test_data.path("data/server.crt"),
|
||||
tutils.test_data.path("data/server.key"),
|
||||
method = method,
|
||||
options = options,
|
||||
)
|
||||
h.handle()
|
||||
h.finish()
|
||||
|
||||
def handle_error(self, request, client_address):
|
||||
s = cStringIO.StringIO()
|
||||
tcp.TCPServer.handle_error(self, request, client_address, s)
|
||||
self.q.put(s.getvalue())
|
||||
|
||||
|
||||
class TestServer(ServerTestBase):
|
||||
class TestServer(test.ServerTestBase):
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler)
|
||||
return test.TServer(False, cls.q, EchoHandler)
|
||||
|
||||
def test_echo(self):
|
||||
testval = "echo!\n"
|
||||
|
@ -135,10 +71,10 @@ class TestServer(ServerTestBase):
|
|||
assert c.rfile.readline() == testval
|
||||
|
||||
|
||||
class TestDisconnect(ServerTestBase):
|
||||
class TestDisconnect(test.ServerTestBase):
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
return TServer(("127.0.0.1", 0), False, cls.q, EchoHandler)
|
||||
return test.TServer(False, cls.q, EchoHandler)
|
||||
|
||||
def test_echo(self):
|
||||
testval = "echo!\n"
|
||||
|
@ -149,10 +85,18 @@ class TestDisconnect(ServerTestBase):
|
|||
assert c.rfile.readline() == testval
|
||||
|
||||
|
||||
class TestServerSSL(ServerTestBase):
|
||||
class TestServerSSL(test.ServerTestBase):
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
return TServer(("127.0.0.1", 0), True, cls.q, EchoHandler)
|
||||
return test.TServer(
|
||||
dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
v3_only = False
|
||||
),
|
||||
cls.q,
|
||||
EchoHandler
|
||||
)
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
|
@ -167,10 +111,19 @@ class TestServerSSL(ServerTestBase):
|
|||
assert certutils.get_remote_cert("127.0.0.1", self.port, None).digest("sha1")
|
||||
|
||||
|
||||
class TestSSLv3Only(ServerTestBase):
|
||||
class TestSSLv3Only(test.ServerTestBase):
|
||||
v3_only = True
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
return TServer(("127.0.0.1", 0), True, cls.q, EchoHandler, True)
|
||||
return test.TServer(
|
||||
dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
v3_only = True
|
||||
),
|
||||
cls.q,
|
||||
EchoHandler,
|
||||
)
|
||||
|
||||
def test_failure(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
|
@ -178,10 +131,18 @@ class TestSSLv3Only(ServerTestBase):
|
|||
tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com", method=tcp.TLSv1_METHOD)
|
||||
|
||||
|
||||
class TestSSLClientCert(ServerTestBase):
|
||||
class TestSSLClientCert(test.ServerTestBase):
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
return TServer(("127.0.0.1", 0), True, cls.q, CertHandler)
|
||||
return test.TServer(
|
||||
dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
v3_only = False
|
||||
),
|
||||
cls.q,
|
||||
CertHandler
|
||||
)
|
||||
|
||||
def test_clientcert(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
|
@ -199,10 +160,18 @@ class TestSSLClientCert(ServerTestBase):
|
|||
)
|
||||
|
||||
|
||||
class TestSNI(ServerTestBase):
|
||||
class TestSNI(test.ServerTestBase):
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
return TServer(("127.0.0.1", 0), True, cls.q, SNIHandler)
|
||||
return test.TServer(
|
||||
dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
v3_only = False
|
||||
),
|
||||
cls.q,
|
||||
SNIHandler
|
||||
)
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
|
@ -211,10 +180,18 @@ class TestSNI(ServerTestBase):
|
|||
assert c.rfile.readline() == "foo.com"
|
||||
|
||||
|
||||
class TestSSLDisconnect(ServerTestBase):
|
||||
class TestSSLDisconnect(test.ServerTestBase):
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
return TServer(("127.0.0.1", 0), True, cls.q, DisconnectHandler)
|
||||
return test.TServer(
|
||||
dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
v3_only = False
|
||||
),
|
||||
cls.q,
|
||||
DisconnectHandler
|
||||
)
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
|
@ -227,10 +204,10 @@ class TestSSLDisconnect(ServerTestBase):
|
|||
tutils.raises(Queue.Empty, self.q.get_nowait)
|
||||
|
||||
|
||||
class TestDisconnect(ServerTestBase):
|
||||
class TestSSLDisconnect(test.ServerTestBase):
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
return TServer(("127.0.0.1", 0), False, cls.q, DisconnectHandler)
|
||||
return test.TServer(False, cls.q, DisconnectHandler)
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
|
@ -242,10 +219,10 @@ class TestDisconnect(ServerTestBase):
|
|||
c.close()
|
||||
|
||||
|
||||
class TestServerTimeOut(ServerTestBase):
|
||||
class TestServerTimeOut(test.ServerTestBase):
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
return TServer(("127.0.0.1", 0), False, cls.q, TimeoutHandler)
|
||||
return test.TServer(False, cls.q, TimeoutHandler)
|
||||
|
||||
def test_timeout(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
|
@ -254,10 +231,10 @@ class TestServerTimeOut(ServerTestBase):
|
|||
assert self.last_handler.timeout
|
||||
|
||||
|
||||
class TestTimeOut(ServerTestBase):
|
||||
class TestTimeOut(test.ServerTestBase):
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
return TServer(("127.0.0.1", 0), False, cls.q, HangHandler)
|
||||
return test.TServer(False, cls.q, HangHandler)
|
||||
|
||||
def test_timeout(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
|
@ -266,10 +243,18 @@ class TestTimeOut(ServerTestBase):
|
|||
tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10)
|
||||
|
||||
|
||||
class TestSSLTimeOut(ServerTestBase):
|
||||
class TestSSLTimeOut(test.ServerTestBase):
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
return TServer(("127.0.0.1", 0), True, cls.q, HangHandler)
|
||||
return test.TServer(
|
||||
dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
v3_only = False
|
||||
),
|
||||
cls.q,
|
||||
HangHandler
|
||||
)
|
||||
|
||||
def test_timeout_client(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
|
|
Loading…
Reference in New Issue