diff --git a/econtext.py b/econtext.py index 10da048a..9733f792 100755 --- a/econtext.py +++ b/econtext.py @@ -42,21 +42,21 @@ DEBUG = True # class ContextError(Exception): - 'Raised when a problem occurs with a context.' - def __init__(self, fmt, *args): - Exception.__init__(self, fmt % args) + 'Raised when a problem occurs with a context.' + def __init__(self, fmt, *args): + Exception.__init__(self, fmt % args) class ChannelError(ContextError): - 'Raised when a channel dies or has been closed.' + 'Raised when a channel dies or has been closed.' class StreamError(ContextError): - 'Raised when a stream cannot be established.' + 'Raised when a stream cannot be established.' class CorruptMessageError(StreamError): - 'Raised when a corrupt message is received on a stream.' + 'Raised when a corrupt message is received on a stream.' class TimeoutError(StreamError): - 'Raised when a timeout occurs on a stream.' + 'Raised when a timeout occurs on a stream.' # @@ -64,166 +64,167 @@ class TimeoutError(StreamError): # def Log(fmt, *args): - if DEBUG: - sys.stderr.write('%d (%d): %s\n' % (os.getpid(), os.getppid(), - (fmt%args).replace('econtext.', ''))) + if DEBUG: + sys.stderr.write('%d (%d): %s\n' % (os.getpid(), os.getppid(), + (fmt % args).replace('econtext.', ''))) def CreateChild(*args): - ''' - Create a child process whose stdin/stdout is connected to a socket. + ''' + Create a child process whose stdin/stdout is connected to a socket. - Args: - *args: executable name and process arguments. + Args: + *args: executable name and process arguments. - Returns: - pid, sock - ''' - sock1, sock2 = socket.socketpair() - pid = os.fork() - if not pid: - for pair in ((0, sock1), (1, sock2)): - os.dup2(sock2.fileno(), pair[0]) - os.close(pair[1].fileno()) - os.execvp(args[0], args) - raise SystemExit - return pid, sock1 + Returns: + pid, sock + ''' + sock1, sock2 = socket.socketpair() + pid = os.fork() + if not pid: + for pair in ((0, sock1), (1, sock2)): + os.dup2(sock2.fileno(), pair[0]) + os.close(pair[1].fileno()) + os.execvp(args[0], args) + raise SystemExit + return pid, sock1 class PartialFunction(object): - ''' - Partial function implementation. - ''' - def __init__(self, fn, *partial_args): - self.fn = fn - self.partial_args = partial_args + ''' + Partial function implementation. + ''' + def __init__(self, fn, *partial_args): + self.fn = fn + self.partial_args = partial_args - def __call__(self, *args, **kwargs): - return self.fn(*(self.partial_args+args), **kwargs) + def __call__(self, *args, **kwargs): + return self.fn(*(self.partial_args+args), **kwargs) - def __repr__(self): - return 'PartialFunction(%r, *%r)' % (self.fn, self.partial_args) + def __repr__(self): + return 'PartialFunction(%r, *%r)' % (self.fn, self.partial_args) class Channel(object): - def __init__(self, stream, handle): - self._stream = stream - self._handle = handle - self._wake_event = threading.Event() - self._queue_lock = threading.Lock() - self._queue = [] - self._stream.AddHandleCB(self._InternalReceive, handle) + def __init__(self, stream, handle): + self._stream = stream + self._handle = handle + self._wake_event = threading.Event() + self._queue_lock = threading.Lock() + self._queue = [] + self._stream.AddHandleCB(self._InternalReceive, handle) - def _InternalReceive(self, killed, data): - ''' - Callback from the stream object; appends a tuple of - (killed-or-closed, data) to the internal queue and wakes the internal - event. + def _InternalReceive(self, killed, data): + ''' + Callback from the stream object; appends a tuple of + (killed-or-closed, data) to the internal queue and wakes the internal + event. - Args: - # Has the Stream object lost its connection? - killed: bool - data: ( - # Has the remote Channel had Close() called? - bool, - # The object passed to the remote Send() - object - ) - ''' - Log('%r._InternalReceive(%r, %r)', self, killed, data) - self._queue_lock.acquire() - try: - self._queue.append((killed or data[0], killed or data[1])) - self._wake_event.set() - finally: - self._queue_lock.release() + Args: + # Has the Stream object lost its connection? + killed: bool + data: ( + # Has the remote Channel had Close() called? + bool, + # The object passed to the remote Send() + object + ) + ''' + Log('%r._InternalReceive(%r, %r)', self, killed, data) + self._queue_lock.acquire() + try: + self._queue.append((killed or data[0], killed or data[1])) + self._wake_event.set() + finally: + self._queue_lock.release() - def Close(self): - ''' - Indicate this channel is closed to the remote side. - ''' - Log('%r.Close()', self) - self._stream.Enqueue(handle, (True, None)) + def Close(self): + ''' + Indicate this channel is closed to the remote side. + ''' + Log('%r.Close()', self) + self._stream.Enqueue(handle, (True, None)) - def Send(self, data): - ''' - Send the given object to the remote side. - ''' - Log('%r.Send(%r)', self, data) - self._stream.Enqueue(handle, (False, data)) + def Send(self, data): + ''' + Send the given object to the remote side. + ''' + Log('%r.Send(%r)', self, data) + self._stream.Enqueue(handle, (False, data)) - def Receive(self, timeout=None): - ''' - Receive the next object to arrive on this channel, or return if the - optional timeout is reached. + def Receive(self, timeout=None): + ''' + Receive the next object to arrive on this channel, or return if the + optional timeout is reached. - Args: - timeout: float + Args: + timeout: float - Returns: - object - ''' - Log('%r.Receive(%r)', self, timeout) - if not self._queue: - self._wake_event.wait(timeout) - if not self._wake_event.isSet(): - return + Returns: + object + ''' + Log('%r.Receive(%r)', self, timeout) + if not self._queue: + self._wake_event.wait(timeout) + if not self._wake_event.isSet(): + return - self._queue_lock.acquire() - try: - self._wake_event.clear() - Log('%r.Receive() queue is %r', self, self._queue) - closed, data = self._queue.pop(0) - Log('%r.Receive() got closed=%r, data=%r', self, closed, data) - if closed: - raise ChannelError('Channel is closed.') - return data - finally: - self._queue_lock.release() + self._queue_lock.acquire() + try: + self._wake_event.clear() + Log('%r.Receive() queue is %r', self, self._queue) + closed, data = self._queue.pop(0) + Log('%r.Receive() got closed=%r, data=%r', self, closed, data) + if closed: + raise ChannelError('Channel is closed.') + return data + finally: + self._queue_lock.release() - def __iter__(self): - ''' - Return an iterator that yields objects arriving on this channel, until the - channel dies or is closed. - ''' - while True: - try: - yield self.Receive() - except ChannelError: - return + def __iter__(self): + ''' + Return an iterator that yields objects arriving on this channel, until + the channel dies or is closed. + ''' + while True: + try: + yield self.Receive() + except ChannelError: + return - def __repr__(self): - return 'econtext.Channel(%r, %r)' % (self._stream, self._handle) + def __repr__(self): + return 'econtext.Channel(%r, %r)' % (self._stream, self._handle) class SlaveModuleImporter(object): - ''' - Import protocol implementation that fetches modules from the parent process. - ''' - - def __init__(self, context): ''' - Initialise a new instance. - - Args: - context: Context instance this importer will communicate via. + Import protocol implementation that fetches modules from the parent + process. ''' - self._context = context - def find_module(self, fullname, path=None): - if not imp.find_module(fullname): - return self + def __init__(self, context): + ''' + Initialise a new instance. - def load_module(self, fullname): - kind, data = self._context.EnqueueAwaitReply(GET_MODULE, fullname) + Args: + context: Context instance this importer will communicate via. + ''' + self._context = context - def GetModule(cls, killed, fullname): - Log('%r.GetModule(%r, %r)', cls, killed, fullname) - if killed: - return - if fullname in sys.modules: - pass - GetModule = classmethod(GetModule) + def find_module(self, fullname, path=None): + if not imp.find_module(fullname): + return self + + def load_module(self, fullname): + kind, data = self._context.EnqueueAwaitReply(GET_MODULE, fullname) + + def GetModule(cls, killed, fullname): + Log('%r.GetModule(%r, %r)', cls, killed, fullname) + if killed: + return + if fullname in sys.modules: + pass + GetModule = classmethod(GetModule) # @@ -231,582 +232,582 @@ class SlaveModuleImporter(object): # class Stream(object): - def __init__(self, context): - ''' - Initialize a new Stream instance. + def __init__(self, context): + ''' + Initialize a new Stream instance. - Args: - context: econtext.Context - ''' - self._context = context + Args: + context: econtext.Context + ''' + self._context = context - self._input_buf = self._output_buf = '' - self._input_buf_lock = threading.Lock() - self._output_buf_lock = threading.Lock() - self._rhmac = hmac.new(context.key, digestmod=sha.new) - self._whmac = self._rhmac.copy() + self._input_buf = self._output_buf = '' + self._input_buf_lock = threading.Lock() + self._output_buf_lock = threading.Lock() + self._rhmac = hmac.new(context.key, digestmod=sha.new) + self._whmac = self._rhmac.copy() - self._last_handle = 0 - self._handle_map = {} - self._handle_lock = threading.Lock() + self._last_handle = 0 + self._handle_map = {} + self._handle_lock = threading.Lock() - self._func_refs = {} - self._func_ref_lock = threading.Lock() + self._func_refs = {} + self._func_ref_lock = threading.Lock() - self._pickler_file = cStringIO.StringIO() - self._pickler = cPickle.Pickler(self._pickler_file) - self._pickler.persistent_id = self._CheckFunctionPerID + self._pickler_file = cStringIO.StringIO() + self._pickler = cPickle.Pickler(self._pickler_file) + self._pickler.persistent_id = self._CheckFunctionPerID - self._unpickler_file = cStringIO.StringIO() - self._unpickler = cPickle.Unpickler(self._unpickler_file) - self._unpickler.persistent_load = self._LoadFunctionFromPerID + self._unpickler_file = cStringIO.StringIO() + self._unpickler = cPickle.Unpickler(self._unpickler_file) + self._unpickler.persistent_load = self._LoadFunctionFromPerID - def Pickle(self, obj): - ''' - Serialize the given object using the pickler. + def Pickle(self, obj): + ''' + Serialize the given object using the pickler. - Args: - obj: object + Args: + obj: object - Returns: - str - ''' - self._pickler.dump(obj) - data = self._pickler_file.getvalue() - self._pickler_file.seek(0) - self._pickler_file.truncate(0) - return data + Returns: + str + ''' + self._pickler.dump(obj) + data = self._pickler_file.getvalue() + self._pickler_file.seek(0) + self._pickler_file.truncate(0) + return data - def Unpickle(self, data): - ''' - Unserialize the given string using the unpickler. + def Unpickle(self, data): + ''' + Unserialize the given string using the unpickler. - Args: - data: str + Args: + data: str - Returns: - object - ''' - Log('%r.Unpickle(%r)', self, data) - self._unpickler_file.write(data) - self._unpickler_file.seek(0) - data = self._unpickler.load() - self._unpickler_file.seek(0) - self._unpickler_file.truncate(0) - return data + Returns: + object + ''' + Log('%r.Unpickle(%r)', self, data) + self._unpickler_file.write(data) + self._unpickler_file.seek(0) + data = self._unpickler.load() + self._unpickler_file.seek(0) + self._unpickler_file.truncate(0) + return data - def _CheckFunctionPerID(self, obj): - ''' - Return None or a persistent ID for an object. - Please see the cPickle documentation. + def _CheckFunctionPerID(self, obj): + ''' + Return None or a persistent ID for an object. + Please see the cPickle documentation. - Args: - obj: object + Args: + obj: object - Returns: - str - ''' - if isinstance(obj, (types.FunctionType, types.MethodType)): - pid = 'FUNC:' + repr(obj) - self._func_refs[per_id] = obj - return pid + Returns: + str + ''' + if isinstance(obj, (types.FunctionType, types.MethodType)): + pid = 'FUNC:' + repr(obj) + self._func_refs[per_id] = obj + return pid - def _LoadFunctionFromPerID(self, pid): - ''' - Load an object from a persistent ID. - Please see the cPickle documentation. + def _LoadFunctionFromPerID(self, pid): + ''' + Load an object from a persistent ID. + Please see the cPickle documentation. - Args: - pid: str + Args: + pid: str - Returns: - object - ''' - if not pid.startswith('FUNC:'): - raise CorruptMessageError('unrecognized persistent ID received: %r', pid) - return PartialFunction(self._CallPersistentWhatsit, pid) + Returns: + object + ''' + if not pid.startswith('FUNC:'): + raise CorruptMessageError('unrecognized persistent ID received: %r', pid) + return PartialFunction(self._CallPersistentWhatsit, pid) - def AllocHandle(self): - ''' - Allocate a unique handle for this stream. + def AllocHandle(self): + ''' + Allocate a unique handle for this stream. - Returns: - long - ''' - self._handle_lock.acquire() - try: - self._last_handle += 1L - finally: - self._handle_lock.release() - return self._last_handle + Returns: + long + ''' + self._handle_lock.acquire() + try: + self._last_handle += 1L + finally: + self._handle_lock.release() + return self._last_handle - def AddHandleCB(self, fn, handle, persist=True): - ''' - Invoke a function for all messages with the given handle. + def AddHandleCB(self, fn, handle, persist=True): + ''' + Invoke a function for all messages with the given handle. - Args: - fn: callable - handle: long - persist: False to only receive a single message. - ''' - Log('%r.AddHandleCB(%r, %r, persist=%r)', self, fn, handle, persist) - self._handle_lock.acquire() - try: - self._handle_map[handle] = persist, fn - finally: - self._handle_lock.release() + Args: + fn: callable + handle: long + persist: False to only receive a single message. + ''' + Log('%r.AddHandleCB(%r, %r, persist=%r)', self, fn, handle, persist) + self._handle_lock.acquire() + try: + self._handle_map[handle] = persist, fn + finally: + self._handle_lock.release() - def Receive(self): - ''' - Handle the next complete message on the stream. Raise CorruptMessageError - or IOError on failure. - ''' - Log('%r.Receive()', self) + def Receive(self): + ''' + Handle the next complete message on the stream. Raise + CorruptMessageError or IOError on failure. + ''' + Log('%r.Receive()', self) - self._input_buf += os.read(self._rfd, 4096) - if len(self._input_buf) < 24: - return + self._input_buf += os.read(self._rfd, 4096) + if len(self._input_buf) < 24: + return - msg_mac = self._input_buf[:20] - msg_len = struct.unpack('>L', self._input_buf[20:24])[0] - if len(self._input_buf) < msg_len-24: - return + msg_mac = self._input_buf[:20] + msg_len = struct.unpack('>L', self._input_buf[20:24])[0] + if len(self._input_buf) < msg_len-24: + return - self._rhmac.update(self._input_buf[20:msg_len+24]) - expected_mac = self._rhmac.digest() - if msg_mac != expected_mac: - raise CorruptMessageError('%r got invalid MAC: expected %r, got %r', - self, msg_mac.encode('hex'), - expected_mac.encode('hex')) + self._rhmac.update(self._input_buf[20:msg_len+24]) + expected_mac = self._rhmac.digest() + if msg_mac != expected_mac: + raise CorruptMessageError('%r got invalid MAC: expected %r, got %r', + self, msg_mac.encode('hex'), + expected_mac.encode('hex')) - try: - handle, data = self.Unpickle(self._input_buf[24:msg_len+24]) - self._input_buf = self._input_buf[msg_len+4:] - handle = long(handle) + try: + handle, data = self.Unpickle(self._input_buf[24:msg_len+24]) + self._input_buf = self._input_buf[msg_len+4:] + handle = long(handle) - Log('%r.Receive(): decoded handle=%r; data=%r', self, handle, data) - persist, fn = self._handle_map[handle] - if not persist: - del self._handle_map[handle] - except KeyError, ex: - raise CorruptMessageError('%r got invalid handle: %r', self, handle) - except (TypeError, ValueError), ex: - raise CorruptMessageError('%r got invalid message: %s', self, ex) + Log('%r.Receive(): decoded handle=%r; data=%r', self, handle, data) + persist, fn = self._handle_map[handle] + if not persist: + del self._handle_map[handle] + except KeyError, ex: + raise CorruptMessageError('%r got invalid handle: %r', self, handle) + except (TypeError, ValueError), ex: + raise CorruptMessageError('%r got invalid message: %s', self, ex) - fn(False, data) + fn(False, data) - def Transmit(self): - ''' - Transmit buffered messages. + def Transmit(self): + ''' + Transmit buffered messages. - Returns: - bool: more data left in bufer? + Returns: + bool: more data left in bufer? - Raises: - IOError - ''' - Log('%r.Transmit()', self) - written = os.write(self._fd, self._output_buf[:4096]) - self._output_buf = self._output_buf[written:] - return bool(self._output_buf) + Raises: + IOError + ''' + Log('%r.Transmit()', self) + written = os.write(self._fd, self._output_buf[:4096]) + self._output_buf = self._output_buf[written:] + return bool(self._output_buf) - def Enqueue(self, handle, obj): - ''' - Serialize an object, send it to the given handle, and tell our context's - broker we have output. + def Enqueue(self, handle, obj): + ''' + Serialize an object, send it to the given handle, and tell our context's + broker we have output. - Args: - handle: long - obj: object - ''' - Log('%r.Enqueue(%r, %r)', self, handle, obj) + Args: + handle: long + obj: object + ''' + Log('%r.Enqueue(%r, %r)', self, handle, obj) - self._output_buf_lock.acquire() - try: - encoded = self.Pickle((handle, obj)) - msg = struct.pack('>L', len(encoded)) + encoded - self._whmac.update(msg) - self._output_buf += self._whmac.digest() + msg - finally: - self._output_buf_lock.release() - self._context.broker.Register(self._context) + self._output_buf_lock.acquire() + try: + encoded = self.Pickle((handle, obj)) + msg = struct.pack('>L', len(encoded)) + encoded + self._whmac.update(msg) + self._output_buf += self._whmac.digest() + msg + finally: + self._output_buf_lock.release() + self._context.broker.Register(self._context) - def Disconnect(self): - ''' - Close our associated file descriptor and tell any registered callbacks - that the connection has been destroyed. - ''' - Log('%r.Disconnect()', self) - try: - os.close(self._fd) - except OSError, e: - Log('%r.Disconnect(): did not close fd %s: %s', self, self._fd, e) + def Disconnect(self): + ''' + Close our associated file descriptor and tell any registered callbacks + that the connection has been destroyed. + ''' + Log('%r.Disconnect()', self) + try: + os.close(self._fd) + except OSError, e: + Log('%r.Disconnect(): did not close fd %s: %s', self, self._fd, e) - for handle, (persist, fn) in self._handle_map.iteritems(): - Log('%r.Disconnect(): killing stale callback handle=%r; fn=%r', - self, handle, fn) - fn(True, None) + for handle, (persist, fn) in self._handle_map.iteritems(): + Log('%r.Disconnect(): killing stale callback handle=%r; fn=%r', + self, handle, fn) + fn(True, None) - @classmethod - def Accept(cls, context, sock): - ''' - - ''' - stream = cls(context) - context.SetStream() - broker.Register(context) + @classmethod + def Accept(cls, context, sock): + ''' + + ''' + stream = cls(context) + context.SetStream() + broker.Register(context) - def Connect(self): - ''' - Connect to a Broker at the address specified in our associated Context. - ''' + def Connect(self): + ''' + Connect to a Broker at the address specified in our associated Context. + ''' - Log('%r.Connect()', self) - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._fd = sock.fileno() - sock.connect(self._context.parent_addr) - self.Enqueue(0, self._context.name) + Log('%r.Connect()', self) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._fd = sock.fileno() + sock.connect(self._context.parent_addr) + self.Enqueue(0, self._context.name) - def fileno(self): - return self._fd + def fileno(self): + return self._fd - def __repr__(self): - return 'econtext.%s()' %\ - (self.__class__.__name__, self._context) + def __repr__(self): + return 'econtext.%s()' %\ + (self.__class__.__name__, self._context) class LocalStream(Stream): - ''' - Base for streams capable of starting new slaves. - ''' - - python_path = property( - lambda self: getattr(self, '_python_path', sys.executable), - lambda self, path: setattr(self, '_python_path', path), - doc='The path to the remote Python interpreter.') - - def __init__(self, context): - super(LocalStream, self).__init__(context) - self._permitted_modules = {} - self._unpickler.find_global = self._FindGlobal - self.AddHandleCB(SlaveModuleImporter.GetModule, handle=GET_MODULE) - - def _FindGlobal(self, module_name, class_name): ''' - See the cPickle documentation: given a module and class name, determine - whether class referred to is safe for unpickling. - - Args: - module_name: str - class_name: str - - Returns: - classobj or type + Base for streams capable of starting new slaves. ''' - if module_name not in self._permitted_modules: - raise StreamError('context %r attempted to unpickle %r in module %r', - self._context, class_name, module_name) - return getattr(sys.modules[module_name], class_name) - def AllowModule(self, module_name): - ''' - Add the given module to the list of permitted modules. + python_path = property( + lambda self: getattr(self, '_python_path', sys.executable), + lambda self, path: setattr(self, '_python_path', path), + doc='The path to the remote Python interpreter.') - Args: - module_name: str - ''' - self._permitted_modules.add(module_name) + def __init__(self, context): + super(LocalStream, self).__init__(context) + self._permitted_modules = {} + self._unpickler.find_global = self._FindGlobal + self.AddHandleCB(SlaveModuleImporter.GetModule, handle=GET_MODULE) - # Hexed and passed to 'python -c'. It forks, dups 0->100, creates a pipe, - # then execs a new interpreter with a custom argv. CONTEXT_NAME is replaced - # with the context name. Optimized for source size. - def _FirstStage(): - import os,sys,zlib - R,W=os.pipe() - if os.fork(): - os.dup2(0,100) - os.dup2(R,0) - os.close(R) - os.close(W) - os.execv(sys.executable,(CONTEXT_NAME,)) - else: - os.fdopen(W,'wb',0).write(zlib.decompress(sys.stdin.read(input()))) - print 'OK' - sys.exit(0) + def _FindGlobal(self, module_name, class_name): + ''' + See the cPickle documentation: given a module and class name, determine + whether class referred to is safe for unpickling. - def GetBootCommand(self): - source = inspect.getsource(self._FirstStage) - source = textwrap.dedent('\n'.join(source.strip().split('\n')[1:])) - source = source.replace(' ', '\t') - source = source.replace('CONTEXT_NAME', repr(self._context.name)) - return [ self.python_path, '-c', - 'exec "%s".decode("hex")' % (source.encode('hex'),) ] + Args: + module_name: str + class_name: str - def __repr__(self): - return '%s(%s)' % (self.__class__.__name__, self._context) + Returns: + classobj or type + ''' + if module_name not in self._permitted_modules: + raise StreamError('context %r attempted to unpickle %r in module %r', + self._context, class_name, module_name) + return getattr(sys.modules[module_name], class_name) - def Connect(self): - Log('%r.Connect()', self) - pid, sock = CreateChild(*self.GetBootCommand()) - self._fd = sock.fileno() - Log('%r.Connect(): child process stdin/stdout=%r', self, self._fd) + def AllowModule(self, module_name): + ''' + Add the given module to the list of permitted modules. - source = inspect.getsource(sys.modules[__name__]) - source += '\nExternalContextMain(%r, %r, %r)\n' %\ - (self._context.name, self._context.broker._listen_addr, - self._context.key) - compressed = zlib.compress(source) + Args: + module_name: str + ''' + self._permitted_modules.add(module_name) - preamble = str(len(compressed)) + '\n' + compressed - sock.sendall(preamble) - assert os.read(self._fd, 3) == 'OK\n' + # Hexed and passed to 'python -c'. It forks, dups 0->100, creates a pipe, + # then execs a new interpreter with a custom argv. CONTEXT_NAME is replaced + # with the context name. Optimized for source size. + def _FirstStage(): + import os,sys,zlib + R,W=os.pipe() + if os.fork(): + os.dup2(0,100) + os.dup2(R,0) + os.close(R) + os.close(W) + os.execv(sys.executable,(CONTEXT_NAME,)) + else: + os.fdopen(W,'wb',0).write(zlib.decompress(sys.stdin.read(input()))) + print 'OK' + sys.exit(0) + + def GetBootCommand(self): + source = inspect.getsource(self._FirstStage) + source = textwrap.dedent('\n'.join(source.strip().split('\n')[1:])) + source = source.replace(' ', '\t') + source = source.replace('CONTEXT_NAME', repr(self._context.name)) + return [ self.python_path, '-c', + 'exec "%s".decode("hex")' % (source.encode('hex'),) ] + + def __repr__(self): + return '%s(%s)' % (self.__class__.__name__, self._context) + + def Connect(self): + Log('%r.Connect()', self) + pid, sock = CreateChild(*self.GetBootCommand()) + self._fd = sock.fileno() + Log('%r.Connect(): child process stdin/stdout=%r', self, self._fd) + + source = inspect.getsource(sys.modules[__name__]) + source += '\nExternalContextMain(%r, %r, %r)\n' %\ + (self._context.name, self._context.broker._listen_addr, + self._context.key) + compressed = zlib.compress(source) + + preamble = str(len(compressed)) + '\n' + compressed + sock.sendall(preamble) + assert os.read(self._fd, 3) == 'OK\n' class SSHStream(LocalStream): - ssh_path = property( - lambda self: getattr(self, '_ssh_path', 'ssh'), - lambda self, path: setattr(self, '_ssh_path', path), - doc='The path to the SSH binary.') + ssh_path = property( + lambda self: getattr(self, '_ssh_path', 'ssh'), + lambda self, path: setattr(self, '_ssh_path', path), + doc='The path to the SSH binary.') - def GetBootCommand(self): - bits = [self.ssh_path] - if self._context.username: - bits += ['-l', self._context.username] - bits.append(self._context.hostname) - return bits + map(commands.mkarg, super(SSHStream, self).GetBootCommand()) + def GetBootCommand(self): + bits = [self.ssh_path] + if self._context.username: + bits += ['-l', self._context.username] + bits.append(self._context.hostname) + return bits + map(commands.mkarg, super(SSHStream, self).GetBootCommand()) class Context(object): - ''' - Represents a remote context regardless of connection method. - ''' - - def __init__(self, broker, name=None, hostname=None, username=None, key=None, - parent_addr=None): - self.broker = broker - self.name = name - self.hostname = hostname - self.username = username - self.parent_addr = parent_addr - if key: - self.key = key - else: - self.key = file('/dev/urandom', 'rb').read(16).encode('hex') - - def GetStream(self): - return self._stream - - def SetStream(self, stream): - self._stream = stream - return stream - - def EnqueueAwaitReply(self, handle, deadline, data): ''' - Send a message to the given handle and wait for a response with an optional - timeout. The message contains (reply_handle, data), where reply_handle is - the handle on which this function expects its reply. + Represents a remote context regardless of connection method. ''' - Log('%r.EnqueueAwaitReply(%r, %r, %r)', self, handle, deadline, data) - reply_handle = self._stream.AllocHandle() - reply_event = threading.Event() - container = [] - def _Receive(killed, data): - Log('%r._Receive(%r, %r)', self, killed, data) - container.extend([killed, data]) - reply_event.set() + def __init__(self, broker, name=None, hostname=None, username=None, key=None, + parent_addr=None): + self.broker = broker + self.name = name + self.hostname = hostname + self.username = username + self.parent_addr = parent_addr + if key: + self.key = key + else: + self.key = file('/dev/urandom', 'rb').read(16).encode('hex') - self._stream.AddHandleCB(_Receive, reply_handle, persist=False) - self._stream.Enqueue(CALL_FUNCTION, (False, (reply_handle,) + data)) + def GetStream(self): + return self._stream - reply_event.wait(deadline) - if not reply_event.isSet(): - self.Disconnect() - raise TimeoutError('deadline exceeded.') + def SetStream(self, stream): + self._stream = stream + return stream - killed, data = container - if killed: - raise StreamError('lost connection during call.') + def EnqueueAwaitReply(self, handle, deadline, data): + ''' + Send a message to the given handle and wait for a response with an + optional timeout. The message contains (reply_handle, data), where + reply_handle is the handle on which this function expects its reply. + ''' + Log('%r.EnqueueAwaitReply(%r, %r, %r)', self, handle, deadline, data) + reply_handle = self._stream.AllocHandle() + reply_event = threading.Event() + container = [] - Log('%r._EnqueueAwaitReply(): got reply: %r', self, data) - return data + def _Receive(killed, data): + Log('%r._Receive(%r, %r)', self, killed, data) + container.extend([killed, data]) + reply_event.set() - def CallWithDeadline(self, fn, deadline, *args, **kwargs): - Log('%r.CallWithDeadline(%r, %r, *%r, **%r)', self, fn, deadline, args, - kwargs) + self._stream.AddHandleCB(_Receive, reply_handle, persist=False) + self._stream.Enqueue(CALL_FUNCTION, (False, (reply_handle,) + data)) - use_channel = bool(kwargs.pop('use_channel', False)) - if isinstance(fn, types.MethodType) and \ - isinstance(fn.im_self, (type, types.ClassType)): - fn_class = fn.im_self.__name__ - else: - fn_class = None + reply_event.wait(deadline) + if not reply_event.isSet(): + self.Disconnect() + raise TimeoutError('deadline exceeded.') - call = (use_channel, fn.__module__, fn_class, fn.__name__, args, kwargs) - success, result = self.EnqueueAwaitReply(CALL_FUNCTION, deadline, call) + killed, data = container + if killed: + raise StreamError('lost connection during call.') - if success: - return result - else: - exc_obj, traceback = result - exc_obj.real_traceback = traceback - raise exc_obj + Log('%r._EnqueueAwaitReply(): got reply: %r', self, data) + return data - def Call(self, fn, *args, **kwargs): - return self.CallWithDeadline(fn, None, *args, **kwargs) + def CallWithDeadline(self, fn, deadline, *args, **kwargs): + Log('%r.CallWithDeadline(%r, %r, *%r, **%r)', self, fn, deadline, args, + kwargs) - def __repr__(self): - bits = map(repr, filter(None, [self.name, self.hostname, self.username])) - return 'Context(%s)' % ', '.join(bits) + use_channel = bool(kwargs.pop('use_channel', False)) + if isinstance(fn, types.MethodType) and \ + isinstance(fn.im_self, (type, types.ClassType)): + fn_class = fn.im_self.__name__ + else: + fn_class = None + + call = (use_channel, fn.__module__, fn_class, fn.__name__, args, kwargs) + success, result = self.EnqueueAwaitReply(CALL_FUNCTION, deadline, call) + + if success: + return result + else: + exc_obj, traceback = result + exc_obj.real_traceback = traceback + raise exc_obj + + def Call(self, fn, *args, **kwargs): + return self.CallWithDeadline(fn, None, *args, **kwargs) + + def __repr__(self): + bits = map(repr, filter(None, [self.name, self.hostname, self.username])) + return 'Context(%s)' % ', '.join(bits) class Broker(object): - ''' - Context broker: this is responsible for keeping track of contexts, any - stream that is associated with them, and for I/O multiplexing. - ''' - - def __init__(self): - self._dead = False - self._poller = select.poll() - self._poller_fd_map = {} - self._poller_lock = threading.Lock() - self._contexts = {} - - self._wake_rfd, self._wake_wfd = os.pipe() - self._listen_sock = None - self._poller.register(self._wake_rfd) - - self._thread = threading.Thread(target=self.Loop, name='Broker') - self._thread.setDaemon(True) - self._thread.start() - - def CreateListener(self, address=None, backlog=30): ''' - Create a socket to accept connections from newly spawned contexts. - Args: - address: The IPv4 address tuple to listen on. - backlog: Number of connections to accept while broker thread is busy. + Context broker: this is responsible for keeping track of contexts, any + stream that is associated with them, and for I/O multiplexing. ''' - self._listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._listen_sock.bind(address or ('0.0.0.0', 0)) - self._listen_sock.listen(backlog) - self._listen_addr = self._listen_sock.getsockname() - self._poller.register(self._listen_sock) - def Register(self, context): - ''' - Put a context under control of this broker. - ''' - Log('%r.Register(%r) -> fd=%r', self, context, context.GetStream().fileno()) - self._poller_lock.acquire() - os.write(self._wake_wfd, ' ') - try: - self._contexts[context.name] = context - self._poller.register(context.GetStream()) - self._poller_fd_map[context.GetStream().fileno()] = context.GetStream() - finally: - self._poller_lock.release() - return context + def __init__(self): + self._dead = False + self._poller = select.poll() + self._poller_fd_map = {} + self._poller_lock = threading.Lock() + self._contexts = {} - def GetLocal(self, name): - ''' - Return the named local context, or create it if it doesn't exist. + self._wake_rfd, self._wake_wfd = os.pipe() + self._listen_sock = None + self._poller.register(self._wake_rfd) - Args: - name: 'my-local-context' - Returns: - econtext.Context - ''' - context = Context(self, name) - context.SetStream(LocalStream(context)).Connect() - return self.Register(context) + self._thread = threading.Thread(target=self.Loop, name='Broker') + self._thread.setDaemon(True) + self._thread.start() - def GetRemote(self, hostname, username, name=None): - ''' - Return the named remote context, or create it if it doesn't exist. - ''' - if name is None: - name = 'econtext[%s@%s:%d]' %\ - (username, os.getenv('HOSTNAME'), os.getpid()) + def CreateListener(self, address=None, backlog=30): + ''' + Create a socket to accept connections from newly spawned contexts. + Args: + address: The IPv4 address tuple to listen on. + backlog: Number of connections to accept while broker thread is busy. + ''' + self._listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._listen_sock.bind(address or ('0.0.0.0', 0)) + self._listen_sock.listen(backlog) + self._listen_addr = self._listen_sock.getsockname() + self._poller.register(self._listen_sock) - context = Context(self, name, hostname, username) - context.SetStream(SSHStream(context)).Connect() - return self.Register(context) + def Register(self, context): + ''' + Put a context under control of this broker. + ''' + Log('%r.Register(%r) -> fd=%r', self, context, context.GetStream().fileno()) + self._poller_lock.acquire() + os.write(self._wake_wfd, ' ') + try: + self._contexts[context.name] = context + self._poller.register(context.GetStream()) + self._poller_fd_map[context.GetStream().fileno()] = context.GetStream() + finally: + self._poller_lock.release() + return context - def Loop(self): - ''' - Handle stream events until Finalize() is called. - ''' - while not self._dead: - Log('%r.Loop()', self) - self._poller_lock.acquire() - self._poller_lock.release() - for fd, event in self._poller.poll(): - if fd == self._wake_rfd: - Log('%r: got event on wake_rfd=%d.', self, self._wake_rfd) - os.read(self._wake_rfd, 1) - continue - elif self._listen_sock and fd == self._listen_sock.fileno(): - context = Context(broker) - Stream.Accept(context, self._listen_sock.accept()) - continue + def GetLocal(self, name): + ''' + Return the named local context, or create it if it doesn't exist. - obj = self._poller_fd_map[fd] - if event & select.POLLHUP: - Log('%r: POLLHUP for %d, %r', self, fd, obj) - obj.Disconnect() - elif event & select.POLLIN: - Log('%r: POLLIN for %d, %r', self, fd, obj) - obj.Receive() - elif event & select.POLLOUT: - Log('%r: POLLOUT for %d, %r', self, fd, obj) - if not obj.Transmit(): # If no output buffered, unset POLLOUT. - self._poller.unregister(obj) - self._poller.register(obj, select.POLLIN) - elif event & select.POLLNVAL: - Log('%r: POLLNVAL for %d, %r', self, fd, obj) - obj.Disconnect() - self._poller.unregister(obj) + Args: + name: 'my-local-context' + Returns: + econtext.Context + ''' + context = Context(self, name) + context.SetStream(LocalStream(context)).Connect() + return self.Register(context) - def Finalize(self): - ''' - Tell all active streams to disconnect. - ''' - self._dead = True - self._poller_lock.acquire() - try: - for name, context in self._contexts.iteritems(): - context.GetStream().Disconnect() - finally: - self._poller_lock.release() + def GetRemote(self, hostname, username, name=None): + ''' + Return the named remote context, or create it if it doesn't exist. + ''' + if name is None: + name = 'econtext[%s@%s:%d]' %\ + (username, os.getenv('HOSTNAME'), os.getpid()) - def __repr__(self): - return 'econtext.Broker()' % (self._contexts.keys(),) + context = Context(self, name, hostname, username) + context.SetStream(SSHStream(context)).Connect() + return self.Register(context) + + def Loop(self): + ''' + Handle stream events until Finalize() is called. + ''' + while not self._dead: + Log('%r.Loop()', self) + self._poller_lock.acquire() + self._poller_lock.release() + for fd, event in self._poller.poll(): + if fd == self._wake_rfd: + Log('%r: got event on wake_rfd=%d.', self, self._wake_rfd) + os.read(self._wake_rfd, 1) + continue + elif self._listen_sock and fd == self._listen_sock.fileno(): + context = Context(broker) + Stream.Accept(context, self._listen_sock.accept()) + continue + + obj = self._poller_fd_map[fd] + if event & select.POLLHUP: + Log('%r: POLLHUP for %d, %r', self, fd, obj) + obj.Disconnect() + elif event & select.POLLIN: + Log('%r: POLLIN for %d, %r', self, fd, obj) + obj.Receive() + elif event & select.POLLOUT: + Log('%r: POLLOUT for %d, %r', self, fd, obj) + if not obj.Transmit(): # If no output buffered, unset POLLOUT. + self._poller.unregister(obj) + self._poller.register(obj, select.POLLIN) + elif event & select.POLLNVAL: + Log('%r: POLLNVAL for %d, %r', self, fd, obj) + obj.Disconnect() + self._poller.unregister(obj) + + def Finalize(self): + ''' + Tell all active streams to disconnect. + ''' + self._dead = True + self._poller_lock.acquire() + try: + for name, context in self._contexts.iteritems(): + context.GetStream().Disconnect() + finally: + self._poller_lock.release() + + def __repr__(self): + return 'econtext.Broker()' % (self._contexts.keys(),) def ExternalContextMain(context_name, parent_addr, key): - syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID) - syslog.syslog('initializing (parent=%s)' % (os.getenv('SSH_CLIENT'),)) - Log('ExternalContextMain(%r, %r, %r)', context_name, parent_addr, key) + syslog.openlog('%s:%s' % (getpass.getuser(), context_name), syslog.LOG_PID) + syslog.syslog('initializing (parent=%s)' % (os.getenv('SSH_CLIENT'),)) + Log('ExternalContextMain(%r, %r, %r)', context_name, parent_addr, key) - os.wait() # Reap the first stage. - os.dup2(100, 0) - os.close(100) + os.wait() # Reap the first stage. + os.dup2(100, 0) + os.close(100) - broker = Broker() - context = Context(broker, 'parent', parent_addr=parent_addr, key=key) + broker = Broker() + context = Context(broker, 'parent', parent_addr=parent_addr, key=key) - stream = context.SetStream(Stream(context)) - stream.Connect() - broker.Register(context) + stream = context.SetStream(Stream(context)) + stream.Connect() + broker.Register(context) - for call_info in Channel(stream, CALL_FUNCTION): - Log('ExternalContextMain(): CALL_FUNCTION %r', call_info) - reply_handle, mod_name, func_name, args, kwargs = call_info - fn = getattr(__import__(mod_name), func_name) + for call_info in Channel(stream, CALL_FUNCTION): + Log('ExternalContextMain(): CALL_FUNCTION %r', call_info) + reply_handle, mod_name, func_name, args, kwargs = call_info + fn = getattr(__import__(mod_name), func_name) - try: - stream.Enqueue(reply_handle, (True, fn(*args, **kwargs))) - except Exception, e: - stream.Enqueue(reply_handle, (False, (e, traceback.extract_stack()))) + try: + stream.Enqueue(reply_handle, (True, fn(*args, **kwargs))) + except Exception, e: + stream.Enqueue(reply_handle, (False, (e, traceback.extract_stack()))) diff --git a/econtext_test.py b/econtext_test.py index 513ae6dc..4a1c2215 100755 --- a/econtext_test.py +++ b/econtext_test.py @@ -2,25 +2,25 @@ """ def DoStuff(): - import time - file('/tmp/foobar', 'w').write(time.ctime()) + import time + file('/tmp/foobar', 'w').write(time.ctime()) localhost = pyrpc.SSHConnection('localhost') localhost.Connect() try: - ret = localhost.Evaluate(DoStuff) + ret = localhost.Evaluate(DoStuff) except OSError, e: - + Tests - - Test Channel objects to destruction. - - External contexts sometimes don't appear to die during a crash. This needs - tested to destruction. - - Test reconnecting to previously idle-killed contexts. - - Test remote context longevity to destruction. They should never stay - around after parent dies or disconnects. + - Test Channel objects to destruction. + - External contexts sometimes don't appear to die during a crash. This needs + tested to destruction. + - Test reconnecting to previously idle-killed contexts. + - Test remote context longevity to destruction. They should never stay + around after parent dies or disconnects. """ @@ -35,53 +35,53 @@ import econtext # class GetModuleImportsTestCase(unittest.TestCase): - # This must be kept in sync with our actual imports. - IMPORTS = [ - ('econtext', 'econtext'), - ('sys', 'PythonSystemModule'), - ('sys', 'sys'), - ('unittest', 'unittest') - ] + # This must be kept in sync with our actual imports. + IMPORTS = [ + ('econtext', 'econtext'), + ('sys', 'PythonSystemModule'), + ('sys', 'sys'), + ('unittest', 'unittest') + ] - def setUp(self): - global PythonSystemModule - import sys as PythonSystemModule + def setUp(self): + global PythonSystemModule + import sys as PythonSystemModule - def tearDown(Self): - global PythonSystemModule - del PythonSystemModule + def tearDown(Self): + global PythonSystemModule + del PythonSystemModule - def testImports(self): - self.assertEqual(set(self.IMPORTS), - set(econtext.GetModuleImports(sys.modules[__name__]))) + def testImports(self): + self.assertEqual(set(self.IMPORTS), + set(econtext.GetModuleImports(sys.modules[__name__]))) class BuildPartialModuleTestCase(unittest.TestCase): - def testNullModule(self): - """Pass empty sequences; result should contain nothing but a hash bang line - and whitespace.""" + def testNullModule(self): + """Pass empty sequences; result should contain nothing but a hash bang line + and whitespace.""" - lines = econtext.BuildPartialModule([], []).strip().split('\n') + lines = econtext.BuildPartialModule([], []).strip().split('\n') - self.assert_(lines[0].startswith('#!')) - self.assert_('import' not in lines[1:]) + self.assert_(lines[0].startswith('#!')) + self.assert_('import' not in lines[1:]) - def testPassingMethodTypeFails(self): - """Pass an instance method and ensure we refuse it.""" + def testPassingMethodTypeFails(self): + """Pass an instance method and ensure we refuse it.""" - self.assertRaises(TypeError, econtext.BuildPartialModule, - [self.testPassingMethodTypeFails], []) + self.assertRaises(TypeError, econtext.BuildPartialModule, + [self.testPassingMethodTypeFails], []) - @staticmethod - def exampleStaticMethod(): - pass + @staticmethod + def exampleStaticMethod(): + pass - def testStaticMethodGetsUnwrapped(self): - "Ensure that @staticmethod decorators are stripped." + def testStaticMethodGetsUnwrapped(self): + "Ensure that @staticmethod decorators are stripped." - dct = {} - exec econtext.BuildPartialModule([self.exampleStaticMethod], []) in dct - self.assertFalse(isinstance(dct['exampleStaticMethod'], staticmethod)) + dct = {} + exec econtext.BuildPartialModule([self.exampleStaticMethod], []) in dct + self.assertFalse(isinstance(dct['exampleStaticMethod'], staticmethod)) @@ -90,33 +90,33 @@ class BuildPartialModuleTestCase(unittest.TestCase): # class StreamTestBase: - """This defines rules that should remain true for all Stream subclasses. We - test in this manner to guard against a subclass breaking Stream's - polymorphism (e.g. overriding a method with the wrong prototype). + """This defines rules that should remain true for all Stream subclasses. We + test in this manner to guard against a subclass breaking Stream's + polymorphism (e.g. overriding a method with the wrong prototype). - def testCommandLine(self): - print self.driver.command_line - """ + def testCommandLine(self): + print self.driver.command_line + """ class SSHStreamTestCase(unittest.TestCase, StreamTestBase): - DRIVER_CLASS = econtext.SSHStream + DRIVER_CLASS = econtext.SSHStream - def setUp(self): - # Stubs. + def setUp(self): + # Stubs. - # Instance initialization. - self.stream = econtext.SSHStream('localhost', 'test-agent') + # Instance initialization. + self.stream = econtext.SSHStream('localhost', 'test-agent') - def tearDown(self): - pass + def tearDown(self): + pass - def testConstructor(self): - pass + def testConstructor(self): + pass class TCPStreamTestCase(unittest.TestCase, StreamTestBase): - pass + pass # @@ -124,4 +124,4 @@ class TCPStreamTestCase(unittest.TestCase, StreamTestBase): # if __name__ == '__main__': - unittest.main() + unittest.main()