importer: module whitelist/blacklist support
Hoped to avoid it, but it's the obvious solution for Ansible.
This commit is contained in:
parent
4a20a57552
commit
68b34fa8f2
|
@ -394,7 +394,7 @@ class Importer(object):
|
|||
|
||||
:param context: Context to communicate via.
|
||||
"""
|
||||
def __init__(self, router, context, core_src):
|
||||
def __init__(self, router, context, core_src, whitelist=(), blacklist=()):
|
||||
self._context = context
|
||||
self._present = {'mitogen': [
|
||||
'mitogen.compat',
|
||||
|
@ -407,6 +407,15 @@ class Importer(object):
|
|||
'mitogen.utils',
|
||||
]}
|
||||
self._lock = threading.Lock()
|
||||
self.whitelist = whitelist or ['']
|
||||
self.blacklist = list(blacklist) + [
|
||||
# 2.x generates needless imports for 'builtins', while 3.x does the
|
||||
# same for '__builtin__'. The correct one is built-in, the other
|
||||
# always a negative round-trip.
|
||||
'builtins',
|
||||
'__builtin__',
|
||||
]
|
||||
|
||||
# Presence of an entry in this map indicates in-flight GET_MODULE.
|
||||
self._callbacks = {}
|
||||
router.add_handler(self._on_load_module, LOAD_MODULE)
|
||||
|
@ -451,12 +460,9 @@ class Importer(object):
|
|||
finally:
|
||||
del _tls.running
|
||||
|
||||
def _load_module_hacks(self, fullname):
|
||||
if fullname in ('builtins', '__builtin__'):
|
||||
# Python 2.x will generate needless imports for 'builtins', while
|
||||
# Python 3.x will generate needless imports for '__builtin__'. The
|
||||
# correct one is already present in sys.modules, the other is
|
||||
# always a negative round-trip.
|
||||
def _refuse_imports(self, fullname):
|
||||
if ((not any(fullname.startswith(s) for s in self.whitelist)) or
|
||||
(any(fullname.startswith(s) for s in self.blacklist))):
|
||||
raise ImportError('Refused')
|
||||
|
||||
f = sys._getframe(2)
|
||||
|
@ -515,7 +521,7 @@ class Importer(object):
|
|||
|
||||
def load_module(self, fullname):
|
||||
_v and LOG.debug('Importer.load_module(%r)', fullname)
|
||||
self._load_module_hacks(fullname)
|
||||
self._refuse_imports(fullname)
|
||||
|
||||
event = threading.Event()
|
||||
self._request_module(fullname, event.set)
|
||||
|
@ -1260,7 +1266,7 @@ class ExternalContext(object):
|
|||
if debug:
|
||||
enable_debug_logging()
|
||||
|
||||
def _setup_importer(self, core_src_fd):
|
||||
def _setup_importer(self, core_src_fd, whitelist, blacklist):
|
||||
if core_src_fd:
|
||||
with os.fdopen(101, 'r', 1) as fp:
|
||||
core_size = int(fp.readline())
|
||||
|
@ -1271,7 +1277,9 @@ class ExternalContext(object):
|
|||
else:
|
||||
core_src = None
|
||||
|
||||
self.importer = Importer(self.router, self.parent, core_src)
|
||||
self.importer = Importer(self.router, self.parent, core_src,
|
||||
whitelist, blacklist)
|
||||
self.router.importer = self.importer
|
||||
sys.meta_path.append(self.importer)
|
||||
|
||||
def _setup_package(self, context_id, parent_ids):
|
||||
|
@ -1328,12 +1336,13 @@ class ExternalContext(object):
|
|||
self.dispatch_stopped = True
|
||||
|
||||
def main(self, parent_ids, context_id, debug, profiling, log_level,
|
||||
in_fd=100, out_fd=1, core_src_fd=101, setup_stdio=True):
|
||||
in_fd=100, out_fd=1, core_src_fd=101, setup_stdio=True,
|
||||
whitelist=(), blacklist=()):
|
||||
self._setup_master(profiling, parent_ids[0], context_id, in_fd, out_fd)
|
||||
try:
|
||||
try:
|
||||
self._setup_logging(debug, log_level)
|
||||
self._setup_importer(core_src_fd)
|
||||
self._setup_importer(core_src_fd, whitelist, blacklist)
|
||||
self._setup_package(context_id, parent_ids)
|
||||
if setup_stdio:
|
||||
self._setup_stdio()
|
||||
|
@ -1342,7 +1351,7 @@ class ExternalContext(object):
|
|||
|
||||
sys.executable = os.environ.pop('ARGV0', sys.executable)
|
||||
_v and LOG.debug('Connected to %s; my ID is %r, PID is %r',
|
||||
self.parent, context_id, os.getpid())
|
||||
self.parent, context_id, os.getpid())
|
||||
_v and LOG.debug('Recovered sys.executable: %r', sys.executable)
|
||||
|
||||
_profile_hook('main', self._dispatch_calls)
|
||||
|
|
|
@ -341,17 +341,17 @@ def run(dest, router, args, deadline=None, econtext=None):
|
|||
fp.write('#!%s\n' % (sys.executable,))
|
||||
fp.write(inspect.getsource(mitogen.core))
|
||||
fp.write('\n')
|
||||
fp.write('ExternalContext().main%r\n' % ((
|
||||
parent_ids, # parent_ids
|
||||
context_id, # context_id
|
||||
router.debug, # debug
|
||||
router.profiling, # profiling
|
||||
logging.getLogger().level, # log_level
|
||||
sock2.fileno(), # in_fd
|
||||
sock2.fileno(), # out_fd
|
||||
None, # core_src_fd
|
||||
False, # setup_stdio
|
||||
),))
|
||||
fp.write('ExternalContext().main(**%r)\n' % ({
|
||||
'parent_ids': parent_ids,
|
||||
'context_id': context_id,
|
||||
'debug': router.debug,
|
||||
'profiling': router.profiling,
|
||||
'log_level': mitogen.parent.get_log_level(),
|
||||
'in_fd': sock2.fileno(),
|
||||
'out_fd': sock2.fileno(),
|
||||
'core_src_fd': None,
|
||||
'setup_stdio': False,
|
||||
},))
|
||||
finally:
|
||||
fp.close()
|
||||
|
||||
|
|
|
@ -441,6 +441,8 @@ class ModuleResponder(object):
|
|||
self._router = router
|
||||
self._finder = ModuleFinder()
|
||||
self._cache = {} # fullname -> pickled
|
||||
self.blacklist = []
|
||||
self.whitelist = []
|
||||
router.add_handler(self._on_get_module, mitogen.core.GET_MODULE)
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -448,6 +450,12 @@ class ModuleResponder(object):
|
|||
|
||||
MAIN_RE = re.compile(r'^if\s+__name__\s*==\s*.__main__.\s*:', re.M)
|
||||
|
||||
def whitelist_prefix(self, fullname):
|
||||
self.whitelist.append(fullname)
|
||||
|
||||
def blacklist_prefix(self, fullname):
|
||||
self.blacklist.append(fullname)
|
||||
|
||||
def neutralize_main(self, src):
|
||||
"""Given the source for the __main__ module, try to find where it
|
||||
begins conditional execution based on a "if __name__ == '__main__'"
|
||||
|
@ -458,6 +466,9 @@ class ModuleResponder(object):
|
|||
return src
|
||||
|
||||
def _build_tuple(self, fullname):
|
||||
if fullname in self._blacklist:
|
||||
raise ImportError('blacklisted')
|
||||
|
||||
if fullname in self._cache:
|
||||
return self._cache[fullname]
|
||||
|
||||
|
|
|
@ -63,6 +63,10 @@ class Argv(object):
|
|||
return ' '.join(map(self.escape, self.argv))
|
||||
|
||||
|
||||
def get_log_level():
|
||||
return (LOG.level or logging.getLogger().level or logging.INFO)
|
||||
|
||||
|
||||
def minimize_source(source):
|
||||
subber = lambda match: '""' + ('\n' * match.group(0).count('\n'))
|
||||
source = DOCSTRING_RE.sub(subber, source)
|
||||
|
@ -336,14 +340,17 @@ class Stream(mitogen.core.Stream):
|
|||
def get_preamble(self):
|
||||
parent_ids = mitogen.parent_ids[:]
|
||||
parent_ids.insert(0, mitogen.context_id)
|
||||
|
||||
source = inspect.getsource(mitogen.core)
|
||||
source += '\nExternalContext().main%r\n' % ((
|
||||
parent_ids, # parent_ids
|
||||
self.remote_id, # context_id
|
||||
self.debug,
|
||||
self.profiling,
|
||||
LOG.level or logging.getLogger().level or logging.INFO,
|
||||
),)
|
||||
source += '\nExternalContext().main(**%r)\n' % ({
|
||||
'parent_ids': parent_ids,
|
||||
'context_id': self.remote_id,
|
||||
'debug': self.debug,
|
||||
'profiling': self.profiling,
|
||||
'log_level': get_log_level(),
|
||||
'whitelist': self._router.get_module_whitelist(),
|
||||
'blacklist': self._router.get_module_blacklist(),
|
||||
},)
|
||||
|
||||
compressed = zlib.compress(minimize_source(source))
|
||||
return str(len(compressed)) + '\n' + compressed
|
||||
|
@ -385,6 +392,16 @@ class ChildIdAllocator(object):
|
|||
class Router(mitogen.core.Router):
|
||||
context_class = mitogen.core.Context
|
||||
|
||||
def get_module_blacklist(self):
|
||||
if mitogen.context_id == 0:
|
||||
return self.responder.blacklist
|
||||
return self.importer.blacklist
|
||||
|
||||
def get_module_whitelist(self):
|
||||
if mitogen.context_id == 0:
|
||||
return self.responder.whitelist
|
||||
return self.importer.whitelist
|
||||
|
||||
def allocate_id(self):
|
||||
return self.id_allocator.allocate()
|
||||
|
||||
|
|
Loading…
Reference in New Issue