From a02d4f67b9276f836345c6a7f6f8b8675999818d Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Mon, 12 Oct 2015 13:45:15 -0400 Subject: [PATCH] Update gen_test for native coroutines. --- tornado/test/testing_test.py | 50 ++++++++++++++++++++++++++++++++++-- tornado/testing.py | 24 ++++++++++++----- 2 files changed, 65 insertions(+), 9 deletions(-) diff --git a/tornado/test/testing_test.py b/tornado/test/testing_test.py index ded2569b..e00058ac 100644 --- a/tornado/test/testing_test.py +++ b/tornado/test/testing_test.py @@ -5,11 +5,11 @@ from __future__ import absolute_import, division, print_function, with_statement from tornado import gen, ioloop from tornado.log import app_log from tornado.testing import AsyncTestCase, gen_test, ExpectLog -from tornado.test.util import unittest - +from tornado.test.util import unittest, skipBefore35, exec_test import contextlib import os import traceback +import warnings @contextlib.contextmanager @@ -86,6 +86,26 @@ class AsyncTestCaseWrapperTest(unittest.TestCase): self.assertEqual(len(result.errors), 1) self.assertIn("should be decorated", result.errors[0][1]) + @skipBefore35 + def test_undecorated_coroutine(self): + namespace = exec_test(globals(), locals(), """ + class Test(AsyncTestCase): + async def test_coro(self): + pass + """) + + test_class = namespace['Test'] + test = test_class('test_coro') + result = unittest.TestResult() + + # Silence "RuntimeWarning: coroutine 'test_coro' was never awaited". + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + test.run(result) + + self.assertEqual(len(result.errors), 1) + self.assertIn("should be decorated", result.errors[0][1]) + def test_undecorated_generator_with_skip(self): class Test(AsyncTestCase): @unittest.skip("don't run this") @@ -228,5 +248,31 @@ class GenTest(AsyncTestCase): test_with_kwargs(self, test='test') self.finished = True + @skipBefore35 + def test_native_coroutine(self): + namespace = exec_test(globals(), locals(), """ + @gen_test + async def test(self): + self.finished = True + """) + + namespace['test'](self) + + @skipBefore35 + def test_native_coroutine_timeout(self): + # Set a short timeout and exceed it. + namespace = exec_test(globals(), locals(), """ + @gen_test(timeout=0.1) + async def test(self): + await gen.sleep(1) + """) + + try: + namespace['test'](self) + self.fail("did not get expected exception") + except ioloop.TimeoutError: + self.finished = True + + if __name__ == '__main__': unittest.main() diff --git a/tornado/testing.py b/tornado/testing.py index f5e9f153..54d76fe4 100644 --- a/tornado/testing.py +++ b/tornado/testing.py @@ -34,6 +34,7 @@ from tornado.log import gen_log, app_log from tornado.stack_context import ExceptionStackContext from tornado.util import raise_exc_info, basestring_type import functools +import inspect import logging import os import re @@ -51,6 +52,12 @@ try: except ImportError: from types import GeneratorType +if sys.version_info >= (3, 5): + iscoroutine = inspect.iscoroutine + iscoroutinefunction = inspect.iscoroutinefunction +else: + iscoroutine = iscoroutinefunction = lambda f: False + # Tornado's own test suite requires the updated unittest module # (either py27+ or unittest2) so tornado.test.util enforces # this requirement, but for other users of tornado.testing we want @@ -123,9 +130,9 @@ class _TestMethodWrapper(object): def __call__(self, *args, **kwargs): result = self.orig_method(*args, **kwargs) - if isinstance(result, GeneratorType): - raise TypeError("Generator test methods should be decorated with " - "tornado.testing.gen_test") + if isinstance(result, GeneratorType) or iscoroutine(result): + raise TypeError("Generator and coroutine test methods should be" + " decorated with tornado.testing.gen_test") elif result is not None: raise ValueError("Return value from test method ignored: %r" % result) @@ -499,13 +506,16 @@ def gen_test(func=None, timeout=None): @functools.wraps(f) def pre_coroutine(self, *args, **kwargs): result = f(self, *args, **kwargs) - if isinstance(result, GeneratorType): + if isinstance(result, GeneratorType) or iscoroutine(result): self._test_generator = result else: self._test_generator = None return result - coro = gen.coroutine(pre_coroutine) + if iscoroutinefunction(f): + coro = pre_coroutine + else: + coro = gen.coroutine(pre_coroutine) @functools.wraps(coro) def post_coroutine(self, *args, **kwargs): @@ -515,8 +525,8 @@ def gen_test(func=None, timeout=None): timeout=timeout) except TimeoutError as e: # run_sync raises an error with an unhelpful traceback. - # If we throw it back into the generator the stack trace - # will be replaced by the point where the test is stopped. + # Throw it back into the generator or coroutine so the stack + # trace is replaced by the point where the test is stopped. self._test_generator.throw(e) # In case the test contains an overly broad except clause, # we may get back here. In this case re-raise the original