Add testing.AsyncSSLTestCase

Allow subclasses of AsyncHTTPTestCase to provide their own http client
and server implementations.
This commit is contained in:
Alek Storm 2012-06-07 23:23:04 -04:00
parent 0b432be307
commit df0998650a
3 changed files with 78 additions and 71 deletions

View File

@ -60,10 +60,6 @@ class EchoPostHandler(RequestHandler):
class HTTPClientCommonTestCase(AsyncHTTPTestCase, LogTrapTestCase):
def get_http_client(self):
"""Returns AsyncHTTPClient instance. May be overridden in subclass."""
return AsyncHTTPClient(io_loop=self.io_loop)
def get_app(self):
return Application([
url("/hello", HelloWorldHandler),
@ -74,11 +70,6 @@ class HTTPClientCommonTestCase(AsyncHTTPTestCase, LogTrapTestCase):
url("/echopost", EchoPostHandler),
], gzip=True)
def setUp(self):
super(HTTPClientCommonTestCase, self).setUp()
# replace the client defined in the parent class
self.http_client = self.get_http_client()
def test_hello_world(self):
response = self.fetch("/hello")
self.assertEqual(response.code, 200)

View File

@ -8,7 +8,7 @@ from tornado.httpserver import HTTPServer
from tornado.httputil import HTTPHeaders
from tornado.iostream import IOStream
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, AsyncTestCase
from tornado.testing import AsyncHTTPTestCase, AsyncSSLTestCase, AsyncTestCase, LogTrapTestCase
from tornado.util import b, bytes_type
from tornado.web import Application, RequestHandler
import os
@ -45,38 +45,11 @@ class HelloWorldRequestHandler(RequestHandler):
self.finish("Got %d bytes in POST" % len(self.request.body))
class BaseSSLTest(AsyncHTTPTestCase, LogTrapTestCase):
def get_ssl_version(self):
raise NotImplementedError()
def setUp(self):
super(BaseSSLTest, self).setUp()
# Replace the client defined in the parent class.
# Some versions of libcurl have deadlock bugs with ssl,
# so always run these tests with SimpleAsyncHTTPClient.
self.http_client = SimpleAsyncHTTPClient(io_loop=self.io_loop,
force_instance=True)
class BaseSSLTest(AsyncSSLTestCase, LogTrapTestCase):
def get_app(self):
return Application([('/', HelloWorldRequestHandler,
dict(protocol="https"))])
def get_httpserver_options(self):
# Testing keys were generated with:
# openssl req -new -keyout tornado/test/test.key -out tornado/test/test.crt -nodes -days 3650 -x509
test_dir = os.path.dirname(__file__)
return dict(ssl_options=dict(
certfile=os.path.join(test_dir, 'test.crt'),
keyfile=os.path.join(test_dir, 'test.key'),
ssl_version=self.get_ssl_version()))
def fetch(self, path, **kwargs):
self.http_client.fetch(self.get_url(path).replace('http', 'https'),
self.stop,
validate_cert=False,
**kwargs)
return self.wait()
class SSLTestMixin(object):
def test_ssl(self):
@ -119,38 +92,36 @@ class TLSv1Test(BaseSSLTest, SSLTestMixin):
def get_ssl_version(self):
return ssl.PROTOCOL_TLSv1
if hasattr(ssl, 'PROTOCOL_SSLv2'):
class SSLv2Test(BaseSSLTest):
def get_ssl_version(self):
return ssl.PROTOCOL_SSLv2
def test_sslv2_fail(self):
# This is really more of a client test, but run it here since
# we've got all the other ssl version tests here.
# Clients should have SSLv2 disabled by default.
try:
# The server simply closes the connection when it gets
# an SSLv2 ClientHello packet.
# request_timeout is needed here because on some platforms
# (cygwin, but not native windows python), the close is not
# detected promptly.
response = self.fetch('/', request_timeout=1)
except ssl.SSLError:
# In some python/ssl builds the PROTOCOL_SSLv2 constant
# exists but SSLv2 support is still compiled out, which
# would result in an SSLError here (details vary depending
# on python version). The important thing is that
# SSLv2 request's don't succeed, so we can just ignore
# the errors here.
return
self.assertEqual(response.code, 599)
class SSLv2Test(BaseSSLTest):
def get_ssl_version(self):
return ssl.PROTOCOL_SSLv2
def test_sslv2_fail(self):
# This is really more of a client test, but run it here since
# we've got all the other ssl version tests here.
# Clients should have SSLv2 disabled by default.
try:
# The server simply closes the connection when it gets
# an SSLv2 ClientHello packet.
# request_timeout is needed here because on some platforms
# (cygwin, but not native windows python), the close is not
# detected promptly.
response = self.fetch('/', request_timeout=1)
except ssl.SSLError:
# In some python/ssl builds the PROTOCOL_SSLv2 constant
# exists but SSLv2 support is still compiled out, which
# would result in an SSLError here (details vary depending
# on python version). The important thing is that
# SSLv2 request's don't succeed, so we can just ignore
# the errors here.
return
self.assertEqual(response.code, 599)
if ssl is None:
del BaseSSLTest
del SSLv23Test
del SSLv3Test
del TLSv1Test
elif getattr(ssl, 'OPENSSL_VERSION_INFO', (0, 0)) < (1, 0):
if getattr(ssl, 'OPENSSL_VERSION_INFO', (0, 0)) < (1, 0):
# In pre-1.0 versions of openssl, SSLv23 clients always send SSLv2
# ClientHello messages, which are rejected by SSLv3 and TLSv1
# servers. Note that while the OPENSSL_VERSION_INFO was formally
@ -158,6 +129,8 @@ elif getattr(ssl, 'OPENSSL_VERSION_INFO', (0, 0)) < (1, 0):
# python 2.7
del SSLv3Test
del TLSv1Test
if not hasattr(ssl, 'PROTOCOL_SSLv2'):
del SSLv2Test
class MultipartTestHandler(RequestHandler):

