813 lines
22 KiB
Python
Executable File
813 lines
22 KiB
Python
Executable File
#!/usr/bin/env python2.5
|
|
|
|
'''
|
|
Python External Execution Contexts.
|
|
'''
|
|
|
|
import atexit
|
|
import cPickle
|
|
import cStringIO
|
|
import commands
|
|
import getpass
|
|
import hmac
|
|
import imp
|
|
import inspect
|
|
import os
|
|
import select
|
|
import sha
|
|
import signal
|
|
import socket
|
|
import struct
|
|
import sys
|
|
import syslog
|
|
import textwrap
|
|
import threading
|
|
import traceback
|
|
import types
|
|
import zlib
|
|
|
|
|
|
#
|
|
# Module-level data.
|
|
#
|
|
|
|
GET_MODULE = 0L
|
|
CALL_FUNCTION = 1L
|
|
|
|
DEBUG = True
|
|
|
|
|
|
#
|
|
# Exceptions.
|
|
#
|
|
|
|
class ContextError(Exception):
|
|
'Raised when a problem occurs with a context.'
|
|
def __init__(self, fmt, *args):
|
|
Exception.__init__(self, fmt % args)
|
|
|
|
class ChannelError(ContextError):
|
|
'Raised when a channel dies or has been closed.'
|
|
|
|
class StreamError(ContextError):
|
|
'Raised when a stream cannot be established.'
|
|
|
|
class CorruptMessageError(StreamError):
|
|
'Raised when a corrupt message is received on a stream.'
|
|
|
|
class TimeoutError(StreamError):
|
|
'Raised when a timeout occurs on a stream.'
|
|
|
|
|
|
#
|
|
# Helpers.
|
|
#
|
|
|
|
def Log(fmt, *args):
|
|
if DEBUG:
|
|
sys.stderr.write('%d (%d): %s\n' % (os.getpid(), os.getppid(),
|
|
(fmt%args).replace('econtext.', '')))
|
|
|
|
|
|
def CreateChild(*args):
|
|
'''
|
|
Create a child process whose stdin/stdout is connected to a socket.
|
|
|
|
Args:
|
|
*args: executable name and process arguments.
|
|
|
|
Returns:
|
|
pid, sock
|
|
'''
|
|
sock1, sock2 = socket.socketpair()
|
|
pid = os.fork()
|
|
if not pid:
|
|
for pair in ((0, sock1), (1, sock2)):
|
|
os.dup2(sock2.fileno(), pair[0])
|
|
os.close(pair[1].fileno())
|
|
os.execvp(args[0], args)
|
|
raise SystemExit
|
|
return pid, sock1
|
|
|
|
|
|
class PartialFunction(object):
|
|
'''
|
|
Partial function implementation.
|
|
'''
|
|
def __init__(self, fn, *partial_args):
|
|
self.fn = fn
|
|
self.partial_args = partial_args
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return self.fn(*(self.partial_args+args), **kwargs)
|
|
|
|
def __repr__(self):
|
|
return 'PartialFunction(%r, *%r)' % (self.fn, self.partial_args)
|
|
|
|
|
|
class Channel(object):
|
|
def __init__(self, stream, handle):
|
|
self._stream = stream
|
|
self._handle = handle
|
|
self._wake_event = threading.Event()
|
|
self._queue_lock = threading.Lock()
|
|
self._queue = []
|
|
self._stream.AddHandleCB(self._InternalReceive, handle)
|
|
|
|
def _InternalReceive(self, killed, data):
|
|
'''
|
|
Callback from the stream object; appends a tuple of
|
|
(killed-or-closed, data) to the internal queue and wakes the internal
|
|
event.
|
|
|
|
Args:
|
|
# Has the Stream object lost its connection?
|
|
killed: bool
|
|
data: (
|
|
# Has the remote Channel had Close() called?
|
|
bool,
|
|
# The object passed to the remote Send()
|
|
object
|
|
)
|
|
'''
|
|
Log('%r._InternalReceive(%r, %r)', self, killed, data)
|
|
self._queue_lock.acquire()
|
|
try:
|
|
self._queue.append((killed or data[0], killed or data[1]))
|
|
self._wake_event.set()
|
|
finally:
|
|
self._queue_lock.release()
|
|
|
|
def Close(self):
|
|
'''
|
|
Indicate this channel is closed to the remote side.
|
|
'''
|
|
Log('%r.Close()', self)
|
|
self._stream.Enqueue(handle, (True, None))
|
|
|
|
def Send(self, data):
|
|
'''
|
|
Send the given object to the remote side.
|
|
'''
|
|
Log('%r.Send(%r)', self, data)
|
|
self._stream.Enqueue(handle, (False, data))
|
|
|
|
def Receive(self, timeout=None):
|
|
'''
|
|
Receive the next object to arrive on this channel, or return if the
|
|
optional timeout is reached.
|
|
|
|
Args:
|
|
timeout: float
|
|
|
|
Returns:
|
|
object
|
|
'''
|
|
Log('%r.Receive(%r)', self, timeout)
|
|
if not self._queue:
|
|
self._wake_event.wait(timeout)
|
|
if not self._wake_event.isSet():
|
|
return
|
|
|
|
self._queue_lock.acquire()
|
|
try:
|
|
self._wake_event.clear()
|
|
Log('%r.Receive() queue is %r', self, self._queue)
|
|
closed, data = self._queue.pop(0)
|
|
Log('%r.Receive() got closed=%r, data=%r', self, closed, data)
|
|
if closed:
|
|
raise ChannelError('Channel is closed.')
|
|
return data
|
|
finally:
|
|
self._queue_lock.release()
|
|
|
|
def __iter__(self):
|
|
'''
|
|
Return an iterator that yields objects arriving on this channel, until the
|
|
channel dies or is closed.
|
|
'''
|
|
while True:
|
|
try:
|
|
yield self.Receive()
|
|
except ChannelError:
|
|
return
|
|
|
|
def __repr__(self):
|
|
return 'econtext.Channel(%r, %r)' % (self._stream, self._handle)
|
|
|
|
|
|
class SlaveModuleImporter(object):
|
|
'''
|
|
Import protocol implementation that fetches modules from the parent process.
|
|
'''
|
|
|
|
def __init__(self, context):
|
|
'''
|
|
Initialise a new instance.
|
|
|
|
Args:
|
|
context: Context instance this importer will communicate via.
|
|
'''
|
|
self._context = context
|
|
|
|
def find_module(self, fullname, path=None):
|
|
if not imp.find_module(fullname):
|
|
return self
|
|
|
|
def load_module(self, fullname):
|
|
kind, data = self._context.EnqueueAwaitReply(GET_MODULE, fullname)
|
|
|
|
def GetModule(cls, killed, fullname):
|
|
Log('%r.GetModule(%r, %r)', cls, killed, fullname)
|
|
if killed:
|
|
return
|
|
if fullname in sys.modules:
|
|
pass
|
|
GetModule = classmethod(GetModule)
|
|
|
|
|
|
#
|
|
# Stream implementations.
|
|
#
|
|
|
|
class Stream(object):
|
|
def __init__(self, context):
|
|
'''
|
|
Initialize a new Stream instance.
|
|
|
|
Args:
|
|
context: econtext.Context
|
|
'''
|
|
self._context = context
|
|
|
|
self._input_buf = self._output_buf = ''
|
|
self._input_buf_lock = threading.Lock()
|
|
self._output_buf_lock = threading.Lock()
|
|
self._rhmac = hmac.new(context.key, digestmod=sha.new)
|
|
self._whmac = self._rhmac.copy()
|
|
|
|
self._last_handle = 0
|
|
self._handle_map = {}
|
|
self._handle_lock = threading.Lock()
|
|
|
|
self._func_refs = {}
|
|
self._func_ref_lock = threading.Lock()
|
|
|
|
self._pickler_file = cStringIO.StringIO()
|
|
self._pickler = cPickle.Pickler(self._pickler_file)
|
|
self._pickler.persistent_id = self._CheckFunctionPerID
|
|
|
|
self._unpickler_file = cStringIO.StringIO()
|
|
self._unpickler = cPickle.Unpickler(self._unpickler_file)
|
|
self._unpickler.persistent_load = self._LoadFunctionFromPerID
|
|
|
|
def Pickle(self, obj):
|
|
'''
|
|
Serialize the given object using the pickler.
|
|
|
|
Args:
|
|
obj: object
|
|
|
|
Returns:
|
|
str
|
|
'''
|
|
self._pickler.dump(obj)
|
|
data = self._pickler_file.getvalue()
|
|
self._pickler_file.seek(0)
|
|
self._pickler_file.truncate(0)
|
|
return data
|
|
|
|
def Unpickle(self, data):
|
|
'''
|
|
Unserialize the given string using the unpickler.
|
|
|
|
Args:
|
|
data: str
|
|
|
|
Returns:
|
|
object
|
|
'''
|
|
Log('%r.Unpickle(%r)', self, data)
|
|
self._unpickler_file.write(data)
|
|
self._unpickler_file.seek(0)
|
|
data = self._unpickler.load()
|
|
self._unpickler_file.seek(0)
|
|
self._unpickler_file.truncate(0)
|
|
return data
|
|
|
|
def _CheckFunctionPerID(self, obj):
|
|
'''
|
|
Return None or a persistent ID for an object.
|
|
Please see the cPickle documentation.
|
|
|
|
Args:
|
|
obj: object
|
|
|
|
Returns:
|
|
str
|
|
'''
|
|
if isinstance(obj, (types.FunctionType, types.MethodType)):
|
|
pid = 'FUNC:' + repr(obj)
|
|
self._func_refs[per_id] = obj
|
|
return pid
|
|
|
|
def _LoadFunctionFromPerID(self, pid):
|
|
'''
|
|
Load an object from a persistent ID.
|
|
Please see the cPickle documentation.
|
|
|
|
Args:
|
|
pid: str
|
|
|
|
Returns:
|
|
object
|
|
'''
|
|
if not pid.startswith('FUNC:'):
|
|
raise CorruptMessageError('unrecognized persistent ID received: %r', pid)
|
|
return PartialFunction(self._CallPersistentWhatsit, pid)
|
|
|
|
def AllocHandle(self):
|
|
'''
|
|
Allocate a unique handle for this stream.
|
|
|
|
Returns:
|
|
long
|
|
'''
|
|
self._handle_lock.acquire()
|
|
try:
|
|
self._last_handle += 1L
|
|
finally:
|
|
self._handle_lock.release()
|
|
return self._last_handle
|
|
|
|
def AddHandleCB(self, fn, handle, persist=True):
|
|
'''
|
|
Invoke a function for all messages with the given handle.
|
|
|
|
Args:
|
|
fn: callable
|
|
handle: long
|
|
persist: False to only receive a single message.
|
|
'''
|
|
Log('%r.AddHandleCB(%r, %r, persist=%r)', self, fn, handle, persist)
|
|
self._handle_lock.acquire()
|
|
try:
|
|
self._handle_map[handle] = persist, fn
|
|
finally:
|
|
self._handle_lock.release()
|
|
|
|
def Receive(self):
|
|
'''
|
|
Handle the next complete message on the stream. Raise CorruptMessageError
|
|
or IOError on failure.
|
|
'''
|
|
Log('%r.Receive()', self)
|
|
|
|
self._input_buf += os.read(self._rfd, 4096)
|
|
if len(self._input_buf) < 24:
|
|
return
|
|
|
|
msg_mac = self._input_buf[:20]
|
|
msg_len = struct.unpack('>L', self._input_buf[20:24])[0]
|
|
if len(self._input_buf) < msg_len-24:
|
|
return
|
|
|
|
self._rhmac.update(self._input_buf[20:msg_len+24])
|
|
expected_mac = self._rhmac.digest()
|
|
if msg_mac != expected_mac:
|
|
raise CorruptMessageError('%r got invalid MAC: expected %r, got %r',
|
|
self, msg_mac.encode('hex'),
|
|
expected_mac.encode('hex'))
|
|
|
|
try:
|
|
handle, data = self.Unpickle(self._input_buf[24:msg_len+24])
|
|
self._input_buf = self._input_buf[msg_len+4:]
|
|
handle = long(handle)
|
|
|
|
Log('%r.Receive(): decoded handle=%r; data=%r', self, handle, data)
|
|
persist, fn = self._handle_map[handle]
|
|
if not persist:
|
|
del self._handle_map[handle]
|
|
except KeyError, ex:
|
|
raise CorruptMessageError('%r got invalid handle: %r', self, handle)
|
|
except (TypeError, ValueError), ex:
|
|
raise CorruptMessageError('%r got invalid message: %s', self, ex)
|
|
|
|
fn(False, data)
|
|
|
|
def Transmit(self):
|
|
'''
|
|
Transmit buffered messages.
|
|
|
|
Returns:
|
|
bool: more data left in bufer?
|
|
|
|
Raises:
|
|
IOError
|
|
'''
|
|
Log('%r.Transmit()', self)
|
|
written = os.write(self._fd, self._output_buf[:4096])
|
|
self._output_buf = self._output_buf[written:]
|
|
return bool(self._output_buf)
|
|
|
|
def Enqueue(self, handle, obj):
|
|
'''
|
|
Serialize an object, send it to the given handle, and tell our context's
|
|
broker we have output.
|
|
|
|
Args:
|
|
handle: long
|
|
obj: object
|
|
'''
|
|
Log('%r.Enqueue(%r, %r)', self, handle, obj)
|
|
|
|
self._output_buf_lock.acquire()
|
|
try:
|
|
encoded = self.Pickle((handle, obj))
|
|
msg = struct.pack('>L', len(encoded)) + encoded
|
|
self._whmac.update(msg)
|
|
self._output_buf += self._whmac.digest() + msg
|
|
finally:
|
|
self._output_buf_lock.release()
|
|
self._context.broker.Register(self._context)
|
|
|
|
def Disconnect(self):
|
|
'''
|
|
Close our associated file descriptor and tell any registered callbacks
|
|
that the connection has been destroyed.
|
|
'''
|
|
Log('%r.Disconnect()', self)
|
|
try:
|
|
os.close(self._fd)
|
|
except OSError, e:
|
|
Log('%r.Disconnect(): did not close fd %s: %s', self, self._fd, e)
|
|
|
|
for handle, (persist, fn) in self._handle_map.iteritems():
|
|
Log('%r.Disconnect(): killing stale callback handle=%r; fn=%r',
|
|
self, handle, fn)
|
|
fn(True, None)
|
|
|
|
@classmethod
|
|
def Accept(cls, context, sock):
|
|
'''
|
|
|
|
'''
|
|
stream = cls(context)
|
|
context.SetStream()
|
|
broker.Register(context)
|
|
|
|
def Connect(self):
|
|
'''
|
|
Connect to a Broker at the address specified in our associated Context.
|
|
'''
|
|
|
|
Log('%r.Connect()', self)
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
self._fd = sock.fileno()
|
|
sock.connect(self._context.parent_addr)
|
|
self.Enqueue(0, self._context.name)
|
|
|
|
def fileno(self):
|
|
return self._fd
|
|
|
|
def __repr__(self):
|
|
return 'econtext.%s(<context=%r>)' %\
|
|
(self.__class__.__name__, self._context)
|
|
|
|
|
|
class LocalStream(Stream):
|
|
'''
|
|
Base for streams capable of starting new slaves.
|
|
'''
|
|
|
|
python_path = property(
|
|
lambda self: getattr(self, '_python_path', sys.executable),
|
|
lambda self, path: setattr(self, '_python_path', path),
|
|
doc='The path to the remote Python interpreter.')
|
|
|
|
def __init__(self, context):
|
|
super(LocalStream, self).__init__(context)
|
|
self._permitted_modules = {}
|
|
self._unpickler.find_global = self._FindGlobal
|
|
self.AddHandleCB(SlaveModuleImporter.GetModule, handle=GET_MODULE)
|
|
|
|
def _FindGlobal(self, module_name, class_name):
|
|
'''
|
|
See the cPickle documentation: given a module and class name, determine
|
|
whether class referred to is safe for unpickling.
|
|
|
|
Args:
|
|
module_name: str
|
|
class_name: str
|
|
|
|
Returns:
|
|
classobj or type
|
|
'''
|
|
if module_name not in self._permitted_modules:
|
|
raise StreamError('context %r attempted to unpickle %r in module %r',
|
|
self._context, class_name, module_name)
|
|
return getattr(sys.modules[module_name], class_name)
|
|
|
|
def AllowModule(self, module_name):
|
|
'''
|
|
Add the given module to the list of permitted modules.
|
|
|
|
Args:
|
|
module_name: str
|
|
'''
|
|
self._permitted_modules.add(module_name)
|
|
|
|
# Hexed and passed to 'python -c'. It forks, dups 0->100, creates a pipe,
|
|
# then execs a new interpreter with a custom argv. CONTEXT_NAME is replaced
|
|
# with the context name. Optimized for source size.
|
|
def _FirstStage():
|
|
import os,sys,zlib
|
|
R,W=os.pipe()
|
|
if os.fork():
|
|
os.dup2(0,100)
|
|
os.dup2(R,0)
|
|
os.close(R)
|
|
os.close(W)
|
|
os.execv(sys.executable,(CONTEXT_NAME,))
|
|
else:
|
|
os.fdopen(W,'wb',0).write(zlib.decompress(sys.stdin.read(input())))
|
|
print 'OK'
|
|
sys.exit(0)
|
|
|
|
def GetBootCommand(self):
|
|
source = inspect.getsource(self._FirstStage)
|
|
source = textwrap.dedent('\n'.join(source.strip().split('\n')[1:]))
|
|
source = source.replace(' ', '\t')
|
|
source = source.replace('CONTEXT_NAME', repr(self._context.name))
|
|
return [ self.python_path, '-c',
|
|
'exec "%s".decode("hex")' % (source.encode('hex'),) ]
|
|
|
|
def __repr__(self):
|
|
return '%s(%s)' % (self.__class__.__name__, self._context)
|
|
|
|
def Connect(self):
|
|
Log('%r.Connect()', self)
|
|
pid, sock = CreateChild(*self.GetBootCommand())
|
|
self._fd = sock.fileno()
|
|
Log('%r.Connect(): child process stdin/stdout=%r', self, self._fd)
|
|
|
|
source = inspect.getsource(sys.modules[__name__])
|
|
source += '\nExternalContextMain(%r, %r, %r)\n' %\
|
|
(self._context.name, self._context.broker._listen_addr,
|
|
self._context.key)
|
|
compressed = zlib.compress(source)
|
|
|
|
preamble = str(len(compressed)) + '\n' + compressed
|
|
sock.sendall(preamble)
|
|
assert os.read(self._fd, 3) == 'OK\n'
|
|
|
|
|
|
class SSHStream(LocalStream):
|
|
ssh_path = property(
|
|
lambda self: getattr(self, '_ssh_path', 'ssh'),
|
|
lambda self, path: setattr(self, '_ssh_path', path),
|
|
doc='The path to the SSH binary.')
|
|
|
|
def GetBootCommand(self):
|
|
bits = [self.ssh_path]
|
|
if self._context.username:
|
|
bits += ['-l', self._context.username]
|
|
bits.append(self._context.hostname)
|
|
return bits + map(commands.mkarg, super(SSHStream, self).GetBootCommand())
|
|
|
|
|
|
class Context(object):
|
|
'''
|
|
Represents a remote context regardless of connection method.
|
|
'''
|
|
|
|
def __init__(self, broker, name=None, hostname=None, username=None, key=None,
|
|
parent_addr=None):
|
|
self.broker = broker
|
|
self.name = name
|
|
self.hostname = hostname
|
|
self.username = username
|
|
self.parent_addr = parent_addr
|
|
if key:
|
|
self.key = key
|
|
else:
|
|
self.key = file('/dev/urandom', 'rb').read(16).encode('hex')
|
|
|
|
def GetStream(self):
|
|
return self._stream
|
|
|
|
def SetStream(self, stream):
|
|
self._stream = stream
|
|
return stream
|
|
|
|
def EnqueueAwaitReply(self, handle, deadline, data):
|
|
'''
|
|
Send a message to the given handle and wait for a response with an optional
|
|
timeout. The message contains (reply_handle, data), where reply_handle is
|
|
the handle on which this function expects its reply.
|
|
'''
|
|
Log('%r.EnqueueAwaitReply(%r, %r, %r)', self, handle, deadline, data)
|
|
reply_handle = self._stream.AllocHandle()
|
|
reply_event = threading.Event()
|
|
container = []
|
|
|
|
def _Receive(killed, data):
|
|
Log('%r._Receive(%r, %r)', self, killed, data)
|
|
container.extend([killed, data])
|
|
reply_event.set()
|
|
|
|
self._stream.AddHandleCB(_Receive, reply_handle, persist=False)
|
|
self._stream.Enqueue(CALL_FUNCTION, (False, (reply_handle,) + data))
|
|
|
|
reply_event.wait(deadline)
|
|
if not reply_event.isSet():
|
|
self.Disconnect()
|
|
raise TimeoutError('deadline exceeded.')
|
|
|
|
killed, data = container
|
|
if killed:
|
|
raise StreamError('lost connection during call.')
|
|
|
|
Log('%r._EnqueueAwaitReply(): got reply: %r', self, data)
|
|
return data
|
|
|
|
def CallWithDeadline(self, fn, deadline, *args, **kwargs):
|
|
Log('%r.CallWithDeadline(%r, %r, *%r, **%r)', self, fn, deadline, args,
|
|
kwargs)
|
|
|
|
use_channel = bool(kwargs.pop('use_channel', False))
|
|
if isinstance(fn, types.MethodType) and \
|
|
isinstance(fn.im_self, (type, types.ClassType)):
|
|
fn_class = fn.im_self.__name__
|
|
else:
|
|
fn_class = None
|
|
|
|
call = (use_channel, fn.__module__, fn_class, fn.__name__, args, kwargs)
|
|
success, result = self.EnqueueAwaitReply(CALL_FUNCTION, deadline, call)
|
|
|
|
if success:
|
|
return result
|
|
else:
|
|
exc_obj, traceback = result
|
|
exc_obj.real_traceback = traceback
|
|
raise exc_obj
|
|
|
|
def Call(self, fn, *args, **kwargs):
|
|
return self.CallWithDeadline(fn, None, *args, **kwargs)
|
|
|
|
def __repr__(self):
|
|
bits = map(repr, filter(None, [self.name, self.hostname, self.username]))
|
|
return 'Context(%s)' % ', '.join(bits)
|
|
|
|
|
|
class Broker(object):
|
|
'''
|
|
Context broker: this is responsible for keeping track of contexts, any
|
|
stream that is associated with them, and for I/O multiplexing.
|
|
'''
|
|
|
|
def __init__(self):
|
|
self._dead = False
|
|
self._poller = select.poll()
|
|
self._poller_fd_map = {}
|
|
self._poller_lock = threading.Lock()
|
|
self._contexts = {}
|
|
|
|
self._wake_rfd, self._wake_wfd = os.pipe()
|
|
self._listen_sock = None
|
|
self._poller.register(self._wake_rfd)
|
|
|
|
self._thread = threading.Thread(target=self.Loop, name='Broker')
|
|
self._thread.setDaemon(True)
|
|
self._thread.start()
|
|
|
|
def CreateListener(self, address=None, backlog=30):
|
|
'''
|
|
Create a socket to accept connections from newly spawned contexts.
|
|
Args:
|
|
address: The IPv4 address tuple to listen on.
|
|
backlog: Number of connections to accept while broker thread is busy.
|
|
'''
|
|
self._listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
self._listen_sock.bind(address or ('0.0.0.0', 0))
|
|
self._listen_sock.listen(backlog)
|
|
self._listen_addr = self._listen_sock.getsockname()
|
|
self._poller.register(self._listen_sock)
|
|
|
|
def Register(self, context):
|
|
'''
|
|
Put a context under control of this broker.
|
|
'''
|
|
Log('%r.Register(%r) -> fd=%r', self, context, context.GetStream().fileno())
|
|
self._poller_lock.acquire()
|
|
os.write(self._wake_wfd, ' ')
|
|
try:
|
|
self._contexts[context.name] = context
|
|
self._poller.register(context.GetStream())
|
|
self._poller_fd_map[context.GetStream().fileno()] = context.GetStream()
|
|
finally:
|
|
self._poller_lock.release()
|
|
return context
|
|
|
|
def GetLocal(self, name):
|
|
'''
|
|
Return the named local context, or create it if it doesn't exist.
|
|
|
|
Args:
|
|
name: 'my-local-context'
|
|
Returns:
|
|
econtext.Context
|
|
'''
|
|
context = Context(self, name)
|
|
context.SetStream(LocalStream(context)).Connect()
|
|
return self.Register(context)
|
|
|
|
def GetRemote(self, hostname, username, name=None):
|
|
'''
|
|
Return the named remote context, or create it if it doesn't exist.
|
|
'''
|
|
if name is None:
|
|
name = 'econtext[%s@%s:%d]' %\
|
|
(username, os.getenv('HOSTNAME'), os.getpid())
|
|
|
|
context = Context(self, name, hostname, username)
|
|
context.SetStream(SSHStream(context)).Connect()
|
|
return self.Register(context)
|
|
|
|
def Loop(self):
|
|
'''
|
|
Handle stream events until Finalize() is called.
|
|
'''
|
|
while not self._dead:
|
|
Log('%r.Loop()', self)
|
|
self._poller_lock.acquire()
|
|
self._poller_lock.release()
|
|
for fd, event in self._poller.poll():
|
|
if fd == self._wake_rfd:
|
|
Log('%r: got event on wake_rfd=%d.', self, self._wake_rfd)
|
|
os.read(self._wake_rfd, 1)
|
|
continue
|
|
elif self._listen_sock and fd == self._listen_sock.fileno():
|
|
context = Context(broker)
|
|
Stream.Accept(context, self._listen_sock.accept())
|
|
continue
|
|
|
|
obj = self._poller_fd_map[fd]
|
|
if event & select.POLLHUP:
|
|
Log('%r: POLLHUP for %d, %r', self, fd, obj)
|
|
obj.Disconnect()
|
|
elif event & select.POLLIN:
|
|
Log('%r: POLLIN for %d, %r', self, fd, obj)
|
|
obj.Receive()
|
|
elif event & select.POLLOUT:
|
|
Log('%r: POLLOUT for %d, %r', self, fd, obj)
|
|
if not obj.Transmit(): # If no output buffered, unset POLLOUT.
|
|
self._poller.unregister(obj)
|
|
self._poller.register(obj, select.POLLIN)
|
|
elif event & select.POLLNVAL:
|
|
Log('%r: POLLNVAL for %d, %r', self, fd, obj)
|
|
obj.Disconnect()
|
|
self._poller.unregister(obj)
|
|
|
|
def Finalize(self):
|
|
'''
|
|
Tell all active streams to disconnect.
|
|
'''
|
|
self._dead = True
|
|
self._poller_lock.acquire()
|
|
try:
|
|
for name, context in self._contexts.iteritems():
|
|
context.GetStream().Disconnect()
|
|
finally:
|
|
self._poller_lock.release()
|
|
|
|
def __repr__(self):
|
|
return 'econtext.Broker(<contexts=%s>)' % (self._contexts.keys(),)
|
|
|
|
|
|
def ExternalContextMain(context_name, parent_addr, key):
|
|
syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID)
|
|
syslog.syslog('initializing (parent=%s)' % (os.getenv('SSH_CLIENT'),))
|
|
Log('ExternalContextMain(%r, %r, %r)', context_name, parent_addr, key)
|
|
|
|
os.wait() # Reap the first stage.
|
|
os.dup2(100, 0)
|
|
os.close(100)
|
|
|
|
broker = Broker()
|
|
context = Context(broker, 'parent', parent_addr=parent_addr, key=key)
|
|
|
|
stream = context.SetStream(Stream(context))
|
|
stream.Connect()
|
|
broker.Register(context)
|
|
|
|
for call_info in Channel(stream, CALL_FUNCTION):
|
|
Log('ExternalContextMain(): CALL_FUNCTION %r', call_info)
|
|
reply_handle, mod_name, func_name, args, kwargs = call_info
|
|
fn = getattr(__import__(mod_name), func_name)
|
|
|
|
try:
|
|
stream.Enqueue(reply_handle, (True, fn(*args, **kwargs)))
|
|
except Exception, e:
|
|
stream.Enqueue(reply_handle, (False, (e, traceback.extract_stack())))
|