Prevent leak of StackContexts in repeated gen.engine functions.

Internally, StackContexts now return a deactivation callback,
which can be used to prevent that StackContext from propagating
further.  This is used in gen.engine because the decorator doesn't know
which arguments are callbacks that need to be wrapped outside of its
ExceptionStackContext.  This is deliberately undocumented for now.

Closes #507.
This commit is contained in:
Ben Darnell 2012-05-20 22:08:59 -07:00
parent 1be0cc4c2c
commit 57a3f83fc6
4 changed files with 71 additions and 12 deletions

View File

@ -113,13 +113,14 @@ def engine(func):
if runner is not None:
return runner.handle_exception(typ, value, tb)
return False
with ExceptionStackContext(handle_exception):
with ExceptionStackContext(handle_exception) as deactivate:
gen = func(*args, **kwargs)
if isinstance(gen, types.GeneratorType):
runner = Runner(gen)
runner = Runner(gen, deactivate)
runner.run()
return
assert gen is None, gen
deactivate()
# no yield, so we're done
return wrapper
@ -285,8 +286,9 @@ class Runner(object):
Maintains information about pending callbacks and their results.
"""
def __init__(self, gen):
def __init__(self, gen, deactivate_stack_context):
self.gen = gen
self.deactivate_stack_context = deactivate_stack_context
self.yield_point = _NullYieldPoint()
self.pending_callbacks = set()
self.results = {}
@ -351,6 +353,7 @@ class Runner(object):
raise LeakedCallbackError(
"finished without waiting for callbacks %r" %
self.pending_callbacks)
self.deactivate_stack_context()
return
except Exception:
self.finished = True

View File

@ -71,6 +71,7 @@ from __future__ import absolute_import, division, with_statement
import contextlib
import functools
import itertools
import operator
import sys
import threading
@ -95,23 +96,25 @@ class StackContext(object):
with StackContext(my_context):
'''
def __init__(self, context_factory):
def __init__(self, context_factory, _active_cell=None):
self.context_factory = context_factory
self.active_cell = _active_cell or [True]
# Note that some of this code is duplicated in ExceptionStackContext
# below. ExceptionStackContext is more common and doesn't need
# the full generality of this class.
def __enter__(self):
self.old_contexts = _state.contexts
# _state.contexts is a tuple of (class, arg) pairs
# _state.contexts is a tuple of (class, arg, active_cell) tuples
_state.contexts = (self.old_contexts +
((StackContext, self.context_factory),))
((StackContext, self.context_factory, self.active_cell),))
try:
self.context = self.context_factory()
self.context.__enter__()
except Exception:
_state.contexts = self.old_contexts
raise
return lambda: operator.setitem(self.active_cell, 0, False)
def __exit__(self, type, value, traceback):
try:
@ -133,13 +136,16 @@ class ExceptionStackContext(object):
If the exception handler returns true, the exception will be
consumed and will not be propagated to other exception handlers.
'''
def __init__(self, exception_handler):
def __init__(self, exception_handler, _active_cell=None):
self.exception_handler = exception_handler
self.active_cell = _active_cell or [True]
def __enter__(self):
self.old_contexts = _state.contexts
_state.contexts = (self.old_contexts +
((ExceptionStackContext, self.exception_handler),))
((ExceptionStackContext, self.exception_handler,
self.active_cell),))
return lambda: operator.setitem(self.active_cell, 0, False)
def __exit__(self, type, value, traceback):
try:
@ -186,7 +192,9 @@ def wrap(fn):
callback(*args, **kwargs)
return
if not _state.contexts:
new_contexts = [cls(arg) for (cls, arg) in contexts]
new_contexts = [cls(arg, active_cell)
for (cls, arg, active_cell) in contexts
if active_cell[0]]
# If we're moving down the stack, _state.contexts is a prefix
# of contexts. For each element of contexts not in that prefix,
# create a new StackContext object.
@ -198,10 +206,13 @@ def wrap(fn):
for a, b in itertools.izip(_state.contexts, contexts))):
# contexts have been removed or changed, so start over
new_contexts = ([NullContext()] +
[cls(arg) for (cls, arg) in contexts])
[cls(arg, active_cell)
for (cls, arg, active_cell) in contexts
if active_cell[0]])
else:
new_contexts = [cls(arg)
for (cls, arg) in contexts[len(_state.contexts):]]
new_contexts = [cls(arg, active_cell)
for (cls, arg, active_cell) in contexts[len(_state.contexts):]
if active_cell[0]]
if len(new_contexts) > 1:
with _nested(*new_contexts):
callback(*args, **kwargs)

View File

@ -249,6 +249,24 @@ class GenTest(AsyncTestCase):
self.stop()
self.run_gen(f)
def test_stack_context_leak(self):
# regression test: repeated invocations of a gen-based
# function should not result in accumulated stack_contexts
from tornado import stack_context
@gen.engine
def inner(callback):
yield gen.Task(self.io_loop.add_callback)
callback()
@gen.engine
def outer():
for i in xrange(10):
yield gen.Task(inner)
stack_increase = len(stack_context._state.contexts) - initial_stack_depth
self.assertTrue(stack_increase <= 2)
self.stop()
initial_stack_depth = len(stack_context._state.contexts)
self.run_gen(outer)
class GenSequenceHandler(RequestHandler):
@asynchronous

View File

@ -93,5 +93,32 @@ class StackContextTest(AsyncTestCase, LogTrapTestCase):
library_function(final_callback)
self.wait()
def test_deactivate(self):
deactivate_callbacks = []
def f1():
with StackContext(functools.partial(self.context, 'c1')) as c1:
deactivate_callbacks.append(c1)
self.io_loop.add_callback(f2)
def f2():
with StackContext(functools.partial(self.context, 'c2')) as c2:
deactivate_callbacks.append(c2)
self.io_loop.add_callback(f3)
def f3():
with StackContext(functools.partial(self.context, 'c3')) as c3:
deactivate_callbacks.append(c3)
self.io_loop.add_callback(f4)
def f4():
self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
deactivate_callbacks[1]()
# deactivating a context doesn't remove it immediately,
# but it will be missing from the next iteration
self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
self.io_loop.add_callback(f5)
def f5():
self.assertEqual(self.active_contexts, ['c1', 'c3'])
self.stop()
self.io_loop.add_callback(f1)
self.wait()
if __name__ == '__main__':
unittest.main()