From 57a3f83fc6b6fa4d9c207dc078a337260863ff99 Mon Sep 17 00:00:00 2001 From: Ben Darnell Date: Sun, 20 May 2012 22:08:59 -0700 Subject: [PATCH] Prevent leak of StackContexts in repeated gen.engine functions. Internally, StackContexts now return a deactivation callback, which can be used to prevent that StackContext from propagating further. This is used in gen.engine because the decorator doesn't know which arguments are callbacks that need to be wrapped outside of its ExceptionStackContext. This is deliberately undocumented for now. Closes #507. --- tornado/gen.py | 9 ++++++--- tornado/stack_context.py | 29 ++++++++++++++++++++--------- tornado/test/gen_test.py | 18 ++++++++++++++++++ tornado/test/stack_context_test.py | 27 +++++++++++++++++++++++++++ 4 files changed, 71 insertions(+), 12 deletions(-) diff --git a/tornado/gen.py b/tornado/gen.py index 752c3f24..506697d7 100644 --- a/tornado/gen.py +++ b/tornado/gen.py @@ -113,13 +113,14 @@ def engine(func): if runner is not None: return runner.handle_exception(typ, value, tb) return False - with ExceptionStackContext(handle_exception): + with ExceptionStackContext(handle_exception) as deactivate: gen = func(*args, **kwargs) if isinstance(gen, types.GeneratorType): - runner = Runner(gen) + runner = Runner(gen, deactivate) runner.run() return assert gen is None, gen + deactivate() # no yield, so we're done return wrapper @@ -285,8 +286,9 @@ class Runner(object): Maintains information about pending callbacks and their results. """ - def __init__(self, gen): + def __init__(self, gen, deactivate_stack_context): self.gen = gen + self.deactivate_stack_context = deactivate_stack_context self.yield_point = _NullYieldPoint() self.pending_callbacks = set() self.results = {} @@ -351,6 +353,7 @@ class Runner(object): raise LeakedCallbackError( "finished without waiting for callbacks %r" % self.pending_callbacks) + self.deactivate_stack_context() return except Exception: self.finished = True diff --git a/tornado/stack_context.py b/tornado/stack_context.py index df186999..3e0bea85 100644 --- a/tornado/stack_context.py +++ b/tornado/stack_context.py @@ -71,6 +71,7 @@ from __future__ import absolute_import, division, with_statement import contextlib import functools import itertools +import operator import sys import threading @@ -95,23 +96,25 @@ class StackContext(object): with StackContext(my_context): ''' - def __init__(self, context_factory): + def __init__(self, context_factory, _active_cell=None): self.context_factory = context_factory + self.active_cell = _active_cell or [True] # Note that some of this code is duplicated in ExceptionStackContext # below. ExceptionStackContext is more common and doesn't need # the full generality of this class. def __enter__(self): self.old_contexts = _state.contexts - # _state.contexts is a tuple of (class, arg) pairs + # _state.contexts is a tuple of (class, arg, active_cell) tuples _state.contexts = (self.old_contexts + - ((StackContext, self.context_factory),)) + ((StackContext, self.context_factory, self.active_cell),)) try: self.context = self.context_factory() self.context.__enter__() except Exception: _state.contexts = self.old_contexts raise + return lambda: operator.setitem(self.active_cell, 0, False) def __exit__(self, type, value, traceback): try: @@ -133,13 +136,16 @@ class ExceptionStackContext(object): If the exception handler returns true, the exception will be consumed and will not be propagated to other exception handlers. ''' - def __init__(self, exception_handler): + def __init__(self, exception_handler, _active_cell=None): self.exception_handler = exception_handler + self.active_cell = _active_cell or [True] def __enter__(self): self.old_contexts = _state.contexts _state.contexts = (self.old_contexts + - ((ExceptionStackContext, self.exception_handler),)) + ((ExceptionStackContext, self.exception_handler, + self.active_cell),)) + return lambda: operator.setitem(self.active_cell, 0, False) def __exit__(self, type, value, traceback): try: @@ -186,7 +192,9 @@ def wrap(fn): callback(*args, **kwargs) return if not _state.contexts: - new_contexts = [cls(arg) for (cls, arg) in contexts] + new_contexts = [cls(arg, active_cell) + for (cls, arg, active_cell) in contexts + if active_cell[0]] # If we're moving down the stack, _state.contexts is a prefix # of contexts. For each element of contexts not in that prefix, # create a new StackContext object. @@ -198,10 +206,13 @@ def wrap(fn): for a, b in itertools.izip(_state.contexts, contexts))): # contexts have been removed or changed, so start over new_contexts = ([NullContext()] + - [cls(arg) for (cls, arg) in contexts]) + [cls(arg, active_cell) + for (cls, arg, active_cell) in contexts + if active_cell[0]]) else: - new_contexts = [cls(arg) - for (cls, arg) in contexts[len(_state.contexts):]] + new_contexts = [cls(arg, active_cell) + for (cls, arg, active_cell) in contexts[len(_state.contexts):] + if active_cell[0]] if len(new_contexts) > 1: with _nested(*new_contexts): callback(*args, **kwargs) diff --git a/tornado/test/gen_test.py b/tornado/test/gen_test.py index 86d7d0d6..198190cb 100644 --- a/tornado/test/gen_test.py +++ b/tornado/test/gen_test.py @@ -249,6 +249,24 @@ class GenTest(AsyncTestCase): self.stop() self.run_gen(f) + def test_stack_context_leak(self): + # regression test: repeated invocations of a gen-based + # function should not result in accumulated stack_contexts + from tornado import stack_context + @gen.engine + def inner(callback): + yield gen.Task(self.io_loop.add_callback) + callback() + @gen.engine + def outer(): + for i in xrange(10): + yield gen.Task(inner) + stack_increase = len(stack_context._state.contexts) - initial_stack_depth + self.assertTrue(stack_increase <= 2) + self.stop() + initial_stack_depth = len(stack_context._state.contexts) + self.run_gen(outer) + class GenSequenceHandler(RequestHandler): @asynchronous diff --git a/tornado/test/stack_context_test.py b/tornado/test/stack_context_test.py index e682f6a5..f3572875 100644 --- a/tornado/test/stack_context_test.py +++ b/tornado/test/stack_context_test.py @@ -93,5 +93,32 @@ class StackContextTest(AsyncTestCase, LogTrapTestCase): library_function(final_callback) self.wait() + def test_deactivate(self): + deactivate_callbacks = [] + def f1(): + with StackContext(functools.partial(self.context, 'c1')) as c1: + deactivate_callbacks.append(c1) + self.io_loop.add_callback(f2) + def f2(): + with StackContext(functools.partial(self.context, 'c2')) as c2: + deactivate_callbacks.append(c2) + self.io_loop.add_callback(f3) + def f3(): + with StackContext(functools.partial(self.context, 'c3')) as c3: + deactivate_callbacks.append(c3) + self.io_loop.add_callback(f4) + def f4(): + self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3']) + deactivate_callbacks[1]() + # deactivating a context doesn't remove it immediately, + # but it will be missing from the next iteration + self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3']) + self.io_loop.add_callback(f5) + def f5(): + self.assertEqual(self.active_contexts, ['c1', 'c3']) + self.stop() + self.io_loop.add_callback(f1) + self.wait() + if __name__ == '__main__': unittest.main()