From 90a7cedac49c85699aa3dfbc0cc4f03279c0dc27 Mon Sep 17 00:00:00 2001 From: Ben Darnell Date: Sat, 11 Aug 2018 16:40:07 -0400 Subject: [PATCH] testing: Type-annotate the module --- setup.cfg | 6 +++ tornado/testing.py | 130 +++++++++++++++++++++++++++++---------------- 2 files changed, 89 insertions(+), 47 deletions(-) diff --git a/setup.cfg b/setup.cfg index 26d19f2f..13a4ce3f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,6 +19,9 @@ disallow_untyped_defs = True [mypy-tornado.gen] disallow_untyped_defs = True +[mypy-tornado.testing] +disallow_untyped_defs = True + # It's generally too tedious to require type annotations in tests, but # we do want to type check them as much as type inference allows. [mypy-tornado.test.util_test] @@ -35,3 +38,6 @@ check_untyped_defs = True [mypy-tornado.test.gen_test] check_untyped_defs = True + +[mypy-tornado.test.testing_test] +check_untyped_defs = True diff --git a/tornado/testing.py b/tornado/testing.py index 4521ea9c..40a6e759 100644 --- a/tornado/testing.py +++ b/tornado/testing.py @@ -21,7 +21,7 @@ import sys import unittest from tornado import gen -from tornado.httpclient import AsyncHTTPClient +from tornado.httpclient import AsyncHTTPClient, HTTPResponse from tornado.httpserver import HTTPServer from tornado.ioloop import IOLoop, TimeoutError from tornado import netutil @@ -29,12 +29,21 @@ from tornado.platform.asyncio import AsyncIOMainLoop from tornado.process import Subprocess from tornado.log import app_log from tornado.util import raise_exc_info, basestring_type +from tornado.web import Application + +import typing +from typing import Tuple, Any, Callable, Type, Dict, Union, Coroutine, Optional +from types import TracebackType + +if typing.TYPE_CHECKING: + _ExcInfoTuple = Tuple[Optional[Type[BaseException]], Optional[BaseException], + Optional[TracebackType]] _NON_OWNED_IOLOOPS = AsyncIOMainLoop -def bind_unused_port(reuse_port=False): +def bind_unused_port(reuse_port: bool=False) -> Tuple[socket.socket, int]: """Binds a server socket to an available port on localhost. Returns a tuple (socket, port). @@ -49,17 +58,20 @@ def bind_unused_port(reuse_port=False): return sock, port -def get_async_test_timeout(): +def get_async_test_timeout() -> float: """Get the global timeout setting for async tests. Returns a float, the timeout in seconds. .. versionadded:: 3.1 """ - try: - return float(os.environ.get('ASYNC_TEST_TIMEOUT')) - except (ValueError, TypeError): - return 5 + env = os.environ.get('ASYNC_TEST_TIMEOUT') + if env is not None: + try: + return float(env) + except ValueError: + pass + return 5 class _TestMethodWrapper(object): @@ -71,10 +83,10 @@ class _TestMethodWrapper(object): necessarily errors, but we alert anyway since there is no good reason to return a value from a test). """ - def __init__(self, orig_method): + def __init__(self, orig_method: Callable) -> None: self.orig_method = orig_method - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> None: result = self.orig_method(*args, **kwargs) if isinstance(result, Generator) or inspect.iscoroutine(result): raise TypeError("Generator and coroutine test methods should be" @@ -83,7 +95,7 @@ class _TestMethodWrapper(object): raise ValueError("Return value from test method ignored: %r" % result) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: """Proxy all unknown attributes to the original method. This is important for some of the decorators in the `unittest` @@ -138,12 +150,12 @@ class AsyncTestCase(unittest.TestCase): # Test contents of response self.assertIn("FriendFeed", response.body) """ - def __init__(self, methodName='runTest'): + def __init__(self, methodName: str='runTest') -> None: super(AsyncTestCase, self).__init__(methodName) self.__stopped = False self.__running = False - self.__failure = None - self.__stop_args = None + self.__failure = None # type: Optional[_ExcInfoTuple] + self.__stop_args = None # type: Any self.__timeout = None # It's easy to forget the @gen_test decorator, but if you do @@ -152,12 +164,15 @@ class AsyncTestCase(unittest.TestCase): # make sure it's not an undecorated generator. setattr(self, methodName, _TestMethodWrapper(getattr(self, methodName))) - def setUp(self): + # Not used in this class itself, but used by @gen_test + self._test_generator = None # type: Optional[Union[Generator, Coroutine]] + + def setUp(self) -> None: super(AsyncTestCase, self).setUp() self.io_loop = self.get_new_ioloop() self.io_loop.make_current() - def tearDown(self): + def tearDown(self) -> None: # Clean up Subprocess, so it can be used again with a new ioloop. Subprocess.uninitialize() self.io_loop.clear_current() @@ -174,7 +189,7 @@ class AsyncTestCase(unittest.TestCase): # unittest machinery understands. self.__rethrow() - def get_new_ioloop(self): + def get_new_ioloop(self) -> IOLoop: """Returns the `.IOLoop` to use for this test. By default, a new `.IOLoop` is created for each test. @@ -187,7 +202,7 @@ class AsyncTestCase(unittest.TestCase): """ return IOLoop() - def _handle_exception(self, typ, value, tb): + def _handle_exception(self, typ: Type[Exception], value: Exception, tb: TracebackType) -> bool: if self.__failure is None: self.__failure = (typ, value, tb) else: @@ -196,21 +211,22 @@ class AsyncTestCase(unittest.TestCase): self.stop() return True - def __rethrow(self): + def __rethrow(self) -> None: if self.__failure is not None: failure = self.__failure self.__failure = None raise_exc_info(failure) - def run(self, result=None): - super(AsyncTestCase, self).run(result) + def run(self, result: unittest.TestResult=None) -> unittest.TestCase: + ret = super(AsyncTestCase, self).run(result) # As a last resort, if an exception escaped super.run() and wasn't # re-raised in tearDown, raise it here. This will cause the # unittest run to fail messily, but that's better than silently # ignoring an error. self.__rethrow() + return ret - def stop(self, _arg=None, **kwargs): + def stop(self, _arg: Any=None, **kwargs: Any) -> None: """Stops the `.IOLoop`, causing one pending (or future) call to `wait()` to return. @@ -228,7 +244,7 @@ class AsyncTestCase(unittest.TestCase): self.__running = False self.__stopped = True - def wait(self, condition=None, timeout=None): + def wait(self, condition: Callable[..., bool]=None, timeout: float=None) -> None: """Runs the `.IOLoop` until stop is called or timeout has passed. In the event of a timeout, an exception will be thrown. The @@ -251,7 +267,7 @@ class AsyncTestCase(unittest.TestCase): if not self.__stopped: if timeout: - def timeout_func(): + def timeout_func() -> None: try: raise self.failureException( 'Async operation timed out after %s seconds' % @@ -310,7 +326,7 @@ class AsyncHTTPTestCase(AsyncTestCase): to do other asynchronous operations in tests, you'll probably need to use ``stop()`` and ``wait()`` yourself. """ - def setUp(self): + def setUp(self) -> None: super(AsyncHTTPTestCase, self).setUp() sock, port = bind_unused_port() self.__port = port @@ -320,19 +336,19 @@ class AsyncHTTPTestCase(AsyncTestCase): self.http_server = self.get_http_server() self.http_server.add_sockets([sock]) - def get_http_client(self): + def get_http_client(self) -> AsyncHTTPClient: return AsyncHTTPClient() - def get_http_server(self): + def get_http_server(self) -> HTTPServer: return HTTPServer(self._app, **self.get_httpserver_options()) - def get_app(self): + def get_app(self) -> Application: """Should be overridden by subclasses to return a `tornado.web.Application` or other `.HTTPServer` callback. """ raise NotImplementedError() - def fetch(self, path, raise_error=False, **kwargs): + def fetch(self, path: str, raise_error: bool=False, **kwargs: Any) -> HTTPResponse: """Convenience method to synchronously fetch a URL. The given path will be appended to the local server's host and @@ -374,28 +390,28 @@ class AsyncHTTPTestCase(AsyncTestCase): lambda: self.http_client.fetch(url, raise_error=raise_error, **kwargs), timeout=get_async_test_timeout()) - def get_httpserver_options(self): + def get_httpserver_options(self) -> Dict[str, Any]: """May be overridden by subclasses to return additional keyword arguments for the server. """ return {} - def get_http_port(self): + def get_http_port(self) -> int: """Returns the port used by the server. A new port is chosen for each test. """ return self.__port - def get_protocol(self): + def get_protocol(self) -> str: return 'http' - def get_url(self, path): + def get_url(self, path: str) -> str: """Returns an absolute url for the given path on the test server.""" return '%s://127.0.0.1:%s%s' % (self.get_protocol(), self.get_http_port(), path) - def tearDown(self): + def tearDown(self) -> None: self.http_server.stop() self.io_loop.run_sync(self.http_server.close_all_connections, timeout=get_async_test_timeout()) @@ -408,14 +424,14 @@ class AsyncHTTPSTestCase(AsyncHTTPTestCase): Interface is generally the same as `AsyncHTTPTestCase`. """ - def get_http_client(self): + def get_http_client(self) -> AsyncHTTPClient: return AsyncHTTPClient(force_instance=True, defaults=dict(validate_cert=False)) - def get_httpserver_options(self): + def get_httpserver_options(self) -> Dict[str, Any]: return dict(ssl_options=self.get_ssl_options()) - def get_ssl_options(self): + def get_ssl_options(self) -> Dict[str, Any]: """May be overridden by subclasses to select SSL options. By default includes a self-signed testing certificate. @@ -428,11 +444,25 @@ class AsyncHTTPSTestCase(AsyncHTTPTestCase): certfile=os.path.join(module_dir, 'test', 'test.crt'), keyfile=os.path.join(module_dir, 'test', 'test.key')) - def get_protocol(self): + def get_protocol(self) -> str: return 'https' -def gen_test(func=None, timeout=None): +@typing.overload +def gen_test(*, timeout: float=None) -> Callable[[Callable[..., Union[Generator, Coroutine]]], + Callable[..., None]]: + pass + + +@typing.overload # noqa: F811 +def gen_test(func: Callable[..., Union[Generator, Coroutine]]) -> Callable[..., None]: + pass + + +def gen_test( # noqa: F811 + func: Callable[..., Union[Generator, Coroutine]]=None, timeout: float=None, +) -> Union[Callable[..., None], + Callable[[Callable[..., Union[Generator, Coroutine]]], Callable[..., None]]]: """Testing equivalent of ``@gen.coroutine``, to be applied to test methods. ``@gen.coroutine`` cannot be used on tests because the `.IOLoop` is not @@ -471,7 +501,7 @@ def gen_test(func=None, timeout=None): if timeout is None: timeout = get_async_test_timeout() - def wrap(f): + def wrap(f: Callable[..., Union[Generator, Coroutine]]) -> Callable[..., None]: # Stack up several decorators to allow us to access the generator # object itself. In the innermost wrapper, we capture the generator # and save it in an attribute of self. Next, we run the wrapped @@ -482,6 +512,8 @@ def gen_test(func=None, timeout=None): # extensibility in the gen decorators or cancellation support. @functools.wraps(f) def pre_coroutine(self, *args, **kwargs): + # type: (AsyncTestCase, *Any, **Any) -> Union[Generator, Coroutine] + # Type comments used to avoid pypy3 bug. result = f(self, *args, **kwargs) if isinstance(result, Generator) or inspect.iscoroutine(result): self._test_generator = result @@ -496,6 +528,7 @@ def gen_test(func=None, timeout=None): @functools.wraps(coro) def post_coroutine(self, *args, **kwargs): + # type: (AsyncTestCase, *Any, **Any) -> None try: return self.io_loop.run_sync( functools.partial(coro, self, *args, **kwargs), @@ -507,8 +540,9 @@ def gen_test(func=None, timeout=None): # point where the test is stopped. The only reason the generator # would not be running would be if it were cancelled, which means # a native coroutine, so we can rely on the cr_running attribute. - if getattr(self._test_generator, 'cr_running', True): - self._test_generator.throw(e) + if (self._test_generator is not None and + getattr(self._test_generator, 'cr_running', True)): + self._test_generator.throw(type(e), e) # In case the test contains an overly broad except # clause, we may get back here. # Coroutine was stopped or didn't raise a useful stack trace, @@ -549,7 +583,8 @@ class ExpectLog(logging.Filter): .. versionchanged:: 4.3 Added the ``logged_stack`` attribute. """ - def __init__(self, logger, regex, required=True): + def __init__(self, logger: Union[logging.Logger, basestring_type], regex: str, + required: bool=True) -> None: """Constructs an ExpectLog context manager. :param logger: Logger object (or name of logger) to watch. Pass @@ -567,7 +602,7 @@ class ExpectLog(logging.Filter): self.matched = False self.logged_stack = False - def filter(self, record): + def filter(self, record: logging.LogRecord) -> bool: if record.exc_info: self.logged_stack = True message = record.getMessage() @@ -576,17 +611,18 @@ class ExpectLog(logging.Filter): return False return True - def __enter__(self): + def __enter__(self) -> logging.Filter: self.logger.addFilter(self) return self - def __exit__(self, typ, value, tb): + def __exit__(self, typ: Optional[Type[BaseException]], value: Optional[BaseException], + tb: Optional[TracebackType]) -> None: self.logger.removeFilter(self) if not typ and self.required and not self.matched: raise Exception("did not get expected log message") -def main(**kwargs): +def main(**kwargs: Any) -> None: """A simple test runner. This test runner is essentially equivalent to `unittest.main` from @@ -667,7 +703,7 @@ def main(**kwargs): # test discovery, which is incompatible with auto2to3), so don't # set module if we're not asking for a specific test. if len(argv) > 1: - unittest.main(module=None, argv=argv, **kwargs) + unittest.main(module=None, argv=argv, **kwargs) # type: ignore else: unittest.main(defaultTest="all", argv=argv, **kwargs)