diff --git a/mitogen/master.py b/mitogen/master.py index 034c87a4..b7bf9c40 100644 --- a/mitogen/master.py +++ b/mitogen/master.py @@ -715,15 +715,38 @@ class Context(mitogen.core.Context): return self.call_with_deadline(None, False, fn, *args, **kwargs) -def _proxy_connect(mitogen, name, context_id, klass, kwargs): - if not isinstance(mitogen.router, Router): # TODO - mitogen.router.__class__ = Router # TODO - LOG.debug('_proxy_connect(): constructing ModuleForwarder') - ModuleForwarder(mitogen.router, mitogen.parent, mitogen.importer) - context = mitogen.router._connect( +def _local_method(): + return Stream + +def _ssh_method(): + import mitogen.ssh + return mitogen.ssh.Stream + +def _sudo_method(): + import mitogen.sudo + return mitogen.sudo.Stream + + +METHOD_NAMES = { + 'local': _local_method, + 'ssh': _ssh_method, + 'sudo': _sudo_method, +} + + +def upgrade_router(econtext): + if not isinstance(econtext.router, Router): # TODO + econtext.router.__class__ = Router # TODO + LOG.debug('_proxy_connect(): constructing ModuleForwarder') + ModuleForwarder(econtext.router, econtext.parent, econtext.importer) + + +def _proxy_connect(econtext, name, context_id, method_name, kwargs): + upgrade_router(econtext) + context = econtext.router._connect( context_id, - klass, + METHOD_NAMES[method_name](), name=name, **kwargs ) @@ -759,15 +782,13 @@ class Router(mitogen.core.Router): return self._context_by_id.get(context_id) def local(self, **kwargs): - return self.connect(Stream, **kwargs) + return self.connect('local', **kwargs) def sudo(self, **kwargs): - import mitogen.sudo - return self.connect(mitogen.sudo.Stream, **kwargs) + return self.connect('sudo', **kwargs) def ssh(self, **kwargs): - import mitogen.ssh - return self.connect(mitogen.ssh.Stream, **kwargs) + return self.connect('ssh', **kwargs) def _connect(self, context_id, klass, name=None, **kwargs): context = Context(self, context_id) @@ -779,22 +800,22 @@ class Router(mitogen.core.Router): self.register(context, stream) return context - def connect(self, klass, name=None, **kwargs): + def connect(self, method_name, name=None, **kwargs): + klass = METHOD_NAMES[method_name]() kwargs.setdefault('debug', self.debug) via = kwargs.pop('via', None) if via is not None: - return self.proxy_connect(via, klass, name=name, **kwargs) - + return self.proxy_connect(via, method_name, name=name, **kwargs) context_id = self.context_id_counter.next() return self._connect(context_id, klass, name=name, **kwargs) - def proxy_connect(self, via_context, klass, name=None, **kwargs): + def proxy_connect(self, via_context, method_name, name=None, **kwargs): context_id = self.context_id_counter.next() # Must be added prior to _proxy_connect() to avoid a race. self.add_route(context_id, via_context.context_id) name = via_context.call_with_deadline(None, True, - _proxy_connect, name, context_id, klass, kwargs + _proxy_connect, name, context_id, method_name, kwargs ) # name = '%s.%s' % (via_context.name, name) context = Context(self, context_id, name=name)