Fix error handling with the combination of @asynchronous and @gen.coroutine.

Also make self.finish() optional for coroutines since we can finish
if the future resolves successfully.
This commit is contained in:
Ben Darnell 2013-03-23 21:25:14 -04:00
parent 8ddc760f18
commit e00e4c61fe
2 changed files with 80 additions and 1 deletions

View File

@ -647,6 +647,41 @@ class GenSequenceHandler(RequestHandler):
self.finish("3") self.finish("3")
class GenCoroutineSequenceHandler(RequestHandler):
@asynchronous
@gen.coroutine
def get(self):
self.io_loop = self.request.connection.stream.io_loop
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
self.write("1")
self.io_loop.add_callback((yield gen.Callback("k2")))
yield gen.Wait("k2")
self.write("2")
# reuse an old key
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
self.finish("3")
class GenCoroutineUnfinishedSequenceHandler(RequestHandler):
@asynchronous
@gen.coroutine
def get(self):
self.io_loop = self.request.connection.stream.io_loop
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
self.write("1")
self.io_loop.add_callback((yield gen.Callback("k2")))
yield gen.Wait("k2")
self.write("2")
# reuse an old key
self.io_loop.add_callback((yield gen.Callback("k1")))
yield gen.Wait("k1")
# just write, don't finish
self.write("3")
class GenTaskHandler(RequestHandler): class GenTaskHandler(RequestHandler):
@asynchronous @asynchronous
@gen.engine @gen.engine
@ -668,6 +703,16 @@ class GenExceptionHandler(RequestHandler):
raise Exception("oops") raise Exception("oops")
class GenCoroutineExceptionHandler(RequestHandler):
@asynchronous
@gen.coroutine
def get(self):
# This test depends on the order of the two decorators.
io_loop = self.request.connection.stream.io_loop
yield gen.Task(io_loop.add_callback)
raise Exception("oops")
class GenYieldExceptionHandler(RequestHandler): class GenYieldExceptionHandler(RequestHandler):
@asynchronous @asynchronous
@gen.engine @gen.engine
@ -688,8 +733,12 @@ class GenWebTest(AsyncHTTPTestCase):
def get_app(self): def get_app(self):
return Application([ return Application([
('/sequence', GenSequenceHandler), ('/sequence', GenSequenceHandler),
('/coroutine_sequence', GenCoroutineSequenceHandler),
('/coroutine_unfinished_sequence',
GenCoroutineUnfinishedSequenceHandler),
('/task', GenTaskHandler), ('/task', GenTaskHandler),
('/exception', GenExceptionHandler), ('/exception', GenExceptionHandler),
('/coroutine_exception', GenCoroutineExceptionHandler),
('/yield_exception', GenYieldExceptionHandler), ('/yield_exception', GenYieldExceptionHandler),
]) ])
@ -697,6 +746,14 @@ class GenWebTest(AsyncHTTPTestCase):
response = self.fetch('/sequence') response = self.fetch('/sequence')
self.assertEqual(response.body, b"123") self.assertEqual(response.body, b"123")
def test_coroutine_sequence_handler(self):
response = self.fetch('/coroutine_sequence')
self.assertEqual(response.body, b"123")
def test_coroutine_unfinished_sequence_handler(self):
response = self.fetch('/coroutine_unfinished_sequence')
self.assertEqual(response.body, b"123")
def test_task_handler(self): def test_task_handler(self):
response = self.fetch('/task?url=%s' % url_escape(self.get_url('/sequence'))) response = self.fetch('/task?url=%s' % url_escape(self.get_url('/sequence')))
self.assertEqual(response.body, b"got response: 123") self.assertEqual(response.body, b"got response: 123")
@ -707,6 +764,12 @@ class GenWebTest(AsyncHTTPTestCase):
response = self.fetch('/exception') response = self.fetch('/exception')
self.assertEqual(500, response.code) self.assertEqual(500, response.code)
def test_coroutine_exception_handler(self):
# Make sure we get an error and not a timeout
with ExpectLog(app_log, "Uncaught exception GET /coroutine_exception"):
response = self.fetch('/coroutine_exception')
self.assertEqual(500, response.code)
def test_yield_exception_handler(self): def test_yield_exception_handler(self):
response = self.fetch('/yield_exception') response = self.fetch('/yield_exception')
self.assertEqual(response.body, b'ok') self.assertEqual(response.body, b'ok')

View File

@ -73,6 +73,7 @@ import traceback
import types import types
import uuid import uuid
from tornado.concurrent import Future
from tornado import escape from tornado import escape
from tornado import httputil from tornado import httputil
from tornado import locale from tornado import locale
@ -1165,6 +1166,8 @@ def asynchronous(method):
self.finish() self.finish()
""" """
# Delay the IOLoop import because it's not available on app engine.
from tornado.ioloop import IOLoop
@functools.wraps(method) @functools.wraps(method)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
if self.application._wsgi: if self.application._wsgi:
@ -1172,7 +1175,20 @@ def asynchronous(method):
self._auto_finish = False self._auto_finish = False
with stack_context.ExceptionStackContext( with stack_context.ExceptionStackContext(
self._stack_context_handle_exception): self._stack_context_handle_exception):
return method(self, *args, **kwargs) result = method(self, *args, **kwargs)
if isinstance(result, Future):
# If @asynchronous is used with @gen.coroutine, (but
# not @gen.engine), we can automatically finish the
# request when the future resolves. Additionally,
# the Future will swallow any exceptions so we need
# to throw them back out to the stack context to finish
# the request.
def future_complete(f):
f.result()
if not self._finished:
self.finish()
IOLoop.current().add_future(result, future_complete)
return result
return wrapper return wrapper