Prevent leak of StackContexts in repeated gen.engine functions.
Internally, StackContexts now return a deactivation callback, which can be used to prevent that StackContext from propagating further. This is used in gen.engine because the decorator doesn't know which arguments are callbacks that need to be wrapped outside of its ExceptionStackContext. This is deliberately undocumented for now. Closes #507.
This commit is contained in:
parent
1be0cc4c2c
commit
57a3f83fc6
|
@ -113,13 +113,14 @@ def engine(func):
|
|||
if runner is not None:
|
||||
return runner.handle_exception(typ, value, tb)
|
||||
return False
|
||||
with ExceptionStackContext(handle_exception):
|
||||
with ExceptionStackContext(handle_exception) as deactivate:
|
||||
gen = func(*args, **kwargs)
|
||||
if isinstance(gen, types.GeneratorType):
|
||||
runner = Runner(gen)
|
||||
runner = Runner(gen, deactivate)
|
||||
runner.run()
|
||||
return
|
||||
assert gen is None, gen
|
||||
deactivate()
|
||||
# no yield, so we're done
|
||||
return wrapper
|
||||
|
||||
|
@ -285,8 +286,9 @@ class Runner(object):
|
|||
|
||||
Maintains information about pending callbacks and their results.
|
||||
"""
|
||||
def __init__(self, gen):
|
||||
def __init__(self, gen, deactivate_stack_context):
|
||||
self.gen = gen
|
||||
self.deactivate_stack_context = deactivate_stack_context
|
||||
self.yield_point = _NullYieldPoint()
|
||||
self.pending_callbacks = set()
|
||||
self.results = {}
|
||||
|
@ -351,6 +353,7 @@ class Runner(object):
|
|||
raise LeakedCallbackError(
|
||||
"finished without waiting for callbacks %r" %
|
||||
self.pending_callbacks)
|
||||
self.deactivate_stack_context()
|
||||
return
|
||||
except Exception:
|
||||
self.finished = True
|
||||
|
|
|
@ -71,6 +71,7 @@ from __future__ import absolute_import, division, with_statement
|
|||
import contextlib
|
||||
import functools
|
||||
import itertools
|
||||
import operator
|
||||
import sys
|
||||
import threading
|
||||
|
||||
|
@ -95,23 +96,25 @@ class StackContext(object):
|
|||
|
||||
with StackContext(my_context):
|
||||
'''
|
||||
def __init__(self, context_factory):
|
||||
def __init__(self, context_factory, _active_cell=None):
|
||||
self.context_factory = context_factory
|
||||
self.active_cell = _active_cell or [True]
|
||||
|
||||
# Note that some of this code is duplicated in ExceptionStackContext
|
||||
# below. ExceptionStackContext is more common and doesn't need
|
||||
# the full generality of this class.
|
||||
def __enter__(self):
|
||||
self.old_contexts = _state.contexts
|
||||
# _state.contexts is a tuple of (class, arg) pairs
|
||||
# _state.contexts is a tuple of (class, arg, active_cell) tuples
|
||||
_state.contexts = (self.old_contexts +
|
||||
((StackContext, self.context_factory),))
|
||||
((StackContext, self.context_factory, self.active_cell),))
|
||||
try:
|
||||
self.context = self.context_factory()
|
||||
self.context.__enter__()
|
||||
except Exception:
|
||||
_state.contexts = self.old_contexts
|
||||
raise
|
||||
return lambda: operator.setitem(self.active_cell, 0, False)
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
try:
|
||||
|
@ -133,13 +136,16 @@ class ExceptionStackContext(object):
|
|||
If the exception handler returns true, the exception will be
|
||||
consumed and will not be propagated to other exception handlers.
|
||||
'''
|
||||
def __init__(self, exception_handler):
|
||||
def __init__(self, exception_handler, _active_cell=None):
|
||||
self.exception_handler = exception_handler
|
||||
self.active_cell = _active_cell or [True]
|
||||
|
||||
def __enter__(self):
|
||||
self.old_contexts = _state.contexts
|
||||
_state.contexts = (self.old_contexts +
|
||||
((ExceptionStackContext, self.exception_handler),))
|
||||
((ExceptionStackContext, self.exception_handler,
|
||||
self.active_cell),))
|
||||
return lambda: operator.setitem(self.active_cell, 0, False)
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
try:
|
||||
|
@ -186,7 +192,9 @@ def wrap(fn):
|
|||
callback(*args, **kwargs)
|
||||
return
|
||||
if not _state.contexts:
|
||||
new_contexts = [cls(arg) for (cls, arg) in contexts]
|
||||
new_contexts = [cls(arg, active_cell)
|
||||
for (cls, arg, active_cell) in contexts
|
||||
if active_cell[0]]
|
||||
# If we're moving down the stack, _state.contexts is a prefix
|
||||
# of contexts. For each element of contexts not in that prefix,
|
||||
# create a new StackContext object.
|
||||
|
@ -198,10 +206,13 @@ def wrap(fn):
|
|||
for a, b in itertools.izip(_state.contexts, contexts))):
|
||||
# contexts have been removed or changed, so start over
|
||||
new_contexts = ([NullContext()] +
|
||||
[cls(arg) for (cls, arg) in contexts])
|
||||
[cls(arg, active_cell)
|
||||
for (cls, arg, active_cell) in contexts
|
||||
if active_cell[0]])
|
||||
else:
|
||||
new_contexts = [cls(arg)
|
||||
for (cls, arg) in contexts[len(_state.contexts):]]
|
||||
new_contexts = [cls(arg, active_cell)
|
||||
for (cls, arg, active_cell) in contexts[len(_state.contexts):]
|
||||
if active_cell[0]]
|
||||
if len(new_contexts) > 1:
|
||||
with _nested(*new_contexts):
|
||||
callback(*args, **kwargs)
|
||||
|
|
|
@ -249,6 +249,24 @@ class GenTest(AsyncTestCase):
|
|||
self.stop()
|
||||
self.run_gen(f)
|
||||
|
||||
def test_stack_context_leak(self):
|
||||
# regression test: repeated invocations of a gen-based
|
||||
# function should not result in accumulated stack_contexts
|
||||
from tornado import stack_context
|
||||
@gen.engine
|
||||
def inner(callback):
|
||||
yield gen.Task(self.io_loop.add_callback)
|
||||
callback()
|
||||
@gen.engine
|
||||
def outer():
|
||||
for i in xrange(10):
|
||||
yield gen.Task(inner)
|
||||
stack_increase = len(stack_context._state.contexts) - initial_stack_depth
|
||||
self.assertTrue(stack_increase <= 2)
|
||||
self.stop()
|
||||
initial_stack_depth = len(stack_context._state.contexts)
|
||||
self.run_gen(outer)
|
||||
|
||||
|
||||
class GenSequenceHandler(RequestHandler):
|
||||
@asynchronous
|
||||
|
|
|
@ -93,5 +93,32 @@ class StackContextTest(AsyncTestCase, LogTrapTestCase):
|
|||
library_function(final_callback)
|
||||
self.wait()
|
||||
|
||||
def test_deactivate(self):
|
||||
deactivate_callbacks = []
|
||||
def f1():
|
||||
with StackContext(functools.partial(self.context, 'c1')) as c1:
|
||||
deactivate_callbacks.append(c1)
|
||||
self.io_loop.add_callback(f2)
|
||||
def f2():
|
||||
with StackContext(functools.partial(self.context, 'c2')) as c2:
|
||||
deactivate_callbacks.append(c2)
|
||||
self.io_loop.add_callback(f3)
|
||||
def f3():
|
||||
with StackContext(functools.partial(self.context, 'c3')) as c3:
|
||||
deactivate_callbacks.append(c3)
|
||||
self.io_loop.add_callback(f4)
|
||||
def f4():
|
||||
self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
|
||||
deactivate_callbacks[1]()
|
||||
# deactivating a context doesn't remove it immediately,
|
||||
# but it will be missing from the next iteration
|
||||
self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
|
||||
self.io_loop.add_callback(f5)
|
||||
def f5():
|
||||
self.assertEqual(self.active_contexts, ['c1', 'c3'])
|
||||
self.stop()
|
||||
self.io_loop.add_callback(f1)
|
||||
self.wait()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue