diff --git a/econtext/master.py b/econtext/master.py index b8077363..5867fbb5 100644 --- a/econtext/master.py +++ b/econtext/master.py @@ -6,6 +6,7 @@ sent to any context that will be used to establish additional child contexts. import commands import getpass +import imp import inspect import logging import os @@ -111,31 +112,77 @@ class ModuleResponder(object): def __repr__(self): return 'ModuleResponder(%r)' % (self._context,) + def _get_module_via_pkgutil(self, fullname): + """Attempt to fetch source code via pkgutil. In an ideal world, this + would be the only required implementation of get_module().""" + loader = pkgutil.find_loader(fullname) + LOG.debug('pkgutil.find_loader(%r) -> %r', fullname, loader) + if not loader: + return + + path = loader.get_filename(fullname) + source = loader.get_source(fullname) + if path and source: + return path, source, loader.is_package(fullname) + + def _get_module_via_sys_modules(self, fullname): + """Attempt to fetch source code via sys.modules. This is specifically + to support __main__, but it may catch a few more cases.""" + if fullname not in sys.modules: + LOG.debug('%r does not appear in sys.modules', fullname) + return + + is_pkg = hasattr(sys.modules[fullname], '__path__') + try: + source = inspect.getsource(sys.modules[fullname]) + except IOError: + # Work around inspect.getsourcelines() bug. + if not is_pkg: + raise + source = '\n' + + return (sys.modules[fullname].__file__.rstrip('co'), + source, + hasattr(sys.modules[fullname], '__path__')) + + def _get_module_via_parent_enumeration(self, fullname): + """Attempt to fetch source code by examining the module's (hopefully + less insane) parent package. Required for ansible.compat.six.""" + pkgname, _, modname = fullname.rpartition('.') + pkg = sys.modules.get(pkgname) + if pkg is None or not hasattr(pkg, '__file__'): + return + + pkg_path = os.path.dirname(pkg.__file__) + try: + fp, path, ext = imp.find_module(modname, [pkg_path]) + return path, fp.read(), False + except ImportError: + LOG.debug('imp.find_module(%r, %r) -> %s', modname, [pkg_path], e) + + get_module_methods = [_get_module_via_pkgutil, + _get_module_via_sys_modules, + _get_module_via_parent_enumeration] + def get_module(self, data): + LOG.debug('%r.get_module(%r)', self, data) if data == econtext.core._DEAD: return reply_to, fullname = data - LOG.debug('%r.get_module(%r, %r)', self, reply_to, fullname) try: - loader = pkgutil.find_loader(fullname) - LOG.debug('pkgutil.find_loader(%r) -> %r', fullname, loader) - if loader is None: - raise ImportError('pkgutil provides no loader for %r' % - (fullname,)) + for method in self.get_module_methods: + tup = method(self, fullname) + if tup: + break - path = loader.get_filename(fullname) - LOG.debug('%r.get_filename(%r) -> %r', loader, fullname, path) - - # Handle __main__ specially. - if path is None and fullname in sys.modules: - path = sys.modules[fullname].__file__.rstrip('co') - source = inspect.getsource(sys.modules[fullname]) - is_pkg = hasattr(sys.modules[fullname], '__path__') - else: - source = loader.get_source(fullname) - is_pkg = loader.is_package(fullname) + try: + path, source, is_pkg = tup + except TypeError: + raise ImportError('could not find %r' % (fullname,)) + LOG.debug('%r returned for %r: (%r, .., %r)', + method, fullname, path, is_pkg) if is_pkg: pkg_present = get_child_modules(path, fullname) LOG.debug('get_child_modules(%r, %r) -> %r', diff --git a/tests/responder_test.py b/tests/responder_test.py index a8a7734f..104b84c9 100644 --- a/tests/responder_test.py +++ b/tests/responder_test.py @@ -4,6 +4,7 @@ import subprocess import unittest import sys +import econtext.master import econtext.master import testlib @@ -11,7 +12,7 @@ import plain_old_module import simple_pkg.a -class ModuleTest(testlib.BrokerMixin, unittest.TestCase): +class GoodModulesTest(testlib.BrokerMixin, unittest.TestCase): def test_plain_old_module(self): # The simplest case: a top-level module with no interesting imports or # package machinery damage. @@ -33,7 +34,7 @@ class ModuleTest(testlib.BrokerMixin, unittest.TestCase): self.assertEquals(output, "['__main__', 50]\n") -class BrokenPackagesTest(unittest.TestCase): +class BrokenModulesTest(unittest.TestCase): def test_ansible_six_messed_up_path(self): # The copy of six.py shipped with Ansible appears in a package whose # __path__ subsequently ends up empty, which prevents pkgutil from @@ -50,4 +51,4 @@ class BrokenPackagesTest(unittest.TestCase): call = context.enqueue.mock_calls[0] reply_to, data = call[1] self.assertEquals(50, reply_to) - self.assertTrue(isinstance(data, str)) + self.assertTrue(isinstance(data, tuple))