diff --git a/tornado/gen.py b/tornado/gen.py index 752c3f24..506697d7 100644 --- a/tornado/gen.py +++ b/tornado/gen.py @@ -113,13 +113,14 @@ def engine(func): if runner is not None: return runner.handle_exception(typ, value, tb) return False - with ExceptionStackContext(handle_exception): + with ExceptionStackContext(handle_exception) as deactivate: gen = func(*args, **kwargs) if isinstance(gen, types.GeneratorType): - runner = Runner(gen) + runner = Runner(gen, deactivate) runner.run() return assert gen is None, gen + deactivate() # no yield, so we're done return wrapper @@ -285,8 +286,9 @@ class Runner(object): Maintains information about pending callbacks and their results. """ - def __init__(self, gen): + def __init__(self, gen, deactivate_stack_context): self.gen = gen + self.deactivate_stack_context = deactivate_stack_context self.yield_point = _NullYieldPoint() self.pending_callbacks = set() self.results = {} @@ -351,6 +353,7 @@ class Runner(object): raise LeakedCallbackError( "finished without waiting for callbacks %r" % self.pending_callbacks) + self.deactivate_stack_context() return except Exception: self.finished = True diff --git a/tornado/stack_context.py b/tornado/stack_context.py index df186999..3e0bea85 100644 --- a/tornado/stack_context.py +++ b/tornado/stack_context.py @@ -71,6 +71,7 @@ from __future__ import absolute_import, division, with_statement import contextlib import functools import itertools +import operator import sys import threading @@ -95,23 +96,25 @@ class StackContext(object): with StackContext(my_context): ''' - def __init__(self, context_factory): + def __init__(self, context_factory, _active_cell=None): self.context_factory = context_factory + self.active_cell = _active_cell or [True] # Note that some of this code is duplicated in ExceptionStackContext # below. ExceptionStackContext is more common and doesn't need # the full generality of this class. def __enter__(self): self.old_contexts = _state.contexts - # _state.contexts is a tuple of (class, arg) pairs + # _state.contexts is a tuple of (class, arg, active_cell) tuples _state.contexts = (self.old_contexts + - ((StackContext, self.context_factory),)) + ((StackContext, self.context_factory, self.active_cell),)) try: self.context = self.context_factory() self.context.__enter__() except Exception: _state.contexts = self.old_contexts raise + return lambda: operator.setitem(self.active_cell, 0, False) def __exit__(self, type, value, traceback): try: @@ -133,13 +136,16 @@ class ExceptionStackContext(object): If the exception handler returns true, the exception will be consumed and will not be propagated to other exception handlers. ''' - def __init__(self, exception_handler): + def __init__(self, exception_handler, _active_cell=None): self.exception_handler = exception_handler + self.active_cell = _active_cell or [True] def __enter__(self): self.old_contexts = _state.contexts _state.contexts = (self.old_contexts + - ((ExceptionStackContext, self.exception_handler),)) + ((ExceptionStackContext, self.exception_handler, + self.active_cell),)) + return lambda: operator.setitem(self.active_cell, 0, False) def __exit__(self, type, value, traceback): try: @@ -186,7 +192,9 @@ def wrap(fn): callback(*args, **kwargs) return if not _state.contexts: - new_contexts = [cls(arg) for (cls, arg) in contexts] + new_contexts = [cls(arg, active_cell) + for (cls, arg, active_cell) in contexts + if active_cell[0]] # If we're moving down the stack, _state.contexts is a prefix # of contexts. For each element of contexts not in that prefix, # create a new StackContext object. @@ -198,10 +206,13 @@ def wrap(fn): for a, b in itertools.izip(_state.contexts, contexts))): # contexts have been removed or changed, so start over new_contexts = ([NullContext()] + - [cls(arg) for (cls, arg) in contexts]) + [cls(arg, active_cell) + for (cls, arg, active_cell) in contexts + if active_cell[0]]) else: - new_contexts = [cls(arg) - for (cls, arg) in contexts[len(_state.contexts):]] + new_contexts = [cls(arg, active_cell) + for (cls, arg, active_cell) in contexts[len(_state.contexts):] + if active_cell[0]] if len(new_contexts) > 1: with _nested(*new_contexts): callback(*args, **kwargs) diff --git a/tornado/test/gen_test.py b/tornado/test/gen_test.py index 86d7d0d6..198190cb 100644 --- a/tornado/test/gen_test.py +++ b/tornado/test/gen_test.py @@ -249,6 +249,24 @@ class GenTest(AsyncTestCase): self.stop() self.run_gen(f) + def test_stack_context_leak(self): + # regression test: repeated invocations of a gen-based + # function should not result in accumulated stack_contexts + from tornado import stack_context + @gen.engine + def inner(callback): + yield gen.Task(self.io_loop.add_callback) + callback() + @gen.engine + def outer(): + for i in xrange(10): + yield gen.Task(inner) + stack_increase = len(stack_context._state.contexts) - initial_stack_depth + self.assertTrue(stack_increase <= 2) + self.stop() + initial_stack_depth = len(stack_context._state.contexts) + self.run_gen(outer) + class GenSequenceHandler(RequestHandler): @asynchronous diff --git a/tornado/test/stack_context_test.py b/tornado/test/stack_context_test.py index e682f6a5..f3572875 100644 --- a/tornado/test/stack_context_test.py +++ b/tornado/test/stack_context_test.py @@ -93,5 +93,32 @@ class StackContextTest(AsyncTestCase, LogTrapTestCase): library_function(final_callback) self.wait() + def test_deactivate(self): + deactivate_callbacks = [] + def f1(): + with StackContext(functools.partial(self.context, 'c1')) as c1: + deactivate_callbacks.append(c1) + self.io_loop.add_callback(f2) + def f2(): + with StackContext(functools.partial(self.context, 'c2')) as c2: + deactivate_callbacks.append(c2) + self.io_loop.add_callback(f3) + def f3(): + with StackContext(functools.partial(self.context, 'c3')) as c3: + deactivate_callbacks.append(c3) + self.io_loop.add_callback(f4) + def f4(): + self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3']) + deactivate_callbacks[1]() + # deactivating a context doesn't remove it immediately, + # but it will be missing from the next iteration + self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3']) + self.io_loop.add_callback(f5) + def f5(): + self.assertEqual(self.active_contexts, ['c1', 'c3']) + self.stop() + self.io_loop.add_callback(f1) + self.wait() + if __name__ == '__main__': unittest.main()