SSH working
* Get rid of persistent functions for now. * Split select into read/write sides for unidirectional SSH IO. * Put more of Loop in a try/except.
This commit is contained in:
parent
1a30570057
commit
e62b891b9a
228
econtext.py
228
econtext.py
|
@ -90,7 +90,6 @@ def CreateChild(*args):
|
|||
if not pid:
|
||||
os.dup2(childfp.fileno(), 0)
|
||||
os.dup2(childfp.fileno(), 1)
|
||||
sys.stderr = open('milf2', 'w', 1)
|
||||
childfp.close()
|
||||
parentfp.close()
|
||||
os.execvp(args[0], args)
|
||||
|
@ -121,21 +120,6 @@ class Formatter(logging.Formatter):
|
|||
return p + ('{%s} %s' % (os.getpid(), s))
|
||||
|
||||
|
||||
class PartialFunction(object):
|
||||
'''
|
||||
Partial function implementation.
|
||||
'''
|
||||
def __init__(self, fn, *partial_args):
|
||||
self.fn = fn
|
||||
self.partial_args = partial_args
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.fn(*(self.partial_args+args), **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return 'PartialFunction(%r, *%r)' % (self.fn, self.partial_args)
|
||||
|
||||
|
||||
class Channel(object):
|
||||
def __init__(self, stream, handle):
|
||||
self._stream = stream
|
||||
|
@ -268,11 +252,12 @@ class MasterModuleResponder(object):
|
|||
def __init__(self, stream):
|
||||
self._stream = stream
|
||||
|
||||
def GetModule(self, killed, (_, (reply_handle, fullname))):
|
||||
LOG.debug('SlaveModuleImporter.GetModule(%r, %r)', killed, fullname)
|
||||
def GetModule(self, killed, data):
|
||||
if killed:
|
||||
return
|
||||
|
||||
_, (reply_handle, fullname) = data
|
||||
LOG.debug('SlaveModuleImporter.GetModule(%r, %r)', killed, fullname)
|
||||
mod = sys.modules.get(fullname)
|
||||
if mod:
|
||||
source = zlib.compress(inspect.getsource(mod))
|
||||
|
@ -285,12 +270,24 @@ class MasterModuleResponder(object):
|
|||
#
|
||||
|
||||
|
||||
class BasicStream(object):
|
||||
class Side(object):
|
||||
def __init__(self, stream, fd):
|
||||
self.stream = stream
|
||||
self.fd = fd
|
||||
|
||||
def __repr__(self):
|
||||
return '<fd %r of %r>' % (self.fd, self.stream)
|
||||
|
||||
def fileno(self):
|
||||
return self._fd
|
||||
return self.fd
|
||||
|
||||
|
||||
class BasicStream(object):
|
||||
read_side = None
|
||||
write_side = None
|
||||
|
||||
def Disconnect(self):
|
||||
LOG.debug('%r: disconnect on %r fd %d', self._broker, self, self._fd)
|
||||
LOG.debug('%r: disconnect on %r', self._broker, self)
|
||||
self._broker.RemoveStream(self)
|
||||
|
||||
def ReadMore(self):
|
||||
|
@ -325,11 +322,9 @@ class Stream(BasicStream):
|
|||
|
||||
self._pickler_file = cStringIO.StringIO()
|
||||
self._pickler = cPickle.Pickler(self._pickler_file, protocol=2)
|
||||
self._pickler.persistent_id = self._CheckFunctionPerID
|
||||
|
||||
self._unpickler_file = cStringIO.StringIO()
|
||||
self._unpickler = cPickle.Unpickler(self._unpickler_file)
|
||||
self._unpickler.persistent_load = self._LoadFunctionFromPerID
|
||||
|
||||
def Pickle(self, obj):
|
||||
'''
|
||||
|
@ -365,37 +360,6 @@ class Stream(BasicStream):
|
|||
self._unpickler_file.truncate(0)
|
||||
return data
|
||||
|
||||
def _CheckFunctionPerID(self, obj):
|
||||
'''
|
||||
Return None or a persistent ID for an object.
|
||||
Please see the cPickle documentation.
|
||||
|
||||
Args:
|
||||
obj: object
|
||||
|
||||
Returns:
|
||||
str
|
||||
'''
|
||||
if isinstance(obj, (types.FunctionType, types.MethodType)):
|
||||
pid = 'FUNC:' + repr(obj)
|
||||
self._func_refs[per_id] = obj
|
||||
return pid
|
||||
|
||||
def _LoadFunctionFromPerID(self, pid):
|
||||
'''
|
||||
Load an object from a persistent ID.
|
||||
Please see the cPickle documentation.
|
||||
|
||||
Args:
|
||||
pid: str
|
||||
|
||||
Returns:
|
||||
object
|
||||
'''
|
||||
if not pid.startswith('FUNC:'):
|
||||
raise CorruptMessageError('unrecognized persistent ID received: %r', pid)
|
||||
return PartialFunction(self._CallPersistentWhatsit, pid)
|
||||
|
||||
def AllocHandle(self):
|
||||
'''
|
||||
Allocate a unique handle for this stream.
|
||||
|
@ -434,7 +398,7 @@ class Stream(BasicStream):
|
|||
'''
|
||||
LOG.debug('%r.Receive()', self)
|
||||
|
||||
buf = os.read(self._fd, 4096)
|
||||
buf = os.read(self.read_side.fd, 4096)
|
||||
if not buf:
|
||||
return self.Disconnect()
|
||||
|
||||
|
@ -484,7 +448,7 @@ class Stream(BasicStream):
|
|||
IOError
|
||||
'''
|
||||
LOG.debug('%r.Transmit()', self)
|
||||
written = os.write(self._fd, self._output_buf[:4096])
|
||||
written = os.write(self.write_side.fd, self._output_buf[:4096])
|
||||
self._output_buf = self._output_buf[written:]
|
||||
|
||||
def WriteMore(self):
|
||||
|
@ -513,32 +477,41 @@ class Stream(BasicStream):
|
|||
|
||||
def Disconnect(self):
|
||||
'''
|
||||
Close our associated file descriptor and tell any registered callbacks
|
||||
that the connection has been destroyed.
|
||||
Close our associated file descriptor and tell registered callbacks the
|
||||
connection has been destroyed.
|
||||
'''
|
||||
LOG.debug('%r.Disconnect()', self)
|
||||
try:
|
||||
os.close(self._fd)
|
||||
except OSError, e:
|
||||
LOG.debug('%r.Disconnect(): did not close fd %s: %s',
|
||||
self, self._fd, e)
|
||||
|
||||
self._fd = None
|
||||
if self._context.GetStream() is self:
|
||||
self._context.SetStream(None)
|
||||
|
||||
try:
|
||||
os.close(self.read_side.fd)
|
||||
except OSError, e:
|
||||
LOG.debug('%r.Disconnect(): did not close fd %s: %s',
|
||||
self, self.read_side.fd, e)
|
||||
|
||||
if self.read_side.fd != self.write_side.fd:
|
||||
try:
|
||||
os.close(self.write_side.fd)
|
||||
except OSError, e:
|
||||
LOG.debug('%r.Disconnect(): did not close fd %s: %s',
|
||||
self, self.write_side.fd, e)
|
||||
|
||||
self.read_side.fd = None
|
||||
self.write_side.fd = None
|
||||
for handle, (persist, fn) in self._handle_map.iteritems():
|
||||
LOG.debug('%r.Disconnect(): stale callback handle=%r; fn=%r',
|
||||
self, handle, fn)
|
||||
fn(True, None)
|
||||
|
||||
@classmethod
|
||||
def Accept(cls, context, fd):
|
||||
def Accept(cls, context, rfd, wfd):
|
||||
'''
|
||||
|
||||
'''
|
||||
stream = cls(context)
|
||||
stream._fd = os.dup(fd)
|
||||
stream.read_side = Side(stream, os.dup(rfd))
|
||||
stream.write_side = Side(stream, os.dup(wfd))
|
||||
context.SetStream(stream)
|
||||
context.broker.Register(context)
|
||||
return stream
|
||||
|
@ -550,7 +523,8 @@ class Stream(BasicStream):
|
|||
|
||||
LOG.debug('%r.Connect()', self)
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self._fd = sock.fileno()
|
||||
self.read_side = Side(self, sock.fileno())
|
||||
self.write_side = Side(self, sock.fileno())
|
||||
sock.connect(self._context.parent_addr)
|
||||
self.Enqueue(0, self._context.name)
|
||||
|
||||
|
@ -635,9 +609,11 @@ class LocalStream(Stream):
|
|||
def Connect(self):
|
||||
LOG.debug('%r.Connect()', self)
|
||||
pid, sock = CreateChild(*self.GetBootCommand())
|
||||
self._fd = os.dup(sock.fileno())
|
||||
self.read_side = Side(self, os.dup(sock.fileno()))
|
||||
self.write_side = self.read_side
|
||||
sock.close()
|
||||
LOG.debug('%r.Connect(): child process stdin/stdout=%r', self, self._fd)
|
||||
LOG.debug('%r.Connect(): child process stdin/stdout=%r',
|
||||
self, self.read_side.fd)
|
||||
|
||||
source = inspect.getsource(sys.modules[__name__])
|
||||
source += '\nExternalContextMain(%r, %r, %r)\n' % (
|
||||
|
@ -648,8 +624,8 @@ class LocalStream(Stream):
|
|||
compressed = zlib.compress(source)
|
||||
|
||||
preamble = str(len(compressed)) + '\n' + compressed
|
||||
write_all(self._fd, preamble)
|
||||
assert os.read(self._fd, 3) == 'OK\n'
|
||||
write_all(self.write_side.fd, preamble)
|
||||
assert os.read(self.read_side.fd, 3) == 'OK\n'
|
||||
|
||||
|
||||
class SSHStream(LocalStream):
|
||||
|
@ -751,16 +727,17 @@ class Context(object):
|
|||
class Waker(BasicStream):
|
||||
def __init__(self, broker):
|
||||
self._broker = broker
|
||||
self._rfd, self._wfd = os.pipe()
|
||||
self._fd = self._rfd
|
||||
rfd, wfd = os.pipe()
|
||||
self.read_side = Side(self, rfd)
|
||||
self.write_side = Side(self, wfd)
|
||||
broker.AddStream(self)
|
||||
|
||||
def Wake(self):
|
||||
os.write(self._wfd, ' ')
|
||||
os.write(self.write_side.fd, ' ')
|
||||
|
||||
def Receive(self):
|
||||
LOG.debug('%r: waking %r', self, self._broker)
|
||||
os.read(self._rfd, 1)
|
||||
os.read(self.read_side.fd, 1)
|
||||
|
||||
|
||||
class Listener(BasicStream):
|
||||
|
@ -770,7 +747,7 @@ class Listener(BasicStream):
|
|||
self._sock.bind(address or ('0.0.0.0', 0))
|
||||
self._sock.listen(backlog)
|
||||
self._listen_addr = self._sock.getsockname()
|
||||
self._fd = self._sock.fileno()
|
||||
self.read_side = Side(self, self._sock.fileno())
|
||||
broker.AddStream(self)
|
||||
|
||||
def Receive(self):
|
||||
|
@ -795,7 +772,6 @@ class Broker(object):
|
|||
self._waker = Waker(self)
|
||||
|
||||
self._thread = threading.Thread(target=self._Loop, name='Broker')
|
||||
self._thread.setDaemon(True)
|
||||
self._thread.start()
|
||||
|
||||
def CreateListener(self, address=None, backlog=30):
|
||||
|
@ -809,16 +785,15 @@ class Broker(object):
|
|||
|
||||
def UpdateStream(self, stream, wake=False):
|
||||
LOG.debug('UpdateStream(%r, wake=%s)', stream, wake)
|
||||
fileno = stream.fileno()
|
||||
if fileno is not None and stream.ReadMore():
|
||||
self._readers.add(stream)
|
||||
if stream.ReadMore() and stream.read_side.fileno():
|
||||
self._readers.add(stream.read_side)
|
||||
else:
|
||||
self._readers.discard(stream)
|
||||
self._readers.discard(stream.read_side)
|
||||
|
||||
if fileno is not None and stream.WriteMore():
|
||||
self._writers.add(stream)
|
||||
if stream.WriteMore() and stream.write_side.fileno():
|
||||
self._writers.add(stream.write_side)
|
||||
else:
|
||||
self._writers.discard(stream)
|
||||
self._writers.discard(stream.write_side)
|
||||
|
||||
if wake:
|
||||
self._waker.Wake()
|
||||
|
@ -836,8 +811,9 @@ class Broker(object):
|
|||
'''
|
||||
Put a context under control of this broker.
|
||||
'''
|
||||
LOG.debug('%r.Register(%r) -> fd=%r', self, context,
|
||||
context.GetStream().fileno())
|
||||
LOG.debug('%r.Register(%r) -> r=%r w=%r', self, context,
|
||||
context.GetStream().read_side,
|
||||
context.GetStream().write_side)
|
||||
self.AddStream(context.GetStream())
|
||||
self._contexts[context.name] = context
|
||||
return context
|
||||
|
@ -855,7 +831,7 @@ class Broker(object):
|
|||
context.SetStream(LocalStream(context)).Connect()
|
||||
return self.Register(context)
|
||||
|
||||
def GetRemote(self, hostname, username, name=None):
|
||||
def GetRemote(self, hostname, username, name=None, python_path=None):
|
||||
'''
|
||||
Return the named remote context, or create it if it doesn't exist.
|
||||
'''
|
||||
|
@ -864,51 +840,56 @@ class Broker(object):
|
|||
(username, os.getenv('HOSTNAME'), os.getpid())
|
||||
|
||||
context = Context(self, name, hostname, username)
|
||||
context.SetStream(SSHStream(context)).Connect()
|
||||
stream = SSHStream(context)
|
||||
if python_path:
|
||||
stream.python_path = python_path
|
||||
context.SetStream(stream)
|
||||
stream.Connect()
|
||||
return self.Register(context)
|
||||
|
||||
def _Loop(self):
|
||||
try:
|
||||
self.Loop()
|
||||
except Exception:
|
||||
LOG.exception('Loop() crashed')
|
||||
def _LoopOnce(self):
|
||||
LOG.debug('%r.Loop()', self)
|
||||
self._lock.acquire()
|
||||
self._lock.release()
|
||||
|
||||
def Loop(self):
|
||||
#LOG.debug('readers = %r', self._readers)
|
||||
#LOG.debug('rfds = %r', [r.fileno() for r in self._readers])
|
||||
#LOG.debug('writers = %r', self._writers)
|
||||
#LOG.debug('wfds = %r', [w.fileno() for w in self._writers])
|
||||
rsides, wsides, _ = select.select(self._readers, self._writers, ())
|
||||
for side in rsides:
|
||||
LOG.debug('%r: POLLIN for %r', self, side.stream)
|
||||
side.stream.Receive()
|
||||
self.UpdateStream(side.stream)
|
||||
|
||||
for side in wsides:
|
||||
LOG.debug('%r: POLLOUT for %r', self, side.stream)
|
||||
side.stream.Transmit()
|
||||
self.UpdateStream(side.stream)
|
||||
|
||||
def _Loop(self):
|
||||
'''
|
||||
Handle stream events until Finalize() is called.
|
||||
'''
|
||||
while not self._dead:
|
||||
LOG.debug('%r.Loop()', self)
|
||||
self._lock.acquire()
|
||||
self._lock.release()
|
||||
try:
|
||||
while not self._dead:
|
||||
self._LoopOnce()
|
||||
|
||||
#LOG.debug('readers = %r', self._readers)
|
||||
#LOG.debug('rfds = %r', [r.fileno() for r in self._readers])
|
||||
#LOG.debug('writers = %r', self._writers)
|
||||
rstrms, wstrms, _ = select.select(self._readers, self._writers, ())
|
||||
for stream in rstrms:
|
||||
LOG.debug('%r: POLLIN for %r', self, stream)
|
||||
stream.Receive()
|
||||
self.UpdateStream(stream)
|
||||
|
||||
for stream in wstrms:
|
||||
LOG.debug('%r: POLLOUT for %r', self, stream)
|
||||
stream.Transmit()
|
||||
self.UpdateStream(stream)
|
||||
for context in self._contexts.itervalues():
|
||||
stream = context.GetStream()
|
||||
if stream:
|
||||
stream.Disconnect()
|
||||
except Exception:
|
||||
LOG.exception('Loop() crashed')
|
||||
|
||||
def Finalize(self):
|
||||
'''
|
||||
Tell all active streams to disconnect.
|
||||
'''
|
||||
self._dead = True
|
||||
self._waker.Wake()
|
||||
self._lock.acquire()
|
||||
try:
|
||||
for name, context in self._contexts.iteritems():
|
||||
stream = context.GetStream()
|
||||
if stream:
|
||||
stream.Disconnect()
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._lock.release()
|
||||
|
||||
def __repr__(self):
|
||||
return 'econtext.Broker(<contexts=%s>)' % (self._contexts.keys(),)
|
||||
|
@ -918,19 +899,20 @@ def ExternalContextMain(context_name, parent_addr, key):
|
|||
syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID)
|
||||
syslog.syslog('initializing (parent=%s)' % (os.getenv('SSH_CLIENT'),))
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.getLogger('').handlers[0].formatter = Formatter(False)
|
||||
LOG.debug('ExternalContextMain(%r, %r, %r)', context_name, parent_addr, key)
|
||||
|
||||
# os.wait() # Reap the first stage.
|
||||
os.wait() # Reap the first stage.
|
||||
os.dup2(100, 0)
|
||||
os.close(100)
|
||||
|
||||
broker = Broker()
|
||||
context = Context(broker, 'parent', parent_addr=parent_addr, key=key)
|
||||
|
||||
stream = Stream.Accept(context, 0)
|
||||
stream = Stream.Accept(context, 0, 1)
|
||||
os.close(0)
|
||||
os.close(1)
|
||||
|
||||
# stream = context.SetStream(Stream(context))
|
||||
# stream.
|
||||
|
@ -942,7 +924,7 @@ def ExternalContextMain(context_name, parent_addr, key):
|
|||
|
||||
for call_info in Channel(stream, CALL_FUNCTION):
|
||||
LOG.debug('ExternalContextMain(): CALL_FUNCTION %r', call_info)
|
||||
(reply_handle, mod_name, class_name, func_name, args, kwargs) = call_info
|
||||
reply_handle, mod_name, class_name, func_name, args, kwargs = call_info
|
||||
|
||||
try:
|
||||
fn = getattr(__import__(mod_name), func_name)
|
||||
|
|
Loading…
Reference in New Issue