From e62b891b9af001cbaf1bbb2544e1e29263f892fd Mon Sep 17 00:00:00 2001 From: David Wilson Date: Tue, 9 Aug 2016 04:06:14 +0100 Subject: [PATCH] SSH working * Get rid of persistent functions for now. * Split select into read/write sides for unidirectional SSH IO. * Put more of Loop in a try/except. --- econtext.py | 228 ++++++++++++++++++++++++---------------------------- 1 file changed, 105 insertions(+), 123 deletions(-) diff --git a/econtext.py b/econtext.py index 04f0f719..b248ca5e 100644 --- a/econtext.py +++ b/econtext.py @@ -90,7 +90,6 @@ def CreateChild(*args): if not pid: os.dup2(childfp.fileno(), 0) os.dup2(childfp.fileno(), 1) - sys.stderr = open('milf2', 'w', 1) childfp.close() parentfp.close() os.execvp(args[0], args) @@ -121,21 +120,6 @@ class Formatter(logging.Formatter): return p + ('{%s} %s' % (os.getpid(), s)) -class PartialFunction(object): - ''' - Partial function implementation. - ''' - def __init__(self, fn, *partial_args): - self.fn = fn - self.partial_args = partial_args - - def __call__(self, *args, **kwargs): - return self.fn(*(self.partial_args+args), **kwargs) - - def __repr__(self): - return 'PartialFunction(%r, *%r)' % (self.fn, self.partial_args) - - class Channel(object): def __init__(self, stream, handle): self._stream = stream @@ -268,11 +252,12 @@ class MasterModuleResponder(object): def __init__(self, stream): self._stream = stream - def GetModule(self, killed, (_, (reply_handle, fullname))): - LOG.debug('SlaveModuleImporter.GetModule(%r, %r)', killed, fullname) + def GetModule(self, killed, data): if killed: return + _, (reply_handle, fullname) = data + LOG.debug('SlaveModuleImporter.GetModule(%r, %r)', killed, fullname) mod = sys.modules.get(fullname) if mod: source = zlib.compress(inspect.getsource(mod)) @@ -285,12 +270,24 @@ class MasterModuleResponder(object): # -class BasicStream(object): +class Side(object): + def __init__(self, stream, fd): + self.stream = stream + self.fd = fd + + def __repr__(self): + return '' % (self.fd, self.stream) + def fileno(self): - return self._fd + return self.fd + + +class BasicStream(object): + read_side = None + write_side = None def Disconnect(self): - LOG.debug('%r: disconnect on %r fd %d', self._broker, self, self._fd) + LOG.debug('%r: disconnect on %r', self._broker, self) self._broker.RemoveStream(self) def ReadMore(self): @@ -325,11 +322,9 @@ class Stream(BasicStream): self._pickler_file = cStringIO.StringIO() self._pickler = cPickle.Pickler(self._pickler_file, protocol=2) - self._pickler.persistent_id = self._CheckFunctionPerID self._unpickler_file = cStringIO.StringIO() self._unpickler = cPickle.Unpickler(self._unpickler_file) - self._unpickler.persistent_load = self._LoadFunctionFromPerID def Pickle(self, obj): ''' @@ -365,37 +360,6 @@ class Stream(BasicStream): self._unpickler_file.truncate(0) return data - def _CheckFunctionPerID(self, obj): - ''' - Return None or a persistent ID for an object. - Please see the cPickle documentation. - - Args: - obj: object - - Returns: - str - ''' - if isinstance(obj, (types.FunctionType, types.MethodType)): - pid = 'FUNC:' + repr(obj) - self._func_refs[per_id] = obj - return pid - - def _LoadFunctionFromPerID(self, pid): - ''' - Load an object from a persistent ID. - Please see the cPickle documentation. - - Args: - pid: str - - Returns: - object - ''' - if not pid.startswith('FUNC:'): - raise CorruptMessageError('unrecognized persistent ID received: %r', pid) - return PartialFunction(self._CallPersistentWhatsit, pid) - def AllocHandle(self): ''' Allocate a unique handle for this stream. @@ -434,7 +398,7 @@ class Stream(BasicStream): ''' LOG.debug('%r.Receive()', self) - buf = os.read(self._fd, 4096) + buf = os.read(self.read_side.fd, 4096) if not buf: return self.Disconnect() @@ -484,7 +448,7 @@ class Stream(BasicStream): IOError ''' LOG.debug('%r.Transmit()', self) - written = os.write(self._fd, self._output_buf[:4096]) + written = os.write(self.write_side.fd, self._output_buf[:4096]) self._output_buf = self._output_buf[written:] def WriteMore(self): @@ -513,32 +477,41 @@ class Stream(BasicStream): def Disconnect(self): ''' - Close our associated file descriptor and tell any registered callbacks - that the connection has been destroyed. + Close our associated file descriptor and tell registered callbacks the + connection has been destroyed. ''' LOG.debug('%r.Disconnect()', self) - try: - os.close(self._fd) - except OSError, e: - LOG.debug('%r.Disconnect(): did not close fd %s: %s', - self, self._fd, e) - - self._fd = None if self._context.GetStream() is self: self._context.SetStream(None) + try: + os.close(self.read_side.fd) + except OSError, e: + LOG.debug('%r.Disconnect(): did not close fd %s: %s', + self, self.read_side.fd, e) + + if self.read_side.fd != self.write_side.fd: + try: + os.close(self.write_side.fd) + except OSError, e: + LOG.debug('%r.Disconnect(): did not close fd %s: %s', + self, self.write_side.fd, e) + + self.read_side.fd = None + self.write_side.fd = None for handle, (persist, fn) in self._handle_map.iteritems(): LOG.debug('%r.Disconnect(): stale callback handle=%r; fn=%r', self, handle, fn) fn(True, None) @classmethod - def Accept(cls, context, fd): + def Accept(cls, context, rfd, wfd): ''' ''' stream = cls(context) - stream._fd = os.dup(fd) + stream.read_side = Side(stream, os.dup(rfd)) + stream.write_side = Side(stream, os.dup(wfd)) context.SetStream(stream) context.broker.Register(context) return stream @@ -550,7 +523,8 @@ class Stream(BasicStream): LOG.debug('%r.Connect()', self) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._fd = sock.fileno() + self.read_side = Side(self, sock.fileno()) + self.write_side = Side(self, sock.fileno()) sock.connect(self._context.parent_addr) self.Enqueue(0, self._context.name) @@ -635,9 +609,11 @@ class LocalStream(Stream): def Connect(self): LOG.debug('%r.Connect()', self) pid, sock = CreateChild(*self.GetBootCommand()) - self._fd = os.dup(sock.fileno()) + self.read_side = Side(self, os.dup(sock.fileno())) + self.write_side = self.read_side sock.close() - LOG.debug('%r.Connect(): child process stdin/stdout=%r', self, self._fd) + LOG.debug('%r.Connect(): child process stdin/stdout=%r', + self, self.read_side.fd) source = inspect.getsource(sys.modules[__name__]) source += '\nExternalContextMain(%r, %r, %r)\n' % ( @@ -648,8 +624,8 @@ class LocalStream(Stream): compressed = zlib.compress(source) preamble = str(len(compressed)) + '\n' + compressed - write_all(self._fd, preamble) - assert os.read(self._fd, 3) == 'OK\n' + write_all(self.write_side.fd, preamble) + assert os.read(self.read_side.fd, 3) == 'OK\n' class SSHStream(LocalStream): @@ -751,16 +727,17 @@ class Context(object): class Waker(BasicStream): def __init__(self, broker): self._broker = broker - self._rfd, self._wfd = os.pipe() - self._fd = self._rfd + rfd, wfd = os.pipe() + self.read_side = Side(self, rfd) + self.write_side = Side(self, wfd) broker.AddStream(self) def Wake(self): - os.write(self._wfd, ' ') + os.write(self.write_side.fd, ' ') def Receive(self): LOG.debug('%r: waking %r', self, self._broker) - os.read(self._rfd, 1) + os.read(self.read_side.fd, 1) class Listener(BasicStream): @@ -770,7 +747,7 @@ class Listener(BasicStream): self._sock.bind(address or ('0.0.0.0', 0)) self._sock.listen(backlog) self._listen_addr = self._sock.getsockname() - self._fd = self._sock.fileno() + self.read_side = Side(self, self._sock.fileno()) broker.AddStream(self) def Receive(self): @@ -795,7 +772,6 @@ class Broker(object): self._waker = Waker(self) self._thread = threading.Thread(target=self._Loop, name='Broker') - self._thread.setDaemon(True) self._thread.start() def CreateListener(self, address=None, backlog=30): @@ -809,16 +785,15 @@ class Broker(object): def UpdateStream(self, stream, wake=False): LOG.debug('UpdateStream(%r, wake=%s)', stream, wake) - fileno = stream.fileno() - if fileno is not None and stream.ReadMore(): - self._readers.add(stream) + if stream.ReadMore() and stream.read_side.fileno(): + self._readers.add(stream.read_side) else: - self._readers.discard(stream) + self._readers.discard(stream.read_side) - if fileno is not None and stream.WriteMore(): - self._writers.add(stream) + if stream.WriteMore() and stream.write_side.fileno(): + self._writers.add(stream.write_side) else: - self._writers.discard(stream) + self._writers.discard(stream.write_side) if wake: self._waker.Wake() @@ -836,8 +811,9 @@ class Broker(object): ''' Put a context under control of this broker. ''' - LOG.debug('%r.Register(%r) -> fd=%r', self, context, - context.GetStream().fileno()) + LOG.debug('%r.Register(%r) -> r=%r w=%r', self, context, + context.GetStream().read_side, + context.GetStream().write_side) self.AddStream(context.GetStream()) self._contexts[context.name] = context return context @@ -855,7 +831,7 @@ class Broker(object): context.SetStream(LocalStream(context)).Connect() return self.Register(context) - def GetRemote(self, hostname, username, name=None): + def GetRemote(self, hostname, username, name=None, python_path=None): ''' Return the named remote context, or create it if it doesn't exist. ''' @@ -864,51 +840,56 @@ class Broker(object): (username, os.getenv('HOSTNAME'), os.getpid()) context = Context(self, name, hostname, username) - context.SetStream(SSHStream(context)).Connect() + stream = SSHStream(context) + if python_path: + stream.python_path = python_path + context.SetStream(stream) + stream.Connect() return self.Register(context) - def _Loop(self): - try: - self.Loop() - except Exception: - LOG.exception('Loop() crashed') + def _LoopOnce(self): + LOG.debug('%r.Loop()', self) + self._lock.acquire() + self._lock.release() - def Loop(self): + #LOG.debug('readers = %r', self._readers) + #LOG.debug('rfds = %r', [r.fileno() for r in self._readers]) + #LOG.debug('writers = %r', self._writers) + #LOG.debug('wfds = %r', [w.fileno() for w in self._writers]) + rsides, wsides, _ = select.select(self._readers, self._writers, ()) + for side in rsides: + LOG.debug('%r: POLLIN for %r', self, side.stream) + side.stream.Receive() + self.UpdateStream(side.stream) + + for side in wsides: + LOG.debug('%r: POLLOUT for %r', self, side.stream) + side.stream.Transmit() + self.UpdateStream(side.stream) + + def _Loop(self): ''' Handle stream events until Finalize() is called. ''' - while not self._dead: - LOG.debug('%r.Loop()', self) - self._lock.acquire() - self._lock.release() + try: + while not self._dead: + self._LoopOnce() - #LOG.debug('readers = %r', self._readers) - #LOG.debug('rfds = %r', [r.fileno() for r in self._readers]) - #LOG.debug('writers = %r', self._writers) - rstrms, wstrms, _ = select.select(self._readers, self._writers, ()) - for stream in rstrms: - LOG.debug('%r: POLLIN for %r', self, stream) - stream.Receive() - self.UpdateStream(stream) - - for stream in wstrms: - LOG.debug('%r: POLLOUT for %r', self, stream) - stream.Transmit() - self.UpdateStream(stream) + for context in self._contexts.itervalues(): + stream = context.GetStream() + if stream: + stream.Disconnect() + except Exception: + LOG.exception('Loop() crashed') def Finalize(self): ''' Tell all active streams to disconnect. ''' self._dead = True + self._waker.Wake() self._lock.acquire() - try: - for name, context in self._contexts.iteritems(): - stream = context.GetStream() - if stream: - stream.Disconnect() - finally: - self._lock.release() + self._lock.release() def __repr__(self): return 'econtext.Broker()' % (self._contexts.keys(),) @@ -918,19 +899,20 @@ def ExternalContextMain(context_name, parent_addr, key): syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID) syslog.syslog('initializing (parent=%s)' % (os.getenv('SSH_CLIENT'),)) - logging.basicConfig(level=logging.DEBUG) + logging.basicConfig(level=logging.INFO) logging.getLogger('').handlers[0].formatter = Formatter(False) LOG.debug('ExternalContextMain(%r, %r, %r)', context_name, parent_addr, key) - # os.wait() # Reap the first stage. + os.wait() # Reap the first stage. os.dup2(100, 0) os.close(100) broker = Broker() context = Context(broker, 'parent', parent_addr=parent_addr, key=key) - stream = Stream.Accept(context, 0) + stream = Stream.Accept(context, 0, 1) os.close(0) + os.close(1) # stream = context.SetStream(Stream(context)) # stream. @@ -942,7 +924,7 @@ def ExternalContextMain(context_name, parent_addr, key): for call_info in Channel(stream, CALL_FUNCTION): LOG.debug('ExternalContextMain(): CALL_FUNCTION %r', call_info) - (reply_handle, mod_name, class_name, func_name, args, kwargs) = call_info + reply_handle, mod_name, class_name, func_name, args, kwargs = call_info try: fn = getattr(__import__(mod_name), func_name)