fork: ensure importer handle is installed on the new router.

This commit is contained in:
David Wilson 2018-04-09 16:37:51 +01:00
parent e9f94e1bbb
commit 3682ac6e29
2 changed files with 30 additions and 7 deletions

View File

@ -509,11 +509,6 @@ class Importer(object):
# Presence of an entry in this map indicates in-flight GET_MODULE.
self._callbacks = {}
router.add_handler(
fn=self._on_load_module,
handle=LOAD_MODULE,
policy=has_parent_authority,
)
self._cache = {}
if core_src:
self._cache['mitogen.core'] = (
@ -523,6 +518,14 @@ class Importer(object):
zlib.compress(core_src, 9),
[],
)
self._install_handler(router)
def _install_handler(self, router):
router.add_handler(
fn=self._on_load_module,
handle=LOAD_MODULE,
policy=has_parent_authority,
)
def __repr__(self):
return 'Importer()'
@ -1542,7 +1545,10 @@ class ExternalContext(object):
enable_debug_logging()
def _setup_importer(self, importer, core_src_fd, whitelist, blacklist):
if not importer:
if importer:
importer._install_handler(self.router)
importer._context = self.parent
else:
if core_src_fd:
fp = os.fdopen(101, 'r', 1)
try:
@ -1559,7 +1565,6 @@ class ExternalContext(object):
core_src, whitelist, blacklist)
self.importer = importer
self.importer._context = self.parent # for fork().
self.router.importer = importer
sys.meta_path.append(self.importer)

View File

@ -40,6 +40,15 @@ def RAND_pseudo_bytes(n=32):
return buf[:]
def exercise_importer(n):
"""
Ensure the forked child has a sensible importer.
"""
sys.path.remove(testlib.DATA_DIR)
import simple_pkg.a
return simple_pkg.a.subtract_one_add_two(n)
class ForkTest(testlib.RouterMixin, unittest2.TestCase):
def test_okay(self):
context = self.router.fork()
@ -57,6 +66,10 @@ class ForkTest(testlib.RouterMixin, unittest2.TestCase):
self.assertNotEqual(context.call(RAND_pseudo_bytes),
RAND_pseudo_bytes())
def test_importer(self):
context = self.router.fork()
self.assertEqual(2, context.call(exercise_importer, 1))
class DoubleChildTest(testlib.RouterMixin, unittest2.TestCase):
def test_okay(self):
@ -74,6 +87,11 @@ class DoubleChildTest(testlib.RouterMixin, unittest2.TestCase):
c2 = self.router.fork(via=c1)
self.assertEquals(123, c2.call(ping))
def test_importer(self):
c1 = self.router.fork(name='c1')
c2 = self.router.fork(name='c2', via=c1)
self.assertEqual(2, c2.call(exercise_importer, 1))
if __name__ == '__main__':
unittest2.main()