testing: Type-annotate the module
This commit is contained in:
parent
ceca6f3120
commit
90a7cedac4
|
@ -19,6 +19,9 @@ disallow_untyped_defs = True
|
|||
[mypy-tornado.gen]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tornado.testing]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
# It's generally too tedious to require type annotations in tests, but
|
||||
# we do want to type check them as much as type inference allows.
|
||||
[mypy-tornado.test.util_test]
|
||||
|
@ -35,3 +38,6 @@ check_untyped_defs = True
|
|||
|
||||
[mypy-tornado.test.gen_test]
|
||||
check_untyped_defs = True
|
||||
|
||||
[mypy-tornado.test.testing_test]
|
||||
check_untyped_defs = True
|
||||
|
|
|
@ -21,7 +21,7 @@ import sys
|
|||
import unittest
|
||||
|
||||
from tornado import gen
|
||||
from tornado.httpclient import AsyncHTTPClient
|
||||
from tornado.httpclient import AsyncHTTPClient, HTTPResponse
|
||||
from tornado.httpserver import HTTPServer
|
||||
from tornado.ioloop import IOLoop, TimeoutError
|
||||
from tornado import netutil
|
||||
|
@ -29,12 +29,21 @@ from tornado.platform.asyncio import AsyncIOMainLoop
|
|||
from tornado.process import Subprocess
|
||||
from tornado.log import app_log
|
||||
from tornado.util import raise_exc_info, basestring_type
|
||||
from tornado.web import Application
|
||||
|
||||
import typing
|
||||
from typing import Tuple, Any, Callable, Type, Dict, Union, Coroutine, Optional
|
||||
from types import TracebackType
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
_ExcInfoTuple = Tuple[Optional[Type[BaseException]], Optional[BaseException],
|
||||
Optional[TracebackType]]
|
||||
|
||||
|
||||
_NON_OWNED_IOLOOPS = AsyncIOMainLoop
|
||||
|
||||
|
||||
def bind_unused_port(reuse_port=False):
|
||||
def bind_unused_port(reuse_port: bool=False) -> Tuple[socket.socket, int]:
|
||||
"""Binds a server socket to an available port on localhost.
|
||||
|
||||
Returns a tuple (socket, port).
|
||||
|
@ -49,17 +58,20 @@ def bind_unused_port(reuse_port=False):
|
|||
return sock, port
|
||||
|
||||
|
||||
def get_async_test_timeout():
|
||||
def get_async_test_timeout() -> float:
|
||||
"""Get the global timeout setting for async tests.
|
||||
|
||||
Returns a float, the timeout in seconds.
|
||||
|
||||
.. versionadded:: 3.1
|
||||
"""
|
||||
try:
|
||||
return float(os.environ.get('ASYNC_TEST_TIMEOUT'))
|
||||
except (ValueError, TypeError):
|
||||
return 5
|
||||
env = os.environ.get('ASYNC_TEST_TIMEOUT')
|
||||
if env is not None:
|
||||
try:
|
||||
return float(env)
|
||||
except ValueError:
|
||||
pass
|
||||
return 5
|
||||
|
||||
|
||||
class _TestMethodWrapper(object):
|
||||
|
@ -71,10 +83,10 @@ class _TestMethodWrapper(object):
|
|||
necessarily errors, but we alert anyway since there is no good
|
||||
reason to return a value from a test).
|
||||
"""
|
||||
def __init__(self, orig_method):
|
||||
def __init__(self, orig_method: Callable) -> None:
|
||||
self.orig_method = orig_method
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> None:
|
||||
result = self.orig_method(*args, **kwargs)
|
||||
if isinstance(result, Generator) or inspect.iscoroutine(result):
|
||||
raise TypeError("Generator and coroutine test methods should be"
|
||||
|
@ -83,7 +95,7 @@ class _TestMethodWrapper(object):
|
|||
raise ValueError("Return value from test method ignored: %r" %
|
||||
result)
|
||||
|
||||
def __getattr__(self, name):
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Proxy all unknown attributes to the original method.
|
||||
|
||||
This is important for some of the decorators in the `unittest`
|
||||
|
@ -138,12 +150,12 @@ class AsyncTestCase(unittest.TestCase):
|
|||
# Test contents of response
|
||||
self.assertIn("FriendFeed", response.body)
|
||||
"""
|
||||
def __init__(self, methodName='runTest'):
|
||||
def __init__(self, methodName: str='runTest') -> None:
|
||||
super(AsyncTestCase, self).__init__(methodName)
|
||||
self.__stopped = False
|
||||
self.__running = False
|
||||
self.__failure = None
|
||||
self.__stop_args = None
|
||||
self.__failure = None # type: Optional[_ExcInfoTuple]
|
||||
self.__stop_args = None # type: Any
|
||||
self.__timeout = None
|
||||
|
||||
# It's easy to forget the @gen_test decorator, but if you do
|
||||
|
@ -152,12 +164,15 @@ class AsyncTestCase(unittest.TestCase):
|
|||
# make sure it's not an undecorated generator.
|
||||
setattr(self, methodName, _TestMethodWrapper(getattr(self, methodName)))
|
||||
|
||||
def setUp(self):
|
||||
# Not used in this class itself, but used by @gen_test
|
||||
self._test_generator = None # type: Optional[Union[Generator, Coroutine]]
|
||||
|
||||
def setUp(self) -> None:
|
||||
super(AsyncTestCase, self).setUp()
|
||||
self.io_loop = self.get_new_ioloop()
|
||||
self.io_loop.make_current()
|
||||
|
||||
def tearDown(self):
|
||||
def tearDown(self) -> None:
|
||||
# Clean up Subprocess, so it can be used again with a new ioloop.
|
||||
Subprocess.uninitialize()
|
||||
self.io_loop.clear_current()
|
||||
|
@ -174,7 +189,7 @@ class AsyncTestCase(unittest.TestCase):
|
|||
# unittest machinery understands.
|
||||
self.__rethrow()
|
||||
|
||||
def get_new_ioloop(self):
|
||||
def get_new_ioloop(self) -> IOLoop:
|
||||
"""Returns the `.IOLoop` to use for this test.
|
||||
|
||||
By default, a new `.IOLoop` is created for each test.
|
||||
|
@ -187,7 +202,7 @@ class AsyncTestCase(unittest.TestCase):
|
|||
"""
|
||||
return IOLoop()
|
||||
|
||||
def _handle_exception(self, typ, value, tb):
|
||||
def _handle_exception(self, typ: Type[Exception], value: Exception, tb: TracebackType) -> bool:
|
||||
if self.__failure is None:
|
||||
self.__failure = (typ, value, tb)
|
||||
else:
|
||||
|
@ -196,21 +211,22 @@ class AsyncTestCase(unittest.TestCase):
|
|||
self.stop()
|
||||
return True
|
||||
|
||||
def __rethrow(self):
|
||||
def __rethrow(self) -> None:
|
||||
if self.__failure is not None:
|
||||
failure = self.__failure
|
||||
self.__failure = None
|
||||
raise_exc_info(failure)
|
||||
|
||||
def run(self, result=None):
|
||||
super(AsyncTestCase, self).run(result)
|
||||
def run(self, result: unittest.TestResult=None) -> unittest.TestCase:
|
||||
ret = super(AsyncTestCase, self).run(result)
|
||||
# As a last resort, if an exception escaped super.run() and wasn't
|
||||
# re-raised in tearDown, raise it here. This will cause the
|
||||
# unittest run to fail messily, but that's better than silently
|
||||
# ignoring an error.
|
||||
self.__rethrow()
|
||||
return ret
|
||||
|
||||
def stop(self, _arg=None, **kwargs):
|
||||
def stop(self, _arg: Any=None, **kwargs: Any) -> None:
|
||||
"""Stops the `.IOLoop`, causing one pending (or future) call to `wait()`
|
||||
to return.
|
||||
|
||||
|
@ -228,7 +244,7 @@ class AsyncTestCase(unittest.TestCase):
|
|||
self.__running = False
|
||||
self.__stopped = True
|
||||
|
||||
def wait(self, condition=None, timeout=None):
|
||||
def wait(self, condition: Callable[..., bool]=None, timeout: float=None) -> None:
|
||||
"""Runs the `.IOLoop` until stop is called or timeout has passed.
|
||||
|
||||
In the event of a timeout, an exception will be thrown. The
|
||||
|
@ -251,7 +267,7 @@ class AsyncTestCase(unittest.TestCase):
|
|||
|
||||
if not self.__stopped:
|
||||
if timeout:
|
||||
def timeout_func():
|
||||
def timeout_func() -> None:
|
||||
try:
|
||||
raise self.failureException(
|
||||
'Async operation timed out after %s seconds' %
|
||||
|
@ -310,7 +326,7 @@ class AsyncHTTPTestCase(AsyncTestCase):
|
|||
to do other asynchronous operations in tests, you'll probably need to use
|
||||
``stop()`` and ``wait()`` yourself.
|
||||
"""
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
super(AsyncHTTPTestCase, self).setUp()
|
||||
sock, port = bind_unused_port()
|
||||
self.__port = port
|
||||
|
@ -320,19 +336,19 @@ class AsyncHTTPTestCase(AsyncTestCase):
|
|||
self.http_server = self.get_http_server()
|
||||
self.http_server.add_sockets([sock])
|
||||
|
||||
def get_http_client(self):
|
||||
def get_http_client(self) -> AsyncHTTPClient:
|
||||
return AsyncHTTPClient()
|
||||
|
||||
def get_http_server(self):
|
||||
def get_http_server(self) -> HTTPServer:
|
||||
return HTTPServer(self._app, **self.get_httpserver_options())
|
||||
|
||||
def get_app(self):
|
||||
def get_app(self) -> Application:
|
||||
"""Should be overridden by subclasses to return a
|
||||
`tornado.web.Application` or other `.HTTPServer` callback.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def fetch(self, path, raise_error=False, **kwargs):
|
||||
def fetch(self, path: str, raise_error: bool=False, **kwargs: Any) -> HTTPResponse:
|
||||
"""Convenience method to synchronously fetch a URL.
|
||||
|
||||
The given path will be appended to the local server's host and
|
||||
|
@ -374,28 +390,28 @@ class AsyncHTTPTestCase(AsyncTestCase):
|
|||
lambda: self.http_client.fetch(url, raise_error=raise_error, **kwargs),
|
||||
timeout=get_async_test_timeout())
|
||||
|
||||
def get_httpserver_options(self):
|
||||
def get_httpserver_options(self) -> Dict[str, Any]:
|
||||
"""May be overridden by subclasses to return additional
|
||||
keyword arguments for the server.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_http_port(self):
|
||||
def get_http_port(self) -> int:
|
||||
"""Returns the port used by the server.
|
||||
|
||||
A new port is chosen for each test.
|
||||
"""
|
||||
return self.__port
|
||||
|
||||
def get_protocol(self):
|
||||
def get_protocol(self) -> str:
|
||||
return 'http'
|
||||
|
||||
def get_url(self, path):
|
||||
def get_url(self, path: str) -> str:
|
||||
"""Returns an absolute url for the given path on the test server."""
|
||||
return '%s://127.0.0.1:%s%s' % (self.get_protocol(),
|
||||
self.get_http_port(), path)
|
||||
|
||||
def tearDown(self):
|
||||
def tearDown(self) -> None:
|
||||
self.http_server.stop()
|
||||
self.io_loop.run_sync(self.http_server.close_all_connections,
|
||||
timeout=get_async_test_timeout())
|
||||
|
@ -408,14 +424,14 @@ class AsyncHTTPSTestCase(AsyncHTTPTestCase):
|
|||
|
||||
Interface is generally the same as `AsyncHTTPTestCase`.
|
||||
"""
|
||||
def get_http_client(self):
|
||||
def get_http_client(self) -> AsyncHTTPClient:
|
||||
return AsyncHTTPClient(force_instance=True,
|
||||
defaults=dict(validate_cert=False))
|
||||
|
||||
def get_httpserver_options(self):
|
||||
def get_httpserver_options(self) -> Dict[str, Any]:
|
||||
return dict(ssl_options=self.get_ssl_options())
|
||||
|
||||
def get_ssl_options(self):
|
||||
def get_ssl_options(self) -> Dict[str, Any]:
|
||||
"""May be overridden by subclasses to select SSL options.
|
||||
|
||||
By default includes a self-signed testing certificate.
|
||||
|
@ -428,11 +444,25 @@ class AsyncHTTPSTestCase(AsyncHTTPTestCase):
|
|||
certfile=os.path.join(module_dir, 'test', 'test.crt'),
|
||||
keyfile=os.path.join(module_dir, 'test', 'test.key'))
|
||||
|
||||
def get_protocol(self):
|
||||
def get_protocol(self) -> str:
|
||||
return 'https'
|
||||
|
||||
|
||||
def gen_test(func=None, timeout=None):
|
||||
@typing.overload
|
||||
def gen_test(*, timeout: float=None) -> Callable[[Callable[..., Union[Generator, Coroutine]]],
|
||||
Callable[..., None]]:
|
||||
pass
|
||||
|
||||
|
||||
@typing.overload # noqa: F811
|
||||
def gen_test(func: Callable[..., Union[Generator, Coroutine]]) -> Callable[..., None]:
|
||||
pass
|
||||
|
||||
|
||||
def gen_test( # noqa: F811
|
||||
func: Callable[..., Union[Generator, Coroutine]]=None, timeout: float=None,
|
||||
) -> Union[Callable[..., None],
|
||||
Callable[[Callable[..., Union[Generator, Coroutine]]], Callable[..., None]]]:
|
||||
"""Testing equivalent of ``@gen.coroutine``, to be applied to test methods.
|
||||
|
||||
``@gen.coroutine`` cannot be used on tests because the `.IOLoop` is not
|
||||
|
@ -471,7 +501,7 @@ def gen_test(func=None, timeout=None):
|
|||
if timeout is None:
|
||||
timeout = get_async_test_timeout()
|
||||
|
||||
def wrap(f):
|
||||
def wrap(f: Callable[..., Union[Generator, Coroutine]]) -> Callable[..., None]:
|
||||
# Stack up several decorators to allow us to access the generator
|
||||
# object itself. In the innermost wrapper, we capture the generator
|
||||
# and save it in an attribute of self. Next, we run the wrapped
|
||||
|
@ -482,6 +512,8 @@ def gen_test(func=None, timeout=None):
|
|||
# extensibility in the gen decorators or cancellation support.
|
||||
@functools.wraps(f)
|
||||
def pre_coroutine(self, *args, **kwargs):
|
||||
# type: (AsyncTestCase, *Any, **Any) -> Union[Generator, Coroutine]
|
||||
# Type comments used to avoid pypy3 bug.
|
||||
result = f(self, *args, **kwargs)
|
||||
if isinstance(result, Generator) or inspect.iscoroutine(result):
|
||||
self._test_generator = result
|
||||
|
@ -496,6 +528,7 @@ def gen_test(func=None, timeout=None):
|
|||
|
||||
@functools.wraps(coro)
|
||||
def post_coroutine(self, *args, **kwargs):
|
||||
# type: (AsyncTestCase, *Any, **Any) -> None
|
||||
try:
|
||||
return self.io_loop.run_sync(
|
||||
functools.partial(coro, self, *args, **kwargs),
|
||||
|
@ -507,8 +540,9 @@ def gen_test(func=None, timeout=None):
|
|||
# point where the test is stopped. The only reason the generator
|
||||
# would not be running would be if it were cancelled, which means
|
||||
# a native coroutine, so we can rely on the cr_running attribute.
|
||||
if getattr(self._test_generator, 'cr_running', True):
|
||||
self._test_generator.throw(e)
|
||||
if (self._test_generator is not None and
|
||||
getattr(self._test_generator, 'cr_running', True)):
|
||||
self._test_generator.throw(type(e), e)
|
||||
# In case the test contains an overly broad except
|
||||
# clause, we may get back here.
|
||||
# Coroutine was stopped or didn't raise a useful stack trace,
|
||||
|
@ -549,7 +583,8 @@ class ExpectLog(logging.Filter):
|
|||
.. versionchanged:: 4.3
|
||||
Added the ``logged_stack`` attribute.
|
||||
"""
|
||||
def __init__(self, logger, regex, required=True):
|
||||
def __init__(self, logger: Union[logging.Logger, basestring_type], regex: str,
|
||||
required: bool=True) -> None:
|
||||
"""Constructs an ExpectLog context manager.
|
||||
|
||||
:param logger: Logger object (or name of logger) to watch. Pass
|
||||
|
@ -567,7 +602,7 @@ class ExpectLog(logging.Filter):
|
|||
self.matched = False
|
||||
self.logged_stack = False
|
||||
|
||||
def filter(self, record):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if record.exc_info:
|
||||
self.logged_stack = True
|
||||
message = record.getMessage()
|
||||
|
@ -576,17 +611,18 @@ class ExpectLog(logging.Filter):
|
|||
return False
|
||||
return True
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> logging.Filter:
|
||||
self.logger.addFilter(self)
|
||||
return self
|
||||
|
||||
def __exit__(self, typ, value, tb):
|
||||
def __exit__(self, typ: Optional[Type[BaseException]], value: Optional[BaseException],
|
||||
tb: Optional[TracebackType]) -> None:
|
||||
self.logger.removeFilter(self)
|
||||
if not typ and self.required and not self.matched:
|
||||
raise Exception("did not get expected log message")
|
||||
|
||||
|
||||
def main(**kwargs):
|
||||
def main(**kwargs: Any) -> None:
|
||||
"""A simple test runner.
|
||||
|
||||
This test runner is essentially equivalent to `unittest.main` from
|
||||
|
@ -667,7 +703,7 @@ def main(**kwargs):
|
|||
# test discovery, which is incompatible with auto2to3), so don't
|
||||
# set module if we're not asking for a specific test.
|
||||
if len(argv) > 1:
|
||||
unittest.main(module=None, argv=argv, **kwargs)
|
||||
unittest.main(module=None, argv=argv, **kwargs) # type: ignore
|
||||
else:
|
||||
unittest.main(defaultTest="all", argv=argv, **kwargs)
|
||||
|
||||
|
|
Loading…
Reference in New Issue