View File

@ -24,6 +24,7 @@ from cStringIO import StringIO
try:
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.ioloop import IOLoop
except ImportError:
# These modules are not importable on app engine. Parts of this module
@ -31,10 +32,12 @@ except ImportError:
AsyncHTTPClient = None
HTTPServer = None
IOLoop = None
SimpleAsyncHTTPClient = None
from tornado.stack_context import StackContext, NullContext
from tornado.util import raise_exc_info
import contextlib
import logging
import os
import signal
import sys
import time
@ -232,12 +235,19 @@ class AsyncHTTPTestCase(AsyncTestCase):
super(AsyncHTTPTestCase, self).setUp()
self.__port = None
self.http_client = AsyncHTTPClient(io_loop=self.io_loop)
self.http_client = self.get_http_client()
self._app = self.get_app()
self.http_server = HTTPServer(self._app, io_loop=self.io_loop,
**self.get_httpserver_options())
self.http_server = self.get_http_server()
self.http_server.listen(self.get_http_port(), address="127.0.0.1")
def get_http_client(self):
return AsyncHTTPClient(io_loop=self.io_loop)
def get_http_server(self):
return HTTPServer(self._app, io_loop=self.io_loop,
**self.get_httpserver_options())
def get_app(self):
"""Should be overridden by subclasses to return a
tornado.web.Application or other HTTPServer callback.
@ -257,12 +267,12 @@ class AsyncHTTPTestCase(AsyncTestCase):
def get_httpserver_options(self):
"""May be overridden by subclasses to return additional
keyword arguments for HTTPServer.
keyword arguments for the server.
"""
return {}
def get_http_port(self):
"""Returns the port used by the HTTPServer.
"""Returns the port used by the server.
A new port is chosen for each test.
"""
@ -270,9 +280,13 @@ class AsyncHTTPTestCase(AsyncTestCase):
self.__port = get_unused_port()
return self.__port
def get_protocol(self):
return 'http'
def get_url(self, path):
"""Returns an absolute url for the given path on the test server."""
return 'http://localhost:%s%s' % (self.get_http_port(), path)
return '%s://localhost:%s%s' % (self.get_protocol(),
self.get_http_port(), path)
def tearDown(self):
self.http_server.stop()
@ -280,6 +294,35 @@ class AsyncHTTPTestCase(AsyncTestCase):
super(AsyncHTTPTestCase, self).tearDown()
class AsyncSSLTestCase(AsyncHTTPTestCase):
def get_ssl_version(self):
raise NotImplementedError()
def get_http_client(self):
# Some versions of libcurl have deadlock bugs with ssl,
# so always run these tests with SimpleAsyncHTTPClient.
return SimpleAsyncHTTPClient(io_loop=self.io_loop, force_instance=True)
def get_httpserver_options(self):
return dict(ssl_options=self.get_ssl_options())
def get_ssl_options(self):
# Testing keys were generated with:
# openssl req -new -keyout tornado/test/test.key -out tornado/test/test.crt -nodes -days 3650 -x509
module_dir = os.path.dirname(__file__)
return dict(
certfile=os.path.join(module_dir, 'test', 'test.crt'),
keyfile=os.path.join(module_dir, 'test', 'test.key'),
ssl_version=self.get_ssl_version())
def get_protocol(self):
return 'https'
def fetch(self, path, **kwargs):
return AsyncHTTPTestCase.fetch(self, path, validate_cert=False,
**kwargs)
class LogTrapTestCase(unittest.TestCase):
"""A test case that captures and discards all logging output
if the test passes.