testing: Type-annotate the module

This commit is contained in:
Ben Darnell 2018-08-11 16:40:07 -04:00
parent ceca6f3120
commit 90a7cedac4
2 changed files with 89 additions and 47 deletions

View File

@ -19,6 +19,9 @@ disallow_untyped_defs = True
[mypy-tornado.gen]
disallow_untyped_defs = True
[mypy-tornado.testing]
disallow_untyped_defs = True
# It's generally too tedious to require type annotations in tests, but
# we do want to type check them as much as type inference allows.
[mypy-tornado.test.util_test]
@ -35,3 +38,6 @@ check_untyped_defs = True
[mypy-tornado.test.gen_test]
check_untyped_defs = True
[mypy-tornado.test.testing_test]
check_untyped_defs = True

View File

@ -21,7 +21,7 @@ import sys
import unittest
from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httpclient import AsyncHTTPClient, HTTPResponse
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop, TimeoutError
from tornado import netutil
@ -29,12 +29,21 @@ from tornado.platform.asyncio import AsyncIOMainLoop
from tornado.process import Subprocess
from tornado.log import app_log
from tornado.util import raise_exc_info, basestring_type
from tornado.web import Application
import typing
from typing import Tuple, Any, Callable, Type, Dict, Union, Coroutine, Optional
from types import TracebackType
if typing.TYPE_CHECKING:
_ExcInfoTuple = Tuple[Optional[Type[BaseException]], Optional[BaseException],
Optional[TracebackType]]
_NON_OWNED_IOLOOPS = AsyncIOMainLoop
def bind_unused_port(reuse_port=False):
def bind_unused_port(reuse_port: bool=False) -> Tuple[socket.socket, int]:
"""Binds a server socket to an available port on localhost.
Returns a tuple (socket, port).
@ -49,17 +58,20 @@ def bind_unused_port(reuse_port=False):
return sock, port
def get_async_test_timeout():
def get_async_test_timeout() -> float:
"""Get the global timeout setting for async tests.
Returns a float, the timeout in seconds.
.. versionadded:: 3.1
"""
try:
return float(os.environ.get('ASYNC_TEST_TIMEOUT'))
except (ValueError, TypeError):
return 5
env = os.environ.get('ASYNC_TEST_TIMEOUT')
if env is not None:
try:
return float(env)
except ValueError:
pass
return 5
class _TestMethodWrapper(object):
@ -71,10 +83,10 @@ class _TestMethodWrapper(object):
necessarily errors, but we alert anyway since there is no good
reason to return a value from a test).
"""
def __init__(self, orig_method):
def __init__(self, orig_method: Callable) -> None:
self.orig_method = orig_method
def __call__(self, *args, **kwargs):
def __call__(self, *args: Any, **kwargs: Any) -> None:
result = self.orig_method(*args, **kwargs)
if isinstance(result, Generator) or inspect.iscoroutine(result):
raise TypeError("Generator and coroutine test methods should be"
@ -83,7 +95,7 @@ class _TestMethodWrapper(object):
raise ValueError("Return value from test method ignored: %r" %
result)
def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
"""Proxy all unknown attributes to the original method.
This is important for some of the decorators in the `unittest`
@ -138,12 +150,12 @@ class AsyncTestCase(unittest.TestCase):
# Test contents of response
self.assertIn("FriendFeed", response.body)
"""
def __init__(self, methodName='runTest'):
def __init__(self, methodName: str='runTest') -> None:
super(AsyncTestCase, self).__init__(methodName)
self.__stopped = False
self.__running = False
self.__failure = None
self.__stop_args = None
self.__failure = None # type: Optional[_ExcInfoTuple]
self.__stop_args = None # type: Any
self.__timeout = None
# It's easy to forget the @gen_test decorator, but if you do
@ -152,12 +164,15 @@ class AsyncTestCase(unittest.TestCase):
# make sure it's not an undecorated generator.
setattr(self, methodName, _TestMethodWrapper(getattr(self, methodName)))
def setUp(self):
# Not used in this class itself, but used by @gen_test
self._test_generator = None # type: Optional[Union[Generator, Coroutine]]
def setUp(self) -> None:
super(AsyncTestCase, self).setUp()
self.io_loop = self.get_new_ioloop()
self.io_loop.make_current()
def tearDown(self):
def tearDown(self) -> None:
# Clean up Subprocess, so it can be used again with a new ioloop.
Subprocess.uninitialize()
self.io_loop.clear_current()
@ -174,7 +189,7 @@ class AsyncTestCase(unittest.TestCase):
# unittest machinery understands.
self.__rethrow()
def get_new_ioloop(self):
def get_new_ioloop(self) -> IOLoop:
"""Returns the `.IOLoop` to use for this test.
By default, a new `.IOLoop` is created for each test.
@ -187,7 +202,7 @@ class AsyncTestCase(unittest.TestCase):
"""
return IOLoop()
def _handle_exception(self, typ, value, tb):
def _handle_exception(self, typ: Type[Exception], value: Exception, tb: TracebackType) -> bool:
if self.__failure is None:
self.__failure = (typ, value, tb)
else:
@ -196,21 +211,22 @@ class AsyncTestCase(unittest.TestCase):
self.stop()
return True
def __rethrow(self):
def __rethrow(self) -> None:
if self.__failure is not None:
failure = self.__failure
self.__failure = None
raise_exc_info(failure)
def run(self, result=None):
super(AsyncTestCase, self).run(result)
def run(self, result: unittest.TestResult=None) -> unittest.TestCase:
ret = super(AsyncTestCase, self).run(result)
# As a last resort, if an exception escaped super.run() and wasn't
# re-raised in tearDown, raise it here. This will cause the
# unittest run to fail messily, but that's better than silently
# ignoring an error.
self.__rethrow()
return ret
def stop(self, _arg=None, **kwargs):
def stop(self, _arg: Any=None, **kwargs: Any) -> None:
"""Stops the `.IOLoop`, causing one pending (or future) call to `wait()`
to return.
@ -228,7 +244,7 @@ class AsyncTestCase(unittest.TestCase):
self.__running = False
self.__stopped = True
def wait(self, condition=None, timeout=None):
def wait(self, condition: Callable[..., bool]=None, timeout: float=None) -> None:
"""Runs the `.IOLoop` until stop is called or timeout has passed.
In the event of a timeout, an exception will be thrown. The
@ -251,7 +267,7 @@ class AsyncTestCase(unittest.TestCase):
if not self.__stopped:
if timeout:
def timeout_func():
def timeout_func() -> None:
try:
raise self.failureException(
'Async operation timed out after %s seconds' %
@ -310,7 +326,7 @@ class AsyncHTTPTestCase(AsyncTestCase):
to do other asynchronous operations in tests, you'll probably need to use
``stop()`` and ``wait()`` yourself.
"""
def setUp(self):
def setUp(self) -> None:
super(AsyncHTTPTestCase, self).setUp()
sock, port = bind_unused_port()
self.__port = port
@ -320,19 +336,19 @@ class AsyncHTTPTestCase(AsyncTestCase):
self.http_server = self.get_http_server()
self.http_server.add_sockets([sock])
def get_http_client(self):
def get_http_client(self) -> AsyncHTTPClient:
return AsyncHTTPClient()
def get_http_server(self):
def get_http_server(self) -> HTTPServer:
return HTTPServer(self._app, **self.get_httpserver_options())
def get_app(self):
def get_app(self) -> Application:
"""Should be overridden by subclasses to return a
`tornado.web.Application` or other `.HTTPServer` callback.
"""
raise NotImplementedError()
def fetch(self, path, raise_error=False, **kwargs):
def fetch(self, path: str, raise_error: bool=False, **kwargs: Any) -> HTTPResponse:
"""Convenience method to synchronously fetch a URL.
The given path will be appended to the local server's host and
@ -374,28 +390,28 @@ class AsyncHTTPTestCase(AsyncTestCase):
lambda: self.http_client.fetch(url, raise_error=raise_error, **kwargs),
timeout=get_async_test_timeout())
def get_httpserver_options(self):
def get_httpserver_options(self) -> Dict[str, Any]:
"""May be overridden by subclasses to return additional
keyword arguments for the server.
"""
return {}
def get_http_port(self):
def get_http_port(self) -> int:
"""Returns the port used by the server.
A new port is chosen for each test.
"""
return self.__port
def get_protocol(self):
def get_protocol(self) -> str:
return 'http'
def get_url(self, path):
def get_url(self, path: str) -> str:
"""Returns an absolute url for the given path on the test server."""
return '%s://127.0.0.1:%s%s' % (self.get_protocol(),
self.get_http_port(), path)
def tearDown(self):
def tearDown(self) -> None:
self.http_server.stop()
self.io_loop.run_sync(self.http_server.close_all_connections,
timeout=get_async_test_timeout())
@ -408,14 +424,14 @@ class AsyncHTTPSTestCase(AsyncHTTPTestCase):
Interface is generally the same as `AsyncHTTPTestCase`.
"""
def get_http_client(self):
def get_http_client(self) -> AsyncHTTPClient:
return AsyncHTTPClient(force_instance=True,
defaults=dict(validate_cert=False))
def get_httpserver_options(self):
def get_httpserver_options(self) -> Dict[str, Any]:
return dict(ssl_options=self.get_ssl_options())
def get_ssl_options(self):
def get_ssl_options(self) -> Dict[str, Any]:
"""May be overridden by subclasses to select SSL options.
By default includes a self-signed testing certificate.
@ -428,11 +444,25 @@ class AsyncHTTPSTestCase(AsyncHTTPTestCase):
certfile=os.path.join(module_dir, 'test', 'test.crt'),
keyfile=os.path.join(module_dir, 'test', 'test.key'))
def get_protocol(self):
def get_protocol(self) -> str:
return 'https'
def gen_test(func=None, timeout=None):
@typing.overload
def gen_test(*, timeout: float=None) -> Callable[[Callable[..., Union[Generator, Coroutine]]],
Callable[..., None]]:
pass
@typing.overload # noqa: F811
def gen_test(func: Callable[..., Union[Generator, Coroutine]]) -> Callable[..., None]:
pass
def gen_test( # noqa: F811
func: Callable[..., Union[Generator, Coroutine]]=None, timeout: float=None,
) -> Union[Callable[..., None],
Callable[[Callable[..., Union[Generator, Coroutine]]], Callable[..., None]]]:
"""Testing equivalent of ``@gen.coroutine``, to be applied to test methods.
``@gen.coroutine`` cannot be used on tests because the `.IOLoop` is not
@ -471,7 +501,7 @@ def gen_test(func=None, timeout=None):
if timeout is None:
timeout = get_async_test_timeout()
def wrap(f):
def wrap(f: Callable[..., Union[Generator, Coroutine]]) -> Callable[..., None]:
# Stack up several decorators to allow us to access the generator
# object itself. In the innermost wrapper, we capture the generator
# and save it in an attribute of self. Next, we run the wrapped
@ -482,6 +512,8 @@ def gen_test(func=None, timeout=None):
# extensibility in the gen decorators or cancellation support.
@functools.wraps(f)
def pre_coroutine(self, *args, **kwargs):
# type: (AsyncTestCase, *Any, **Any) -> Union[Generator, Coroutine]
# Type comments used to avoid pypy3 bug.
result = f(self, *args, **kwargs)
if isinstance(result, Generator) or inspect.iscoroutine(result):
self._test_generator = result
@ -496,6 +528,7 @@ def gen_test(func=None, timeout=None):
@functools.wraps(coro)
def post_coroutine(self, *args, **kwargs):
# type: (AsyncTestCase, *Any, **Any) -> None
try:
return self.io_loop.run_sync(
functools.partial(coro, self, *args, **kwargs),
@ -507,8 +540,9 @@ def gen_test(func=None, timeout=None):
# point where the test is stopped. The only reason the generator
# would not be running would be if it were cancelled, which means
# a native coroutine, so we can rely on the cr_running attribute.
if getattr(self._test_generator, 'cr_running', True):
self._test_generator.throw(e)
if (self._test_generator is not None and
getattr(self._test_generator, 'cr_running', True)):
self._test_generator.throw(type(e), e)
# In case the test contains an overly broad except
# clause, we may get back here.
# Coroutine was stopped or didn't raise a useful stack trace,
@ -549,7 +583,8 @@ class ExpectLog(logging.Filter):
.. versionchanged:: 4.3
Added the ``logged_stack`` attribute.
"""
def __init__(self, logger, regex, required=True):
def __init__(self, logger: Union[logging.Logger, basestring_type], regex: str,
required: bool=True) -> None:
"""Constructs an ExpectLog context manager.
:param logger: Logger object (or name of logger) to watch. Pass
@ -567,7 +602,7 @@ class ExpectLog(logging.Filter):
self.matched = False
self.logged_stack = False
def filter(self, record):
def filter(self, record: logging.LogRecord) -> bool:
if record.exc_info:
self.logged_stack = True
message = record.getMessage()
@ -576,17 +611,18 @@ class ExpectLog(logging.Filter):
return False
return True
def __enter__(self):
def __enter__(self) -> logging.Filter:
self.logger.addFilter(self)
return self
def __exit__(self, typ, value, tb):
def __exit__(self, typ: Optional[Type[BaseException]], value: Optional[BaseException],
tb: Optional[TracebackType]) -> None:
self.logger.removeFilter(self)
if not typ and self.required and not self.matched:
raise Exception("did not get expected log message")
def main(**kwargs):
def main(**kwargs: Any) -> None:
"""A simple test runner.
This test runner is essentially equivalent to `unittest.main` from
@ -667,7 +703,7 @@ def main(**kwargs):
# test discovery, which is incompatible with auto2to3), so don't
# set module if we're not asking for a specific test.
if len(argv) > 1:
unittest.main(module=None, argv=argv, **kwargs)
unittest.main(module=None, argv=argv, **kwargs) # type: ignore
else:
unittest.main(defaultTest="all", argv=argv, **kwargs)