diff --git a/mitogen/core.py b/mitogen/core.py index d60fbafc..2b2ba589 100644 --- a/mitogen/core.py +++ b/mitogen/core.py @@ -509,11 +509,6 @@ class Importer(object): # Presence of an entry in this map indicates in-flight GET_MODULE. self._callbacks = {} - router.add_handler( - fn=self._on_load_module, - handle=LOAD_MODULE, - policy=has_parent_authority, - ) self._cache = {} if core_src: self._cache['mitogen.core'] = ( @@ -523,6 +518,14 @@ class Importer(object): zlib.compress(core_src, 9), [], ) + self._install_handler(router) + + def _install_handler(self, router): + router.add_handler( + fn=self._on_load_module, + handle=LOAD_MODULE, + policy=has_parent_authority, + ) def __repr__(self): return 'Importer()' @@ -1542,7 +1545,10 @@ class ExternalContext(object): enable_debug_logging() def _setup_importer(self, importer, core_src_fd, whitelist, blacklist): - if not importer: + if importer: + importer._install_handler(self.router) + importer._context = self.parent + else: if core_src_fd: fp = os.fdopen(101, 'r', 1) try: @@ -1559,7 +1565,6 @@ class ExternalContext(object): core_src, whitelist, blacklist) self.importer = importer - self.importer._context = self.parent # for fork(). self.router.importer = importer sys.meta_path.append(self.importer) diff --git a/tests/fork_test.py b/tests/fork_test.py index 08c099ba..c5e19adb 100644 --- a/tests/fork_test.py +++ b/tests/fork_test.py @@ -40,6 +40,15 @@ def RAND_pseudo_bytes(n=32): return buf[:] +def exercise_importer(n): + """ + Ensure the forked child has a sensible importer. + """ + sys.path.remove(testlib.DATA_DIR) + import simple_pkg.a + return simple_pkg.a.subtract_one_add_two(n) + + class ForkTest(testlib.RouterMixin, unittest2.TestCase): def test_okay(self): context = self.router.fork() @@ -57,6 +66,10 @@ class ForkTest(testlib.RouterMixin, unittest2.TestCase): self.assertNotEqual(context.call(RAND_pseudo_bytes), RAND_pseudo_bytes()) + def test_importer(self): + context = self.router.fork() + self.assertEqual(2, context.call(exercise_importer, 1)) + class DoubleChildTest(testlib.RouterMixin, unittest2.TestCase): def test_okay(self): @@ -74,6 +87,11 @@ class DoubleChildTest(testlib.RouterMixin, unittest2.TestCase): c2 = self.router.fork(via=c1) self.assertEquals(123, c2.call(ping)) + def test_importer(self): + c1 = self.router.fork(name='c1') + c2 = self.router.fork(name='c2', via=c1) + self.assertEqual(2, c2.call(exercise_importer, 1)) + if __name__ == '__main__': unittest2.main()