[WIP] Py3 improvements and extended brine version

This commit is contained in:
Oleksii Shevchuk 2020-09-07 14:57:39 +03:00
parent 37dcf7469e
commit c4fa316aca
28 changed files with 1676 additions and 685 deletions

View File

@ -200,6 +200,8 @@ dict_keys = type({}.keys())
dict_items = type({}.items())
dict_values = type({}.values())
tuple_items = type(().__iter__())
#############################################################################
# Exported Functions and Glob
#############################################################################

View File

@ -9,6 +9,7 @@ import json
from pupylib.PupyModule import config, PupyModule, PupyArgumentParser
from pupylib.utils.term import colorize
from network.lib.convcompat import as_unicode_string_deep
from defusedxml import minidom
__class_name__ = "IGDClient"
@ -25,10 +26,14 @@ class IGDCMDClient(object):
self.igdc = IGDClient(
args.source, args.url,
args.DEBUG, args.pretty_print)
args.DEBUG, args.pretty_print
)
self.log = log
def show(self, values):
values = as_unicode_string_deep(values)
if isinstance(values, dict):
column_size = max([len(x) for x in values])
fmt = '{{:<{}}}'.format(column_size)
@ -428,9 +433,6 @@ class IGDClient(PupyModule):
self.error('IGD: Not found in LAN')
return
self.cli.igdc.enableDebug(args.DEBUG)
self.cli.igdc.enablePPrint(args.pretty_print)
try:
args.func(args)
except Exception as e:

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2015, Nicolas VERDIER (contact@n1nj4.eu)
# Pupy is under the BSD 3-Clause license. see the LICENSE file at the root of the project for the detailed licence terms
# Pupy is under the BSD 3-Clause license. see the LICENSE file
# at the root of the project for the detailed licence terms
from __future__ import absolute_import
from __future__ import division
@ -13,6 +14,7 @@ from pupylib.PupyModule import (
)
from pupylib.utils.rpyc_utils import redirected_stdo
from network.lib.convcompat import as_native_string
import sys
import readline
@ -21,13 +23,14 @@ if sys.version_info.major > 2:
raw_input = input
__class_name__="InteractivePythonShell"
__class_name__ = 'InteractivePythonShell'
def enqueue_output(out, queue):
for c in iter(lambda: out.read(1), b""):
queue.put(c)
@config(cat="admin")
class InteractivePythonShell(PupyModule):
""" open an interactive python shell on the remote client """
@ -37,7 +40,9 @@ class InteractivePythonShell(PupyModule):
@classmethod
def init_argparse(cls):
cls.arg_parser = PupyArgumentParser(prog='pyshell', description=cls.__doc__)
cls.arg_parser = PupyArgumentParser(
prog='pyshell', description=cls.__doc__
)
def run(self, args):
PyShellController = self.client.remote(
@ -47,13 +52,21 @@ class InteractivePythonShell(PupyModule):
try:
with redirected_stdo(self):
old_completer = readline.get_completer()
try:
psc = PyShellController()
readline.set_completer(psc.get_completer())
completer = psc.get_completer()
def _completer_wrapper(*args, **kwargs):
value = completer(*args, **kwargs)
if value is not None:
return as_native_string(value)
readline.set_completer(_completer_wrapper)
readline.parse_and_bind('tab: complete')
while True:
cmd=raw_input(">>> ")
cmd = raw_input(">>> ")
psc.write(cmd)
finally:

View File

@ -25,20 +25,24 @@ else:
from network.lib import getLogger
logger = getLogger('pconn')
synclogger = getLogger('sync')
syncqueuelogger = getLogger('syncqueue')
from network.lib.ack import Ack
from network.lib.buffer import Buffer
from network.lib.rpc.core import Connection, consts, brine, netref
from network.lib.rpc.core.consts import (
HANDLE_PING, HANDLE_CLOSE, HANDLE_GETROOT
HANDLE_PING, HANDLE_CLOSE, HANDLE_GETROOT,
HANDLE_DIR, HANDLE_HASH, HANDLE_DEL
)
logger = getLogger('pconn')
synclogger = getLogger('sync')
syncqueuelogger = getLogger('syncqueue')
FAST_CALLS = (
HANDLE_PING, HANDLE_CLOSE, HANDLE_GETROOT
HANDLE_PING, HANDLE_CLOSE, HANDLE_GETROOT,
HANDLE_DIR, HANDLE_HASH, HANDLE_DEL
)
PY2TO3_CALLATTRS = (
@ -47,21 +51,26 @@ PY2TO3_CALLATTRS = (
'__getattribute__'
)
############# Monkeypatch brine to be buffer firendly #############
CONTROL_NOP = 0
CONTROL_ENABLE_BRINE_EXT_V1 = 1
MSG_PUPY_CONTROL = 0xF
# Monkeypatch brine to be buffer firendly
def stream_dump(obj):
BRINE_VER_1 = 1
PING_V1_CONTROL_MAGIC = b'\x00CTRL\x00V1'
def stream_dump(obj, version=0):
buf = Buffer()
brine._dump(obj, buf)
brine._dump(obj, buf, version)
return buf
# Py2: bytes == str
@brine.register(brine._dump_registry, bytes)
def _dump_bytes_to_buffer(obj, stream):
def _dump_bytes_to_buffer(obj, stream, version):
obj_len = len(obj)
if obj_len == 0:
stream.append(brine.TAG_EMPTY_STR)
@ -85,7 +94,7 @@ def _dump_bytes_to_buffer(obj, stream):
@brine.register(brine._dump_registry, Buffer)
def _dump_buffer_to_buffer(obj, stream):
def _dump_buffer_to_buffer(obj, stream, version):
stream.append(brine.TAG_STR_L4 + brine.I4.pack(len(obj)))
stream.append(obj)
@ -151,8 +160,9 @@ class SyncRequestDispatchQueue(object):
except Exception as e:
if __debug__:
syncqueuelogger.debug(
'Process task(%s) - exception: func=%s args=%s exc:%s/%s',
name, func, args, type(e), e)
'Process task(%s) - exception: func=%s args=%s '
'exc:%s/%s', name, func, args, type(e), e
)
if on_error:
on_error(e)
@ -173,14 +183,18 @@ class SyncRequestDispatchQueue(object):
except Empty:
with self._workers_lock:
if not self._closed and (self._promise or self._workers <= self._pending_workers + 1):
if not self._closed and (
self._promise or self._workers <=
self._pending_workers + 1):
again = True
else:
self._workers -= 1
if again:
if __debug__:
syncqueuelogger.debug('Wait for task to be queued(%s)', name)
syncqueuelogger.debug(
'Wait for task to be queued(%s)', name
)
task = self._queue.get()
@ -189,7 +203,9 @@ class SyncRequestDispatchQueue(object):
if __debug__:
if not task:
syncqueuelogger.debug('Worker(%s) closed by explicit request', name)
syncqueuelogger.debug(
'Worker(%s) closed by explicit request', name
)
def __call__(self, on_error, func, *args):
with self._workers_lock:
@ -217,17 +233,22 @@ class SyncRequestDispatchQueue(object):
except Full:
if __debug__:
syncqueuelogger.debug(
'Task not queued - no empty slots. Launch new worker (%s, %s)',
self, self._pending_workers)
'Task not queued - no empty slots. '
'Launch new worker (%s, %s)',
self, self._pending_workers
)
pass
if not queued or not ack.wait(timeout=self.MAX_TASK_ACK_TIME, probe=0.1):
if not queued or not ack.wait(
timeout=self.MAX_TASK_ACK_TIME, probe=0.1):
with self._workers_lock:
if self._closed:
if __debug__:
syncqueuelogger.debug(
'Queue (%s) closed, do not start new worker', self)
'Queue (%s) closed, do not start new worker',
self
)
self._workers += 1
if self._workers > self._max_workers:
@ -261,18 +282,21 @@ class SyncRequestDispatchQueue(object):
except Exception as e:
if __debug__:
syncqueuelogger.exception('Queue(%s) close: error: %s', self, e)
syncqueuelogger.exception(
'Queue(%s) close: error: %s', self, e
)
if __debug__:
syncqueuelogger.debug('Queue(%s) closed', self)
class PupyClientCababilities(object):
__slots__ = ('_storage', 'version')
__slots__ = ('_storage', '_version', '_acked')
def __init__(self, version=0):
self._storage = 0
self.version = version
self._version = version
self._acked = True
def set(self, cap):
self._storage |= cap
@ -280,16 +304,33 @@ class PupyClientCababilities(object):
def get(self, cap):
return self._storage & cap == cap
@property
def version(self):
return self._version
@version.setter
def version(self, version):
if self._version != version:
self._acked = False
self._version = version
def ack(self):
result = self._acked
self._acked = True
return result
class PupyConnection(Connection):
__slots__ = (
'_initialized', '_deinitialized', '_closing',
'_close_lock', '_sync_events_lock',
'_async_events_lock', '_sync_events',
'_sync_raw_replies', '_sync_raw_exceptions',
'_last_recv', '_ping', '_ping_timeout',
'_serve_timeout', '_last_ping', '_default_serve_timeout',
'_queue', '_config', '_timer_event', '_timer_event_last',
'_client_capabilities'
'_client_capabilities', '_3to2_mode'
)
def __repr__(self):
@ -351,7 +392,10 @@ class PupyConnection(Connection):
def _on_sync_request_exception(self, exc):
if __debug__:
logger.exception('Connection(%s) - sync request exception %s', self, exc)
logger.exception(
'Connection(%s) - sync request exception %s',
self, exc
)
if not isinstance(exc, EOFError):
logger.exception('%s: %s', self, exc)
@ -379,7 +423,7 @@ class PupyConnection(Connection):
if ping is not None:
try:
self._serve_timeout = int(ping)
except:
except ValueError:
self._serve_timeout = 10
self._ping = ping and ping not in (
@ -391,11 +435,19 @@ class PupyConnection(Connection):
if timeout:
try:
self._ping_timeout = int(timeout)
except:
except ValueError:
self._ping_timeout = 2
return self.get_pings()
def _handle_ping(self, data):
if data.startswith(PING_V1_CONTROL_MAGIC):
payload = brine.load(data[len(PING_V1_CONTROL_MAGIC):])
self._dispatch_pupy_control(*payload)
return b''
return data
def get_pings(self):
if self._ping:
return self._serve_timeout, self._ping_timeout
@ -412,8 +464,10 @@ class PupyConnection(Connection):
if __debug__:
trace = traceback.extract_stack()
if len(trace) >= 4:
synclogger.debug('Sync request wait(%s): %s / %s:%s %s (%s)',
self, seq, *trace[-4])
synclogger.debug(
'Sync request wait(%s): %s / %s:%s %s (%s)',
self, seq, *trace[-4]
)
self._sync_events[seq].wait()
@ -439,7 +493,9 @@ class PupyConnection(Connection):
'Dispatch sync reply(%s): %s - start', self, seq)
Connection._dispatch_reply(
self, seq, self._sync_raw_replies.pop(seq))
self, seq, self._sync_raw_replies.pop(seq),
self._client_capabilities.version
)
if __debug__:
synclogger.debug(
@ -448,15 +504,20 @@ class PupyConnection(Connection):
if is_exception:
if __debug__:
synclogger.debug(
'Dispatch sync exception(%s): %s - start', self, seq)
'Dispatch sync exception(%s): %s - start', self, seq
)
synclogger.debug(
'Dispatch sync exception(%s): %s - handler = %s(%s) args = %s',
self, seq,
self._HANDLERS[handler], handler,
repr(args))
'Dispatch sync exception(%s): %s - handler = %s(%s) '
'args = %s',
self, seq,
self._HANDLERS[handler], handler,
repr(args)
)
Connection._dispatch_exception(
self, seq, self._sync_raw_exceptions.pop(seq))
self, seq, self._sync_raw_exceptions.pop(seq),
self._client_capabilities.version
)
if __debug__:
synclogger.debug(
@ -476,9 +537,16 @@ class PupyConnection(Connection):
else:
return obj
def _send_control(self, *args):
seq = next(self._seqcounter)
self._send(consts.MSG_PUPY_CONTROL, seq, args)
def _send_control(self, code, data=None, timeout=None):
# Use PING command to send controls
# For compatibility
payload = brine.dump((code, data))
payload.insert(PING_V1_CONTROL_MAGIC)
return self.async_request(
consts.HANDLE_PING, payload, timeout=timeout
)
def _py2to3_conv(self, handler, args):
if handler in (consts.HANDLE_GETATTR, consts.HANDLE_DELATTR):
@ -527,7 +595,9 @@ class PupyConnection(Connection):
cls = netref.builtin_classes_cache[typeinfo]
else:
info = self.sync_request(consts.HANDLE_INSPECT, oid)
cls = netref.class_factory(clsname, modname, info)
cls = netref.class_factory(
clsname, modname, info
)
self._netref_classes_cache[typeinfo] = cls
# print("Use inspect netref", typeinfo, "as", cls, "info", info)
return cls(weakref.ref(self), oid)
@ -549,7 +619,14 @@ class PupyConnection(Connection):
self._sync_events[seq] = Ack()
self._send(consts.MSG_REQUEST, seq, (handler, self._box(args)))
self._send(
consts.MSG_REQUEST, seq, (
handler, self._box(
args, self._client_capabilities.version
)
),
self._client_capabilities.version
)
if __debug__:
synclogger.debug('Request submitted(%s): %s', self, seq)
@ -582,7 +659,10 @@ class PupyConnection(Connection):
logger.debug(
'Dispatch async reply(%s): %s - start', self, seq)
Connection._dispatch_reply(self, seq, raw)
Connection._dispatch_reply(
self, seq, raw,
self._client_capabilities.version
)
if __debug__:
logger.debug(
@ -603,13 +683,19 @@ class PupyConnection(Connection):
if __debug__:
logger.debug(
'Dispatch sync exception(%s): %s - pass',
self, seq)
self, seq
)
self._sync_events[seq].set()
else:
if __debug__:
logger.debug(
'Dispatch async reply(%s): %s - start', self, seq)
Connection._dispatch_exception(self, seq, raw)
Connection._dispatch_exception(
self, seq, raw,
self._client_capabilities.version
)
if __debug__:
logger.debug(
'Dispatch async reply(%s): %s - complete', self, seq)
@ -623,8 +709,10 @@ class PupyConnection(Connection):
if __debug__:
trace = traceback.extract_stack()
if len(trace) >= 2:
logger.debug('Connection(%s) - close - start (at: %s:%s %s(%s))',
self, *trace[-2])
logger.debug(
'Connection(%s) - close - start (at: %s:%s %s(%s))',
self, *trace[-2]
)
try:
self._async_request(consts.HANDLE_CLOSE)
@ -654,13 +742,18 @@ class PupyConnection(Connection):
if self._channel and hasattr(self._channel, 'wake'):
if __debug__:
logger.debug('Connection(%s) - wake buf_in (%s)', self, self._channel)
logger.debug(
'Connection(%s) - wake buf_in (%s)',
self, self._channel
)
self._channel.wake()
except Exception as e:
if __debug__:
logger.debug('Connection(%s) - cleanup exception - %s', self, e)
logger.debug(
'Connection(%s) - cleanup exception - %s', self, e
)
pass
if __debug__:
@ -715,15 +808,11 @@ class PupyConnection(Connection):
raise NotImplementedError('Serve method should not be used!')
def _init_service_with_notify(self, timeout):
def check_timeout():
def check_timeout(promise):
now = time.time()
logger.debug('Check timeout(%s) - start', self)
promise = self.async_request(
consts.HANDLE_PING, b'ping', timeout=timeout
)
while (time.time() - now < timeout) and not self.closed:
if promise.expired:
logger.info('Check timeout(%s) - failed', self)
@ -737,8 +826,12 @@ class PupyConnection(Connection):
time.sleep(1)
if self._local_root:
promise = self._send_control(
CONTROL_ENABLE_BRINE_EXT_V1, timeout=timeout
)
t = Thread(
target=check_timeout,
target=check_timeout, args=(promise,),
name="PupyConnection({}) Timeout check".format(self)
)
t.daemon = True
@ -758,7 +851,8 @@ class PupyConnection(Connection):
self._queue(
self._on_sync_request_exception,
self._init_service_with_notify,
timeout)
timeout
)
def loop(self):
if __debug__:
@ -816,7 +910,7 @@ class PupyConnection(Connection):
data = None
for async_event_id in self._async_callbacks:
for async_event_id in self._async_callbacks.keys():
async_event = self._async_callbacks.get(async_event_id, None)
if not async_event:
continue
@ -875,13 +969,23 @@ class PupyConnection(Connection):
return data
def _dispatch_pupy_control(self, args):
def is_extended(self):
return self._client_capabilities.version > 0
def _dispatch_pupy_control(self, code, *args):
if __debug__:
logger.debug(
'Processing pupy brine control: args: %s', args
)
pass
if code == CONTROL_ENABLE_BRINE_EXT_V1:
self._client_capabilities.version = 1
if not self._client_capabilities.ack():
self._send_control(CONTROL_ENABLE_BRINE_EXT_V1)
if __debug__:
logger.debug('Client supports brine extensions V1')
def _dispatch(self, data):
if __debug__:
@ -893,13 +997,11 @@ class PupyConnection(Connection):
if __debug__:
logger.debug('Dispatch(%s) - data (%s)', self, len(data))
msg, seq, args = brine._load(data)
msg, seq, args = brine._load(
data, self._client_capabilities.version
)
if msg == MSG_PUPY_CONTROL:
self._dispatch_pupy_control(args)
return
elif msg == consts.MSG_REQUEST:
if msg == consts.MSG_REQUEST:
if __debug__:
logger.debug(
'Processing message request, type(%s): '
@ -910,11 +1012,15 @@ class PupyConnection(Connection):
handler = args[0]
if handler in FAST_CALLS:
self._dispatch_request(seq, args)
self._dispatch_request(
seq, args, self._client_capabilities.version
)
else:
self._queue(
self._on_sync_request_exception,
self._dispatch_request, seq, args)
self._dispatch_request,
seq, args, self._client_capabilities.version
)
else:
if __debug__:
@ -947,7 +1053,7 @@ class PupyConnection(Connection):
if __debug__:
logger.debug('Dispatch(%s) - no data', self)
for async_event_id in self._async_callbacks:
for async_event_id in self._async_callbacks.keys():
async_event = self._async_callbacks.get(async_event_id)
if not async_event:
continue
@ -975,6 +1081,8 @@ class PupyConnection(Connection):
return now
Connection._HANDLERS[consts.HANDLE_PING] = _handle_ping
class PupyConnectionThread(Thread):
def __init__(self, *args, **kwargs):
@ -989,7 +1097,10 @@ class PupyConnectionThread(Thread):
self.name = 'PupyConnection({}) Thread'.format(self.connection)
if __debug__:
logger.debug('Create connection(%s) thread completed', self.connection)
logger.debug(
'Create connection(%s) thread completed',
self.connection
)
def run(self):
if __debug__:

View File

@ -18,6 +18,9 @@ from network.lib import Proxy
from network.lib.proxies import ProxyInfo
from network.lib.utils import HostInfo, TransportInfo
from network.lib.convcompat import as_unicode_string_deep
from network.lib.rpc.core.brine import (
register_named_tuple as brine_register_named_tuple
)
from umsgpack import Ext, packb, unpackb
@ -30,20 +33,22 @@ KNOWN_NAMED_TUPLES = (
)
def register_named_tuple(code, type):
def register_named_tuple(code, ntype):
MSG_TYPES_PACK[type] = lambda obj: Ext(
code, packb(tuple(x for x in obj)))
MSG_TYPES_UNPACK[code] = lambda obj: type(
MSG_TYPES_UNPACK[code] = lambda obj: ntype(
*unpackb(obj.data))
brine_register_named_tuple(code, ntype)
def register_string(type, code, name):
MSG_TYPES_PACK[type] = lambda obj: Ext(code, '')
def register_string(ntype, code, name):
MSG_TYPES_PACK[ntype] = lambda obj: Ext(code, '')
MSG_TYPES_UNPACK[code] = lambda obj: name
for idx, type in enumerate(KNOWN_NAMED_TUPLES):
register_named_tuple(idx, type)
for idx, ntype in enumerate(KNOWN_NAMED_TUPLES):
register_named_tuple(idx, ntype)
SPECIAL_TYPES_OFFT = len(KNOWN_NAMED_TUPLES)

View File

@ -33,11 +33,9 @@ import logging
try:
from pupylib import getLogger
logger = getLogger('dnscnc')
except:
except ImportError:
logger = logging.getLogger('dnscnc')
blocks_logger = logger.getChild('whitelist')
import socket
import socketserver
import binascii
@ -69,6 +67,9 @@ from network.lib.compat import (
as_byte, is_int, is_str, xrange
)
blocks_logger = logger.getChild('whitelist')
SUPPORTED_METHODS = {
QTYPE.A: A,
QTYPE.AAAA: AAAA,
@ -78,7 +79,7 @@ SUPPORTED_METHODS = {
def convert_node(node):
try:
return str(netaddr.IPAddress(node))
except:
except netaddr.core.AddrFormatError:
return int(node, 16)
@ -187,7 +188,9 @@ class Node(ExpirableObject):
'_warning', '_warning_set_time'
)
def __init__(self, node, timeout, cid=0x31337, iid=0, version=1, commands=[], alert=False):
def __init__(
self, node, timeout, cid=0x31337, iid=0,
version=1, commands=[], alert=False):
super(Node, self).__init__(timeout)
self.node = node
self.cid = cid
@ -221,8 +224,10 @@ class Node(ExpirableObject):
self.commands.append(command)
def __repr__(self):
return '{{NODE:{:012X} IID:{} CID:{:08X} ALERT:{} COMMANDS:{}}}'.format(
self.node, self.iid, self.cid, self.alert, len(self.commands))
return \
'{{NODE:{:012X} IID:{} CID:{:08X} ALERT:{} COMMANDS:{}}}'.format(
self.node, self.iid, self.cid, self.alert, len(self.commands)
)
class Session(ExpirableObject):
@ -330,7 +335,8 @@ class DnsCommandServerHandler(BaseResolver):
ENCODER_V1 = 0
ENCODER_V2 = 1
def __init__(self, domain, key, recursor=None, timeout=None,
def __init__(
self, domain, key, recursor=None, timeout=None,
whitelist=None, edns=False, activation={}):
self.sessions = {}
@ -370,9 +376,10 @@ class DnsCommandServerHandler(BaseResolver):
# Calculate max packet size
# https://tools.ietf.org/html/rfc791
# All hosts must be prepared to accept datagrams of up to 576 octets (whether they
# arrive whole or in fragments). It is recommended that hosts only send datagrams
# larger than 576 octets if they have assurance that the destination is prepared to
# All hosts must be prepared to accept datagrams of up to 576
# octets (whether they arrive whole or in fragments). It is
# recommended that hosts only send datagrams larger than 576 octets
# if they have assurance that the destination is prepared to
# accept the larger datagrams.
# https://dnsflagday.net/
@ -381,7 +388,8 @@ class DnsCommandServerHandler(BaseResolver):
# Query header size = len(domain) + 2 + 4
# Answer header size = 2 {if name = .} + 10 + len(record)
# Default length limited to 256B to spend only 1 byte for info about the length
# Default length limited to 256B to spend only 1 byte for info about
# the length.
# Each payload has index field
# As request is dynamic, we don't count it here, but in encoder
@ -439,17 +447,20 @@ class DnsCommandServerHandler(BaseResolver):
def _nodes_by_nodeids(self, ids):
return [
node for (nodeid, iid),node in self.nodes.items() if nodeid in ids
node for (nodeid, iid), node in self.nodes.items()
if nodeid in ids
]
def _sessions_by_nodeids(self, ids):
return [
session for session in self.sessions if self.sessions[session].node in ids
session for session in self.sessions
if self.sessions[session].node in ids
]
def _nodeids_with_sessions(self, ids):
return set([
session.node for session in self.sessions if self.sessions[session].node in ids
session.node for session in self.sessions
if self.sessions[session].node in ids
])
@locked
@ -588,7 +599,7 @@ class DnsCommandServerHandler(BaseResolver):
if not (spi or node):
return [
session for session in self.sessions.values() \
session for session in self.sessions.values()
if session.system_info is not None
]
elif spi:
@ -597,20 +608,25 @@ class DnsCommandServerHandler(BaseResolver):
]
elif node:
return [
session for session in self.sessions.values() \
if session.cid == node or session.node == node or (
session.system_info and \
(session.system_info['node'] in set(node) or \
str(session.system_info['external_ip']) in set(node)))
session for session in self.sessions.values()
if session.cid == node or session.node == node or (
session.system_info and (
session.system_info['node'] in set(node) or
str(session.system_info['external_ip']) in set(node)
)
)
]
@locked
def set_policy(self, kex=True, timeout=None, interval=None, node=None):
if kex == self.kex and self.timeout == timeout and self.interval == self.interval:
if kex == self.kex and self.timeout == timeout and \
self.interval == self.interval:
return
if interval and interval < 30:
raise ValueError('Interval should not be less then 30s to avoid DNS storm')
raise ValueError(
'Interval should not be less then 30s to avoid DNS storm'
)
if node and (interval or timeout):
sessions = self.find_sessions(
@ -631,7 +647,10 @@ class DnsCommandServerHandler(BaseResolver):
else:
self.interval = interval or self.interval
self.timeout = max(timeout if timeout else self.timeout, self.interval*3)
self.timeout = max(
timeout if timeout else self.timeout,
self.interval*3
)
self.kex = kex if (kex is not None) else self.kex
interval = self.interval
@ -729,7 +748,8 @@ class DnsCommandServerHandler(BaseResolver):
response = []
for idx, part in enumerate([payload[i:i+3] for i in xrange(0, len(payload), 3)]):
for idx, part in enumerate([
payload[i:i+3] for i in xrange(0, len(payload), 3)]):
header = (random.randint(1, 3) << 30)
idx = idx << 25
bits = (struct.unpack('>I', b'\x00'+part+as_byte(
@ -739,7 +759,12 @@ class DnsCommandServerHandler(BaseResolver):
'{}.{}.{}.{}'.format(
*struct.unpack(
'!BBBB', struct.pack(
'>I', header | idx | bits | int(not bool(bits & 6))))))
'>I',
header | idx | bits | int(not bool(bits & 6))
)
)
)
)
return response
@ -765,7 +790,8 @@ class DnsCommandServerHandler(BaseResolver):
response = []
for idx, part in enumerate([payload[i:i+15] for i in xrange(0, len(payload), 15)]):
for idx, part in enumerate([
payload[i:i+15] for i in xrange(0, len(payload), 15)]):
packed = struct.pack('B', idx) + part
if len(packed) < 16:
packed = packed + b'\x00' * (16 - len(packed))
@ -796,7 +822,7 @@ class DnsCommandServerHandler(BaseResolver):
elif len(parts) == 1 and parts[0] in self.activation:
raise DnsActivationRequest(self.activator(parts[0]))
elif len(parts) not in (2,3):
elif len(parts) not in (2, 3):
raise DnsNoCommandServerException()
parts = [
@ -853,7 +879,8 @@ class DnsCommandServerHandler(BaseResolver):
if node_blob:
offset_node_blob = len(payload) - (1+4+2+6)
payload, node_blob = payload[:offset_node_blob], payload[offset_node_blob:]
payload = payload[:offset_node_blob]
node_blob = payload[offset_node_blob:]
version, cid, iid = struct.unpack_from('>BIH', node_blob)
@ -862,8 +889,10 @@ class DnsCommandServerHandler(BaseResolver):
nodeid = from_bytes(node_blob[1+4+2:1+4+2+6])
logger.debug('NONCE: %08x SPI: %08x NODE: %012x',
nonce, spi, nodeid if bool(node_blob) else 0)
logger.debug(
'NONCE: %08x SPI: %08x NODE: %012x',
nonce, spi, nodeid if bool(node_blob) else 0
)
return payload, session, nonce, nodeid, cid, iid, version
@ -893,7 +922,7 @@ class DnsCommandServerHandler(BaseResolver):
node = Node(
command.node, self.timeout,
commands = self.node_commands.get(nodeid),
commands=self.node_commands.get(nodeid),
iid=sid or 0
)
@ -905,13 +934,16 @@ class DnsCommandServerHandler(BaseResolver):
return node
def _cmd_processor(self, command, session, node, csum_gen, csum_check):
logger.debug('command=%s/%s session=%s / node commands=%s / node = %s / cid = %s / iid = %s',
logger.debug(
'command=%s/%s session=%s / node commands=%s '
'/ node = %s / cid = %s / iid = %s',
command, type(command).__name__,
'{:08x}'.format(session.spi) if session else None,
bool(self.node_commands),
'{:012x}'.format(node.node) if node else None,
'{:08x}'.format(node.cid) if node else None,
node.iid if node else None)
node.iid if node else None
)
if isinstance(command, Poll) and session is None:
if not self.kex or self._kex_is_disabled(node):
@ -928,8 +960,10 @@ class DnsCommandServerHandler(BaseResolver):
elif isinstance(command, Ack) and (session is None):
if node:
if len(node.commands) < command.amount:
logger.debug('ACK: invalid amount of commands: %d > %d',
command.amount, len(node.commands))
logger.debug(
'ACK: invalid amount of commands: %d > %d',
command.amount, len(node.commands)
)
node.commands = node.commands[command.amount:]
@ -944,9 +978,8 @@ class DnsCommandServerHandler(BaseResolver):
return [Exit()]
elif (
isinstance(command, Poll) or isinstance(command, SystemStatus)
) and (session is not None):
elif (isinstance(command, Poll) or
isinstance(command, SystemStatus)) and (session is not None):
if session.system_info:
self.on_keep_alive(session.system_info)
@ -954,7 +987,7 @@ class DnsCommandServerHandler(BaseResolver):
session.system_status = command.get_dict()
if session._users_cnt_reported is not None and \
session._users_cnt_reported != session.system_status['users']:
session._users_cnt_reported != session.system_status['users']:
if session._users_cnt_reported > session.system_status['users']:
self.on_users_decrement(session)
else:
@ -970,7 +1003,7 @@ class DnsCommandServerHandler(BaseResolver):
session._high_resource_usage_reported = False
if session._user_active_reported is not None and \
session._user_active_reported != session.system_status['idle']:
session._user_active_reported != session.system_status['idle']:
if session.system_status['idle']:
self.on_user_become_inactive(session)
else:
@ -998,8 +1031,10 @@ class DnsCommandServerHandler(BaseResolver):
node.bump()
logger.debug('SystemStatus + No session + node_commands: %s/%s in %s?',
node, extip, node.commands)
logger.debug(
'SystemStatus + No session + node_commands: %s/%s in %s?',
node, extip, node.commands
)
return node.commands
@ -1035,13 +1070,16 @@ class DnsCommandServerHandler(BaseResolver):
self.on_keep_alive(session.system_info)
if command.amount > len(session.commands):
logger.debug('ACK: invalid amount of commands: %d > %d',
command.amount, len(session.commands))
logger.debug(
'ACK: invalid amount of commands: %d > %d',
command.amount, len(session.commands)
)
session.commands = session.commands[command.amount:]
return [Ack(1)]
elif isinstance(command, (SystemInfo, SystemInfoEx)) and session is not None:
elif isinstance(command, (SystemInfo, SystemInfoEx)) \
and session is not None:
new_session = not bool(session.system_info)
session.system_info = command.get_dict()
if isinstance(command, SystemInfoEx):
@ -1050,7 +1088,9 @@ class DnsCommandServerHandler(BaseResolver):
if not node:
with self.lock:
if not (command.node, session.spi) in self.nodes:
node = self._new_node_from_systeminfo(command, session.spi)
node = self._new_node_from_systeminfo(
command, session.spi
)
else:
node = self.nodes[(command.node, session.spi)]
@ -1066,8 +1106,8 @@ class DnsCommandServerHandler(BaseResolver):
response = []
encoder_version = \
self.ENCODER_V1 if not node or node.version == 1 \
else self.ENCODER_V2
self.ENCODER_V1 if not node or node.version == 1 \
else self.ENCODER_V2
if command.spi not in self.sessions:
self.sessions[command.spi] = Session(
@ -1106,7 +1146,9 @@ class DnsCommandServerHandler(BaseResolver):
if request.q.qtype not in SUPPORTED_METHODS:
reply = request.reply()
reply.header.rcode = RCODE.NXDOMAIN
logger.debug('Request unknown qtype: %s', QTYPE.get(request.q.qtype))
logger.debug(
'Request unknown qtype: %s', QTYPE.get(request.q.qtype)
)
return reply
with self.lock:
@ -1166,25 +1208,32 @@ class DnsCommandServerHandler(BaseResolver):
if self.whitelist and node:
if not self.whitelist(nodeid, cid, version):
blocks_logger.warning('Prohibit communication with %s/%s version %s on %s',
iid, cid, version, nodeid)
blocks_logger.warning(
'Prohibit communication with %s/%s version %s on %s',
iid, cid, version, nodeid
)
node.alert = True
raise NodeBlocked()
if session and session.last_nonce and session.last_qname:
if nonce < session.last_nonce:
logger.info('Ignore nonce from past: %s < %s / %s',
nonce, session.last_nonce, session.node)
logger.info(
'Ignore nonce from past: %s < %s / %s',
nonce, session.last_nonce, session.node
)
if node:
node.warning = 'Nonce from the past ({} < {})'.format(
nonce, session.last_nonce)
return []
elif session.last_nonce == nonce and session.last_qname != qname:
logger.info('Last nonce but different qname: %s != %s',
session.last_qname, qname)
elif session.last_nonce == nonce and \
session.last_qname != qname:
logger.info(
'Last nonce but different qname: %s != %s',
session.last_qname, qname
)
if node:
node.warning = 'Different qname ({})'.format(qname)
@ -1297,8 +1346,10 @@ class DnsCommandServerHandler(BaseResolver):
try:
payload = Parcel(to_send).pack(nonce, gen_csum)
except PackError as e:
emsg = 'Could not create parcel from commands: {} (session={})'.format(
e, '{:08x}'.format(session.spi) if session else None)
emsg = 'Could not create parcel from commands: ' \
'{} (session={})'.format(
e, '{:08x}'.format(session.spi) if session else None
)
logger.error(emsg)
@ -1334,7 +1385,9 @@ class DnsCommandServerHandler(BaseResolver):
if not qname.matchSuffix(self.domain):
if self.recursor:
try:
return DNSRecord.parse(request.send(self.recursor, timeout=2))
return DNSRecord.parse(
request.send(self.recursor, timeout=2)
)
except socket.error:
pass
except Exception as e:
@ -1345,13 +1398,17 @@ class DnsCommandServerHandler(BaseResolver):
reply.header.rcode = RCODE.NXDOMAIN
return reply
answers = self.process(
qtype, qname.stripSuffix(self.domain).idna()[:-1]
)
answers = self.process(qtype, qname.stripSuffix(self.domain).idna()[:-1])
klass = SUPPORTED_METHODS[qtype]
if answers:
for answer in answers:
reply.add_answer(RR(qname, qtype, rdata=klass(answer), ttl=600))
reply.add_answer(
RR(qname, qtype, rdata=klass(answer), ttl=600)
)
if self.edns:
reply.add_ar(EDNS0(udp_len=512))
@ -1360,6 +1417,7 @@ class DnsCommandServerHandler(BaseResolver):
return reply
class DnsCommandServer(object):
def __init__(self, handler, port=5454, address='0.0.0.0'):
self.handler = handler
@ -1367,12 +1425,12 @@ class DnsCommandServer(object):
self.udp_server = socketserver.UDPServer((address, port), DNSHandler)
self.udp_server.allow_reuse_address = True
self.udp_server.resolver = handler
self.udp_server.logger = DNSLogger(log='log_error',prefix=False)
self.udp_server.logger = DNSLogger(log='log_error', prefix=False)
self.tcp_server = socketserver.TCPServer((address, port), DNSHandler)
self.tcp_server.allow_reuse_address = True
self.tcp_server.resolver = handler
self.tcp_server.logger = DNSLogger(log='log_error',prefix=False)
self.tcp_server.logger = DNSLogger(log='log_error', prefix=False)
self.udp_server_thread = Thread(
target=self.udp_server.serve_forever, kwargs={

View File

@ -50,7 +50,7 @@ __all__ = (
'Channel', 'Connection', 'Service', 'BaseNetref', 'AsyncResult',
'GenericException', 'AsyncResultTimeout',
'nowait', 'timed', 'buffiter', 'BgServingThread', 'restricted',
'classic',
'classic', 'byref',
'__version__'
)
@ -58,7 +58,7 @@ import sys
from network.lib.rpc.core import (
Channel, Connection, Service, BaseNetref, AsyncResult,
GenericException, AsyncResultTimeout
GenericException, AsyncResultTimeout, byref
)
from network.lib.rpc.utils.helpers import (

View File

@ -6,12 +6,12 @@ __all__ = (
'Channel', 'Connection', 'BaseNetref',
'AsyncResult', 'AsyncResultTimeout',
'Service', 'GenericException',
'Stream', 'ClosedFile', 'SocketStream'
'Stream', 'ClosedFile', 'SocketStream', 'byref'
)
from network.lib.rpc.core.channel import Channel
from network.lib.rpc.core.protocol import Connection
from network.lib.rpc.core.netref import BaseNetref
from network.lib.rpc.core.netref import BaseNetref, byref
from network.lib.rpc.core.nowait import AsyncResult, AsyncResultTimeout
from network.lib.rpc.core.service import Service
from network.lib.rpc.core.vinegar import GenericException

View File

@ -24,10 +24,13 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import logging
from network.lib.compat import (
Struct, BytesIO, is_py3k
Struct, BytesIO, is_py3k, as_byte, xrange
)
# singletons
TAG_NONE = b'\x00'
TAG_EMPTY_STR = b'\x01'
@ -58,63 +61,115 @@ TAG_SLICE = b'\x19'
TAG_FSET = b'\x1a'
TAG_COMPLEX = b'\x1b'
if is_py3k:
IMM_INTS = dict((i, bytes([i + 0x50])) for i in range(-0x30, 0xa0))
xrange = range
else:
IMM_INTS = dict((i, chr(i + 0x50)) for i in range(-0x30, 0xa0))
# non-standard pupy RPC extensions
TAG_NAMED_TUPLE = b'\xf0'
TAG_IMMUTABLE_DICT = b'\xf1'
TAG_IMMUTABLE_SET = b'\xf2'
TAG_IMMUTABLE_LIST = b'\xf2'
REGISTERED_NAMED_TUPLES_PACK = {}
REGISTERED_NAMED_TUPLES_UNPACK = {}
MAX_REGISTERED_VERSION = 1
I1 = Struct("!B")
I4 = Struct("!L")
F8 = Struct("!d")
C16 = Struct("!dd")
_dump_registry = {}
_load_registry = {}
IMM_INTS_LOADER = dict((v, k) for k, v in IMM_INTS.items())
_dump_registry = tuple(
dict() for _ in xrange(MAX_REGISTERED_VERSION + 1)
)
def register(coll, key):
_load_registry = tuple(
([None]*256) for _ in xrange(MAX_REGISTERED_VERSION + 1)
)
def _dump_named_tuple(obj, stream, version):
obj_type = type(obj)
tuple_id = REGISTERED_NAMED_TUPLES_PACK.get(obj_type)
if tuple_id is None:
raise ValueError('Unregistered named tuple type %s', obj_type)
stream.append(TAG_NAMED_TUPLE)
stream.append(I4.pack(tuple_id))
tuple_dump = _dump_registry[0][tuple]
tuple_dump(obj, stream, version)
def register_named_tuple(code, ntype):
REGISTERED_NAMED_TUPLES_PACK[ntype] = code
REGISTERED_NAMED_TUPLES_UNPACK[code] = ntype
for ver in xrange(1, MAX_REGISTERED_VERSION + 1):
_dump_registry[ver][ntype] = _dump_named_tuple
def register(coll, key, min_version=0):
def deco(func):
coll[key] = func
for version in xrange(min_version, MAX_REGISTERED_VERSION + 1):
if coll is _dump_registry:
_dump_registry[version][key] = func
elif coll is _load_registry:
_load_registry[version][ord(key)] = func
else:
raise ValueError(
'Unknown registry %s' % (repr(coll),)
)
return func
return deco
#===============================================================================
# =============================================================================
# dumping
#===============================================================================
# =============================================================================
@register(_dump_registry, type(None))
def _dump_none(obj, stream):
def _dump_none(obj, stream, version):
stream.append(TAG_NONE)
@register(_dump_registry, type(NotImplemented))
def _dump_notimplemeted(obj, stream):
def _dump_notimplemeted(obj, stream, version):
stream.append(TAG_NOT_IMPLEMENTED)
@register(_dump_registry, type(Ellipsis))
def _dump_ellipsis(obj, stream):
def _dump_ellipsis(obj, stream, version):
stream.append(TAG_ELLIPSIS)
@register(_dump_registry, bool)
def _dump_bool(obj, stream):
def _dump_bool(obj, stream, version):
if obj:
stream.append(TAG_TRUE)
else:
stream.append(TAG_FALSE)
@register(_dump_registry, slice)
def _dump_slice(obj, stream):
def _dump_slice(obj, stream, version):
stream.append(TAG_SLICE)
_dump((obj.start, obj.stop, obj.step), stream)
_dump((obj.start, obj.stop, obj.step), stream, version)
@register(_dump_registry, frozenset)
def _dump_frozenset(obj, stream):
def _dump_frozenset(obj, stream, version):
stream.append(TAG_FSET)
_dump(tuple(obj), stream)
_dump(tuple(obj), stream, version)
@register(_dump_registry, int)
def _dump_int(obj, stream):
if obj in IMM_INTS:
stream.append(IMM_INTS[obj])
def _dump_int(obj, stream, version):
if obj >= -0x30 and obj < 0xa0:
stream.append(as_byte(obj + 0x50))
else:
obj = str(obj).encode('ascii')
obj_len = len(obj)
@ -123,17 +178,20 @@ def _dump_int(obj, stream):
else:
stream.append(TAG_INT_L4 + I4.pack(obj_len) + obj)
@register(_dump_registry, float)
def _dump_float(obj, stream):
def _dump_float(obj, stream, version):
stream.append(TAG_FLOAT + F8.pack(obj))
@register(_dump_registry, complex)
def _dump_complex(obj, stream):
def _dump_complex(obj, stream, version):
stream.append(TAG_COMPLEX + C16.pack(obj.real, obj.imag))
if is_py3k:
@register(_dump_registry, bytes)
def _dump_bytes(obj, stream):
def _dump_bytes(obj, stream, version):
obj_len = len(obj)
if obj_len == 0:
stream.append(TAG_EMPTY_STR)
@ -153,13 +211,31 @@ if is_py3k:
stream.append(obj)
@register(_dump_registry, str)
def _dump_str(obj, stream):
def _dump_str(obj, stream, version):
stream.append(TAG_UNICODE)
_dump_bytes(obj.encode('utf8'), stream)
_dump_bytes(obj.encode('utf8'), stream, version)
@register(_dump_registry, dict, 1)
def _dump_immutable_dict(obj, stream, version):
stream.append(TAG_IMMUTABLE_DICT)
items = len(obj)
stream.append(I4.pack(items))
for item in obj.items():
_dump(item, stream, version)
@register(_dump_registry, type({}.keys()), 1)
def _dump_immutable_dict_keys(obj, stream, version):
stream.append(TAG_IMMUTABLE_LIST)
items = len(obj)
stream.append(I4.pack(items))
for item in obj:
_dump(item, stream, version)
else:
@register(_dump_registry, str)
def _dump_str(obj, stream):
def _dump_str(obj, stream, version):
obj_len = len(obj)
if obj_len == 0:
stream.append(TAG_EMPTY_STR)
@ -179,18 +255,47 @@ else:
stream.append(obj)
@register(_dump_registry, unicode)
def _dump_unicode(obj, stream):
def _dump_unicode(obj, stream, version):
stream.append(TAG_UNICODE)
_dump_str(obj.encode('utf8'), stream)
_dump_str(obj.encode('utf8'), stream, version)
@register(_dump_registry, long)
def _dump_long(obj, stream):
def _dump_long(obj, stream, version):
stream.append(TAG_LONG)
_dump_int(obj, stream)
_dump_int(obj, stream, version)
@register(_dump_registry, dict, 1)
def _dump_immutable_dict(obj, stream, version):
stream.append(TAG_IMMUTABLE_DICT)
items = len(obj)
stream.append(I4.pack(items))
for item in obj.iteritems():
_dump(item, stream, version)
@register(_dump_registry, set, 1)
def _dump_immutable_set(obj, stream, version):
stream.append(TAG_IMMUTABLE_SET)
items = len(obj)
stream.append(I4.pack(items))
for item in obj:
_dump(item, stream, version)
@register(_dump_registry, list, 1)
def _dump_immutable_list(obj, stream, version):
stream.append(TAG_IMMUTABLE_LIST)
items = len(obj)
stream.append(I4.pack(items))
for item in obj:
_dump(item, stream, version)
@register(_dump_registry, tuple)
def _dump_tuple(obj, stream):
def _dump_tuple(obj, stream, version):
obj_len = len(obj)
if obj_len == 0:
stream.append(TAG_EMPTY_TUPLE)
@ -208,181 +313,248 @@ def _dump_tuple(obj, stream):
stream.append(TAG_TUP_L4 + I4.pack(obj_len))
for item in obj:
_dump(item, stream)
_dump(item, stream, version)
def _undumpable(obj, stream):
raise TypeError("cannot dump %r" % (obj,))
def _undumpable(obj, stream, version):
raise TypeError("cannot dump %r (%s) version=%s" % (
obj, type(obj), version)
)
def _dump(obj, stream):
_dump_registry.get(type(obj), _undumpable)(obj, stream)
def _dump(obj, stream, version=0):
dumper = _dump_registry[version].get(
type(obj), _undumpable
)
#===============================================================================
dumper(obj, stream, version)
# =============================================================================
# loading
#===============================================================================
# =============================================================================
@register(_load_registry, TAG_NAMED_TUPLE, 1)
def _load_named_tuple(stream, version):
tuple_id, = I4.unpack(stream.read(4))
obj_type = REGISTERED_NAMED_TUPLES_UNPACK.get(tuple_id)
if obj_type is None:
raise ValueError('Unregistered named tuple id %s', tuple_id)
tuple_data = _load(stream, version)
return obj_type(*tuple_data)
@register(_load_registry, TAG_IMMUTABLE_DICT, 1)
def _load_immutable_dict(stream, version):
items, = I4.unpack(stream.read(4))
dict_items = []
for _ in xrange(items):
dict_items.append(_load(stream, version))
return dict(dict_items)
@register(_load_registry, TAG_IMMUTABLE_SET, 1)
def _load_immutable_set(stream, version):
items, = I4.unpack(stream.read(4))
result = set()
for _ in xrange(items):
result.add(_load(stream, version))
return result
@register(_load_registry, TAG_IMMUTABLE_LIST, 1)
def _load_immutable_list(stream, version):
items, = I4.unpack(stream.read(4))
result = list()
for _ in xrange(items):
result.append(_load(stream, version))
return result
@register(_load_registry, TAG_NONE)
def _load_none(stream):
def _load_none(stream, version):
return None
@register(_load_registry, TAG_NOT_IMPLEMENTED)
def _load_nonimp(stream):
def _load_nonimp(stream, version):
return NotImplemented
@register(_load_registry, TAG_ELLIPSIS)
def _load_elipsis(stream):
def _load_elipsis(stream, version):
return Ellipsis
@register(_load_registry, TAG_TRUE)
def _load_true(stream):
def _load_true(stream, version):
return True
@register(_load_registry, TAG_FALSE)
def _load_false(stream):
def _load_false(stream, version):
return False
@register(_load_registry, TAG_EMPTY_TUPLE)
def _load_empty_tuple(stream):
def _load_empty_tuple(stream, version):
return ()
@register(_load_registry, TAG_EMPTY_STR)
def _load_empty_str(stream):
def _load_empty_str(stream, version):
return b''
if is_py3k:
@register(_load_registry, TAG_LONG)
def _load_long(stream):
obj = _load(stream)
def _load_long(stream, version):
obj = _load(stream, version)
return int(obj)
else:
@register(_load_registry, TAG_LONG)
def _load_long(stream):
obj = _load(stream)
def _load_long(stream, version):
obj = _load(stream, version)
return long(obj)
@register(_load_registry, TAG_FLOAT)
def _load_float(stream):
def _load_float(stream, version):
return F8.unpack(stream.read(8))[0]
@register(_load_registry, TAG_COMPLEX)
def _load_complex(stream):
def _load_complex(stream, version):
real, imag = C16.unpack(stream.read(16))
return complex(real, imag)
@register(_load_registry, TAG_STR1)
def _load_str1(stream):
def _load_str1(stream, version):
return stream.read(1)
@register(_load_registry, TAG_STR2)
def _load_str2(stream):
def _load_str2(stream, version):
return stream.read(2)
@register(_load_registry, TAG_STR3)
def _load_str3(stream):
def _load_str3(stream, version):
return stream.read(3)
@register(_load_registry, TAG_STR4)
def _load_str4(stream):
def _load_str4(stream, version):
return stream.read(4)
@register(_load_registry, TAG_STR_L1)
def _load_str_l1(stream):
def _load_str_l1(stream, version):
obj_len, = I1.unpack(stream.read(1))
return stream.read(obj_len)
@register(_load_registry, TAG_STR_L4)
def _load_str_l4(stream):
def _load_str_l4(stream, version):
obj_len, = I4.unpack(stream.read(4))
return stream.read(obj_len)
@register(_load_registry, TAG_UNICODE)
def _load_unicode(stream):
obj = _load(stream)
def _load_unicode(stream, version):
obj = _load(stream, version)
return obj.decode("utf-8")
@register(_load_registry, TAG_TUP1)
def _load_tup1(stream):
return (_load(stream),)
def _load_tup1(stream, version):
return (_load(stream, version),)
@register(_load_registry, TAG_TUP2)
def _load_tup2(stream):
return (_load(stream), _load(stream))
def _load_tup2(stream, version):
return (_load(stream, version), _load(stream, version))
@register(_load_registry, TAG_TUP3)
def _load_tup3(stream):
return (_load(stream), _load(stream), _load(stream))
def _load_tup3(stream, version):
return (
_load(stream, version), _load(stream, version),
_load(stream, version)
)
@register(_load_registry, TAG_TUP4)
def _load_tup4(stream):
return (_load(stream), _load(stream), _load(stream), _load(stream))
def _load_tup4(stream, version):
return (
_load(stream, version), _load(stream, version),
_load(stream, version), _load(stream, version)
)
@register(_load_registry, TAG_TUP_L1)
def _load_tup_l1(stream):
def _load_tup_l1(stream, version):
obj_len, = I1.unpack(stream.read(1))
return tuple(_load(stream) for _ in xrange(obj_len))
return tuple(_load(stream, version) for _ in xrange(obj_len))
@register(_load_registry, TAG_TUP_L4)
def _load_tup_l4(stream):
def _load_tup_l4(stream, version):
obj_len, = I4.unpack(stream.read(4))
return tuple(_load(stream) for _ in xrange(obj_len))
return tuple(_load(stream, version) for _ in xrange(obj_len))
@register(_load_registry, TAG_SLICE)
def _load_slice(stream):
start, stop, step = _load(stream)
def _load_slice(stream, version):
start, stop, step = _load(stream, version)
return slice(start, stop, step)
@register(_load_registry, TAG_FSET)
def _load_frozenset(stream):
return frozenset(_load(stream))
def _load_frozenset(stream, version):
return frozenset(_load(stream, version))
@register(_load_registry, TAG_INT_L1)
def _load_int_l1(stream):
def _load_int_l1(stream, version):
obj_len, = I1.unpack(stream.read(1))
return int(stream.read(obj_len))
@register(_load_registry, TAG_INT_L4)
def _load_int_l4(stream):
def _load_int_l4(stream, version):
obj_len, = I4.unpack(stream.read(4))
return int(stream.read(obj_len))
def _load(stream):
def _load(stream, version=0):
tag = stream.read(1)
if tag in IMM_INTS_LOADER:
return IMM_INTS_LOADER[tag]
return _load_registry.get(tag)(stream)
ival = ord(tag)
if ival >= 0x20 and ival < 0xf0:
return ival - 0x50
loader = _load_registry[version][ival]
if loader is None:
raise ValueError('Unknown tag 0x%02x' % (ival,))
return loader(stream, version)
#===============================================================================
# =============================================================================
# API
#===============================================================================
def dump(obj):
# =============================================================================
def dump(obj, version=0):
"""Converts (dumps) the given object to a byte-string representation
:param obj: any :func:`dumpable` object
@ -390,11 +562,11 @@ def dump(obj):
:returns: a byte-string representation of the object
"""
stream = []
_dump(obj, stream)
_dump(obj, stream, version)
return b''.join(stream)
def load(data):
def load(data, version=0):
"""Recreates (loads) an object from its byte-string representation
:param data: the byte-string representation of an object
@ -402,31 +574,55 @@ def load(data):
:returns: the dumped object
"""
stream = BytesIO(data)
return _load(stream)
return _load(stream, version)
if is_py3k:
simple_types = frozenset([type(None), int, bool, float, bytes, str, complex,
type(NotImplemented), type(Ellipsis)])
simple_types = frozenset([
type(None), int, bool, float, bytes, str, complex,
type(NotImplemented), type(Ellipsis)
])
else:
simple_types = frozenset([type(None), int, long, bool, float, str, unicode, complex,
type(NotImplemented), type(Ellipsis)])
simple_types = frozenset([
type(None), int, long, bool, float, str, unicode, complex,
type(NotImplemented), type(Ellipsis)
])
def dumpable(obj):
def dumpable(obj, version=0, log_deep=False, copy_mutable=True):
"""Indicates whether the given object is *dumpable* by brine
:returns: ``True`` if the object is dumpable (e.g., :func:`dump` would succeed),
:returns: ``True`` if the object is dumpable (e.g., :func:`dump`
would succeed),
``False`` otherwise
"""
if type(obj) in simple_types:
return True
if type(obj) in (tuple, frozenset):
return all(dumpable(item) for item in obj)
return all(dumpable(item, version, log_deep) for item in obj)
if type(obj) is slice:
return dumpable(obj.start) and dumpable(obj.stop) and dumpable(obj.step)
return \
dumpable(obj.start, version, log_deep) and \
dumpable(obj.stop, version, log_deep) and \
dumpable(obj.step, version, log_deep)
if type(obj) in _dump_registry[version]:
if isinstance(obj, tuple) or (type(obj) is set and copy_mutable):
return all(dumpable(item, version, True) for item in obj)
elif type(obj) is dict and copy_mutable:
return all(
dumpable(k, version, True) and dumpable(v, version, True)
for k, v in obj.items()
)
if __debug__:
if log_deep:
logging.debug(
'dumpable(deep): undumpable object type %s (%s)',
type(obj), repr(obj)
)
return False

View File

@ -17,6 +17,11 @@ LABEL_TUPLE = 2
LABEL_LOCAL_REF = 3
LABEL_REMOTE_REF = 4
LABEL_V1_SET = 5
LABEL_V1_DICT = 6
LABEL_V1_NAMED_TUPLE = 7
LABEL_V1_LIST = 8
# action handlers
HANDLE_PING = 1
HANDLE_CLOSE = 2
@ -36,6 +41,7 @@ HANDLE_DEL = 15
HANDLE_INSPECT = 16
HANDLE_BUFFITER = 17
HANDLE_OLDSLICING = 18
HANDLE_MAX = 19
# optimized exceptions
EXC_STOP_ITERATION = 1

View File

@ -20,19 +20,24 @@ from . import consts
_local_netref_attrs = frozenset([
'____conn__', '____oid__', '____refcount__', '__class__', '__cmp__', '__del__', '__delattr__',
'____conn__', '____oid__', '____refcount__', '___async_call__',
'___methods__',
'__class__', '__cmp__',
'__del__', '__delattr__',
'__dir__', '__doc__', '__getattr__', '__getattribute__', '__hash__',
'__init__', '__metaclass__', '__module__', '__new__', '__reduce__',
'__reduce_ex__', '__repr__', '__setattr__', '__slots__', '__str__',
'__weakref__', '__dict__', '__members__', '__methods__',
'__weakref__', '__dict__', '__members__'
])
"""the set of attributes that are local to the netref object"""
_builtin_types = [
type, object, bool, complex, dict, float, int, list, slice, str, tuple, set,
frozenset, Exception, type(None), types.BuiltinFunctionType, types.GeneratorType,
types.MethodType, types.CodeType, types.FrameType, types.TracebackType,
types.ModuleType, types.FunctionType,
type, object, bool, complex, dict, float, int, list,
slice, str, tuple, set, frozenset, Exception, type(None),
types.BuiltinFunctionType, types.GeneratorType,
types.MethodType, types.CodeType, types.FrameType,
types.TracebackType, types.ModuleType, types.FunctionType,
type(int.__add__), # wrapper_descriptor
type((1).__add__), # method-wrapper
@ -61,8 +66,17 @@ else:
types.InstanceType, type, types.DictProxyType,
])
_normalized_builtin_types = dict(((t.__name__, t.__module__), t)
for t in _builtin_types)
_normalized_builtin_types = dict(
((t.__name__, t.__module__), t)
for t in _builtin_types
)
class byref(object):
__slots__ = ('object',)
def __init__(self, obj):
self.object = obj
def syncreq(proxy, handler, *args):
@ -123,8 +137,8 @@ class BaseNetref(object):
defined in the :data:`_builtin_types`), and they are shared between all
connections.
The rest of the netref classes are created by :meth:`.core.protocl.Connection._unbox`,
and are private to the connection.
The rest of the netref classes are created by
:meth:`.core.protocl.Connection._unbox`, and are private to the connection.
Do not use this class directly; use :func:`class_factory` instead.
@ -135,15 +149,18 @@ class BaseNetref(object):
__metaclass__ = NetrefMetaclass
__slots__ = (
"____conn__", "____oid__", "__weakref__", "____refcount__"
'____conn__', '____oid__', '__weakref__', '____refcount__',
'___async_call__'
)
def __init__(self, conn, oid):
self.____conn__ = conn
self.____oid__ = oid
self.____refcount__ = 1
self.___async_call__ = None
def __del__(self):
# print('DEL', repr((type(self), repr(self))))
try:
asyncreq(self, consts.HANDLE_DEL, self.____refcount__)
except Exception:
@ -153,24 +170,43 @@ class BaseNetref(object):
pass
def __getattribute__(self, name):
methods = object.__getattribute__(self, '___methods__')
method = methods.get(name)
if method is not None:
bound_method = method.__get__(self)
setattr(
bound_method, '___async_call__',
method.___async_call__.__get__(self)
)
return bound_method
if name in _local_netref_attrs:
if name == "__class__":
cls = object.__getattribute__(self, "__class__")
if name == '__class__':
cls = object.__getattribute__(self, '__class__')
if cls is None:
cls = self.__getattr__("__class__")
cls = self.__getattr__('__class__')
return cls
elif name == "__doc__":
return self.__getattr__("__doc__")
elif name == "__members__": # for Python < 2.6
elif name == '__doc__':
return self.__getattr__('__doc__')
elif name == '__members__': # for Python < 2.6
return self.__dir__()
elif name == '___async_call__':
call = methods.get('__call__')
if call is None:
return None
return call.___async_call__.__get__(self)
else:
return object.__getattribute__(self, name)
elif name == "__call__": # IronPython issue #10
return object.__getattribute__(self, "__call__")
elif name == '__call__': # IronPython issue #10
return object.__getattribute__(self, '__call__')
else:
# print('GETATTRIBUTE', repr((type(self), repr(name), repr(self))))
return syncreq(self, consts.HANDLE_GETATTR, name)
def __getattr__(self, name):
# print('GETATTR', repr((type(self), name)))
return syncreq(
self, consts.HANDLE_GETATTR,
as_native_string(name)
@ -182,7 +218,8 @@ class BaseNetref(object):
self, as_native_string(name)
)
else:
syncreq(
# print('DELATTR', repr((type(self), name)))
asyncreq(
self, consts.HANDLE_DELATTR,
as_native_string(name)
)
@ -193,12 +230,14 @@ class BaseNetref(object):
self, as_native_string(name), value
)
else:
syncreq(
# print('SETATTR', repr((type(self), name)))
asyncreq(
self, consts.HANDLE_SETATTR,
as_native_string(name), value
)
def __dir__(self):
# print('DIR', repr((type(self))))
return list(
as_native_string(key) for key in syncreq(
self, consts.HANDLE_DIR
@ -207,21 +246,26 @@ class BaseNetref(object):
# support for metaclasses
def __hash__(self):
# print('HASH', repr((type(self))))
return syncreq(self, consts.HANDLE_HASH)
def __iter__(self):
# print('ITER', repr((type(self))))
return syncreq(
self, consts.HANDLE_CALLATTR, '__iter__'
)
def __cmp__(self, other):
# print('CMP', repr((type(self))))
return syncreq(self, consts.HANDLE_CMP, other)
def __repr__(self):
# print('REPR', repr((type(self))))
# __repr__ MUST return string
return as_native_string(syncreq(self, consts.HANDLE_REPR))
def __str__(self):
# print('STR', repr((type(self))))
# __str__ MUST return string
return as_native_string(syncreq(self, consts.HANDLE_STR))
@ -231,7 +275,7 @@ class BaseNetref(object):
if not isinstance(BaseNetref, NetrefMetaclass):
# python 2 and 3 compatible metaclass...
# python 2 and 3 compatibe metaclass...
ns = dict(BaseNetref.__dict__)
for slot in BaseNetref.__slots__:
ns.pop(slot)
@ -243,18 +287,28 @@ def _make_method(name, doc):
:func:`syncreq` on its `self` argument"""
slicers = {
"__getslice__": "__getitem__",
"__delslice__": "__delitem__",
"__setslice__": "__setitem__"
'__getslice__': '__getitem__',
'__delslice__': '__delitem__',
'__setslice__': '__setitem__'
}
name = as_native_string(name) # IronPython issue #10
if name == "__call__":
if name == '__call__':
def __call__(_self, *args, **kwargs):
kwargs = tuple(kwargs.items())
return syncreq(_self, consts.HANDLE_CALL, args, kwargs)
def ___async_call__(_self, *args, **kwargs):
kwargs = tuple(kwargs.items())
return asyncreq(_self, consts.HANDLE_CALL, args, kwargs)
__call__.__doc__ = doc
___async_call__.__doc__ = doc + ' (async)'
setattr(__call__, '___async_call__', ___async_call__)
setattr(___async_call__, '___async_call__', ___async_call__)
return __call__
elif name in slicers:
@ -267,8 +321,23 @@ def _make_method(name, doc):
name, start, stop, args
)
def async_method(self, start, stop, *args):
if stop == maxint:
stop = None
return asyncreq(
self, consts.HANDLE_OLDSLICING, slicers[name],
name, start, stop, args
)
method.__name__ = name
method.__doc__ = doc
async_method.__name__ = name
async_method.__doc__ = doc + ' (async)'
setattr(method, '___async_call__', async_method)
setattr(async_method, '___async_call__', async_method)
return method
else:
@ -278,8 +347,21 @@ def _make_method(name, doc):
_self, consts.HANDLE_CALLATTR, name, args, kwargs
)
def async_method(_self, *args, **kwargs):
# print("ASYNC METHOD CALL ARGS", _self, args, kwargs)
kwargs = tuple(kwargs.items())
return asyncreq(
_self, consts.HANDLE_CALLATTR, name, args, kwargs
)
method.__name__ = name
method.__doc__ = doc
async_method.__name__ = name
async_method.__doc__ = doc + ' (async)'
setattr(method, '___async_call__', async_method)
setattr(async_method, '___async_call__', async_method)
return method
@ -320,7 +402,10 @@ def class_factory(clsname, modname, methods):
clsname = as_native_string(clsname)
modname = as_native_string(modname)
ns = {"__slots__": ()}
ns = {
'__slots__': (),
'___methods__': {}
}
for name, doc in methods:
name = as_native_string(name)
@ -330,19 +415,21 @@ def class_factory(clsname, modname, methods):
ns['__next__'] = _make_method(name, doc)
if name not in _local_netref_attrs:
ns[name] = _make_method(name, doc)
wrapper = _make_method(name, doc)
ns[name] = wrapper
ns['___methods__'][name] = wrapper
ns["__module__"] = modname
ns['__module__'] = modname
if modname in sys.modules and hasattr(sys.modules[modname], clsname):
ns["__class__"] = getattr(sys.modules[modname], clsname)
ns['__class__'] = getattr(sys.modules[modname], clsname)
elif (clsname, modname) in _normalized_builtin_types:
ns["__class__"] = _normalized_builtin_types[clsname, modname]
ns['__class__'] = _normalized_builtin_types[clsname, modname]
else:
# to be resolved by the instance
ns["__class__"] = None
ns['__class__'] = None
return type(clsname, (BaseNetref,), ns)

View File

@ -21,56 +21,69 @@ from ..lib import get_methods
from . import consts, brine, vinegar, netref
from .nowait import AsyncResult
class PingError(Exception):
"""The exception raised should :func:`Connection.ping` fail"""
pass
DEFAULT_CONFIG = dict(
DEFAULT_CONFIG = {
# ATTRIBUTES
allow_safe_attrs = True,
allow_exposed_attrs = True,
allow_public_attrs = False,
allow_all_attrs = False,
safe_attrs = {
'allow_safe_attrs': True,
'allow_exposed_attrs': True,
'allow_public_attrs': False,
'allow_all_attrs': False,
'safe_attrs': {
'__abs__', '__add__', '__and__', '__bool__', '__cmp__', '__contains__',
'__delitem__', '__delslice__', '__div__', '__divmod__', '__doc__',
'__eq__', '__float__', '__floordiv__', '__ge__', '__getitem__',
'__getslice__', '__gt__', '__hash__', '__hex__', '__iadd__', '__iand__',
'__idiv__', '__ifloordiv__', '__ilshift__', '__imod__', '__imul__',
'__index__', '__int__', '__invert__', '__ior__', '__ipow__', '__irshift__',
'__isub__', '__iter__', '__itruediv__', '__ixor__', '__le__', '__len__',
'__long__', '__lshift__', '__lt__', '__mod__', '__mul__', '__ne__',
'__neg__', '__new__', '__nonzero__', '__oct__', '__or__', '__pos__',
'__pow__', '__radd__', '__rand__', '__rdiv__', '__rdivmod__', '__repr__',
'__rfloordiv__', '__rlshift__', '__rmod__', '__rmul__', '__ror__',
'__rpow__', '__rrshift__', '__rshift__', '__rsub__', '__rtruediv__',
'__rxor__', '__setitem__', '__setslice__', '__str__', '__sub__',
'__truediv__', '__xor__', 'next', '__length_hint__', '__enter__',
'__exit__', '__next__'
'__getslice__', '__gt__', '__hash__', '__hex__', '__iadd__',
'__iand__', '__idiv__', '__ifloordiv__', '__ilshift__', '__imod__',
'__imul__', '__index__', '__int__', '__invert__', '__ior__',
'__ipow__', '__irshift__', '__isub__', '__iter__', '__itruediv__',
'__ixor__', '__le__', '__len__', '__long__', '__lshift__', '__lt__',
'__mod__', '__mul__', '__ne__', '__neg__', '__new__', '__nonzero__',
'__oct__', '__or__', '__pos__', '__pow__', '__radd__', '__rand__',
'__rdiv__', '__rdivmod__', '__repr__', '__rfloordiv__', '__rlshift__',
'__rmod__', '__rmul__', '__ror__', '__rpow__', '__rrshift__',
'__rshift__', '__rsub__', '__rtruediv__', '__rxor__', '__setitem__',
'__setslice__', '__str__', '__sub__', '__truediv__', '__xor__', 'next',
'__length_hint__', '__enter__', '__exit__', '__next__'
},
exposed_prefix = "exposed_",
allow_getattr = True,
allow_setattr = False,
allow_delattr = False,
'exposed_prefix': 'exposed_',
'allow_getattr': True,
'allow_setattr': False,
'allow_delattr': False,
# EXCEPTIONS
include_local_traceback = True,
instantiate_custom_exceptions = False,
import_custom_exceptions = False,
instantiate_oldstyle_exceptions = False, # which don't derive from Exception
propagate_SystemExit_locally = False, # whether to propagate SystemExit locally or to the other party
propagate_KeyboardInterrupt_locally = True, # whether to propagate KeyboardInterrupt locally or to the other party
log_exceptions = True,
'include_local_traceback': True,
'instantiate_custom_exceptions': False,
'import_custom_exceptions': False,
# which don't derive from Exception
'instantiate_oldstyle_exceptions': False,
# whether to propagate SystemExit locally or to the other party
'propagate_SystemExit_locally': False,
# whether to propagate KeyboardInterrupt locally or to the other party
'propagate_KeyboardInterrupt_locally': True,
'log_exceptions': True,
# MISC
allow_pickle = False,
connid = None,
credentials = None,
endpoints = None,
logger = None,
sync_request_timeout = 30,
)
'allow_pickle': False,
'connid': None,
'credentials': None,
'endpoints': None,
'logger': None,
'sync_request_timeout': 30,
}
"""
The default configuration dictionary of the protocol. You can override these parameters
The default configuration dictionary of the protocol. You can override these
parameters
by passing a different configuration dict to the :class:`Connection` class.
.. note::
@ -128,9 +141,11 @@ Parameter Default value Description
======================================= ================ =====================================================
"""
list_keys_type = type({}.keys())
_connection_id_generator = itertools.count(1)
class Connection(object):
"""The *connection* (AKA *protocol*).
@ -143,12 +158,25 @@ class Connection(object):
need to call :func:`_init_service` manually later
"""
def __init__(self, service, channel, config = {}, _lazy = False):
__slots__ = (
'__weakref__',
'_closed', '_config', '_config', '_channel', '_seqcounter',
'_recvlock', '_sendlock', '_sync_replies', '_sync_lock',
'_sync_event', '_async_callbacks', '_local_objects',
'_last_traceback', '_proxy_cache', '_netref_classes_cache',
'_remote_root', '_send_queue', '_local_root', '_closed'
)
def __init__(self, service, channel, config={}, _lazy=False):
self._closed = True
self._config = DEFAULT_CONFIG.copy()
self._config.update(config)
if self._config["connid"] is None:
self._config["connid"] = "conn%d" % (next(_connection_id_generator),)
self._config["connid"] = "conn%d" % (
next(_connection_id_generator),
)
self._channel = channel
self._seqcounter = itertools.count()
@ -169,7 +197,6 @@ class Connection(object):
self._init_service()
self._closed = False
def _init_service(self):
self._local_root.on_connect()
@ -189,7 +216,7 @@ class Connection(object):
#
# IO
#
def _cleanup(self, _anyway = True):
def _cleanup(self, _anyway=True):
if self._closed and not _anyway:
return
self._closed = True
@ -203,10 +230,8 @@ class Connection(object):
self._last_traceback = None
self._remote_root = None
self._local_root = None
#self._seqcounter = None
#self._config.clear()
def close(self, _catchall = True):
def close(self, _catchall=True):
"""closes the connection, releasing all held resources"""
if self._closed:
return
@ -219,7 +244,7 @@ class Connection(object):
if not _catchall:
raise
finally:
self._cleanup(_anyway = True)
self._cleanup(_anyway=True)
@property
def closed(self):
@ -230,7 +255,7 @@ class Connection(object):
"""Returns the connectin's underlying file descriptor"""
return self._channel.fileno()
def ping(self, data = None, timeout = 3):
def ping(self, data=None, timeout=3):
"""
Asserts that the other party is functioning properly, by making sure
the *data* is echoed back before the *timeout* expires
@ -242,15 +267,17 @@ class Connection(object):
"""
if data is None:
data = b'abcdefghijklmnopqrstuvwxyz' * 20
res = self.async_request(consts.HANDLE_PING, data, timeout = timeout)
res = self.async_request(
consts.HANDLE_PING, data, timeout=timeout
)
if res.value != data:
raise PingError("echo mismatches sent data")
def _get_seq_id(self):
return next(self._seqcounter)
def _send(self, msg, seq, args):
data = brine.dump((msg, seq, args))
def _send(self, msg, seq, args, version):
data = brine.dump((msg, seq, args), version)
# GC might run while sending data
# if so, a BaseNetref.__del__ might be called
# BaseNetref.__del__ must call asyncreq,
@ -280,52 +307,105 @@ class Connection(object):
finally:
self._sendlock.release()
def _send_request(self, seq, handler, args):
self._send(consts.MSG_REQUEST, seq, (handler, self._box(args)))
def _send_request(self, seq, handler, args, version=0):
self._send(
consts.MSG_REQUEST, seq, (
handler, self._box(args, version)
), version
)
def _send_reply(self, seq, obj):
self._send(consts.MSG_REPLY, seq, self._box(obj))
def _send_reply(self, seq, obj, version=0):
self._send(
consts.MSG_REPLY, seq, self._box(obj, version),
version
)
def _send_exception(self, seq, exctype, excval, exctb):
exc = vinegar.dump(exctype, excval, exctb,
include_local_traceback = self._config["include_local_traceback"])
self._send(consts.MSG_EXCEPTION, seq, exc)
def _send_exception(self, seq, exctype, excval, exctb, version):
exc = vinegar.dump(
exctype, excval, exctb,
self._config["include_local_traceback"],
version
)
self._send(
consts.MSG_EXCEPTION, seq, exc, version
)
#
# boxing
#
def _box(self, obj):
def _box(self, obj, version=0, copy_mutable=True):
"""store a local object in such a way that it could be recreated on
the remote party either by-value or by-reference"""
if brine.dumpable(obj):
if isinstance(obj, netref.byref):
obj = obj.object
copy_mutable = False
if brine.dumpable(obj, version, copy_mutable=copy_mutable):
return consts.LABEL_VALUE, obj
if type(obj) is tuple:
return consts.LABEL_TUPLE, tuple(self._box(item) for item in obj)
elif type(obj) is tuple:
return consts.LABEL_TUPLE, tuple(
self._box(item, version, copy_mutable) for item in obj
)
elif isinstance(obj, netref.BaseNetref) and obj.____conn__() is self:
return consts.LABEL_LOCAL_REF, obj.____oid__
else:
self._local_objects.add(obj)
try:
cls = obj.__class__
except Exception:
# see issue #16
cls = type(obj)
if not isinstance(cls, type):
cls = type(obj)
return consts.LABEL_REMOTE_REF, (id(obj), cls.__name__, cls.__module__)
def _unbox(self, package):
elif version > 0:
if copy_mutable and type(obj) is set:
return consts.LABEL_V1_SET, tuple(
self._box(item, version, copy_mutable) for item in obj
)
elif (copy_mutable and type(obj) is list) or (
is_py3k and type(obj) is list_keys_type):
return consts.LABEL_V1_LIST, tuple(
self._box(item, version, copy_mutable) for item in obj
)
elif copy_mutable and type(obj) is dict:
return consts.LABEL_V1_DICT, tuple(
(
self._box(key, version, copy_mutable),
self._box(value, version, copy_mutable)
) for key, value in obj.items()
)
elif isinstance(obj, tuple) and \
type(obj) in brine.REGISTERED_NAMED_TUPLES_PACK:
return consts.LABEL_V1_NAMED_TUPLE, (
brine.REGISTERED_NAMED_TUPLES_PACK[type(obj)],
tuple(
self._box(item, version, copy_mutable) for item in obj
)
)
self._local_objects.add(obj)
try:
cls = obj.__class__
except Exception:
# see issue #16
cls = type(obj)
if not isinstance(cls, type):
cls = type(obj)
return consts.LABEL_REMOTE_REF, (
id(obj), cls.__name__, cls.__module__
)
def _unbox(self, package, version=0):
"""recreate a local object representation of the remote object: if the
object is passed by value, just return it; if the object is passed by
reference, create a netref to it"""
label, value = package
if label == consts.LABEL_VALUE:
return value
if label == consts.LABEL_TUPLE:
return tuple(self._unbox(item) for item in value)
if label == consts.LABEL_LOCAL_REF:
elif label == consts.LABEL_TUPLE:
return tuple(self._unbox(item, version) for item in value)
elif label == consts.LABEL_LOCAL_REF:
return self._local_objects[value]
if label == consts.LABEL_REMOTE_REF:
elif label == consts.LABEL_REMOTE_REF:
oid, clsname, modname = value
if oid in self._proxy_cache:
proxy = self._proxy_cache[oid]
@ -338,6 +418,29 @@ class Connection(object):
proxy = self._netref_factory(oid, clsname, modname)
self._proxy_cache[oid] = proxy
return proxy
elif version > 0:
if label == consts.LABEL_V1_SET:
return set(
self._unbox(item, version) for item in value
)
elif label == consts.LABEL_V1_LIST:
return list(
self._unbox(item, version) for item in value
)
elif label == consts.LABEL_V1_NAMED_TUPLE:
tuple_id, values = value
if tuple_id not in brine.REGISTERED_NAMED_TUPLES_UNPACK:
raise ValueError('Unexpected tuple id %s' % (tuple_id,))
tuple_type = brine.REGISTERED_NAMED_TUPLES_UNPACK[tuple_id]
return tuple_type(*tuple(
self._unbox(item, version) for item in values
))
elif label == consts.LABEL_V1_DICT:
return {
self._unbox(key, version): self._unbox(value, version)
for (key, value) in value
}
raise ValueError("invalid label %r" % (label,))
@ -356,38 +459,43 @@ class Connection(object):
#
# dispatching
#
def _dispatch_request(self, seq, raw_args):
def _dispatch_request(self, seq, raw_args, version):
try:
handler, args = raw_args
args = self._unbox(args)
args = self._unbox(args, version)
res = self._HANDLERS[handler](self, *args)
except:
except Exception:
# need to catch old style exceptions too
t, v, tb = sys.exc_info()
self._last_traceback = tb
logger = self._config["logger"]
if logger and t is not StopIteration:
logger.debug("Exception caught", exc_info=True)
if t is SystemExit and self._config["propagate_SystemExit_locally"]:
if t is SystemExit and self._config[
"propagate_SystemExit_locally"]:
raise
if t is KeyboardInterrupt and self._config["propagate_KeyboardInterrupt_locally"]:
if t is KeyboardInterrupt and self._config[
"propagate_KeyboardInterrupt_locally"]:
raise
self._send_exception(seq, t, v, tb)
self._send_exception(seq, t, v, tb, version)
else:
self._send_reply(seq, res)
self._send_reply(seq, res, version)
def _dispatch_reply(self, seq, raw):
obj = self._unbox(raw)
def _dispatch_reply(self, seq, raw, version):
obj = self._unbox(raw, version)
if seq in self._async_callbacks:
self._async_callbacks.pop(seq)(False, obj)
else:
self._sync_replies[seq] = (False, obj)
def _dispatch_exception(self, seq, raw):
obj = vinegar.load(raw,
import_custom_exceptions = self._config["import_custom_exceptions"],
instantiate_custom_exceptions = self._config["instantiate_custom_exceptions"],
instantiate_oldstyle_exceptions = self._config["instantiate_oldstyle_exceptions"])
def _dispatch_exception(self, seq, raw, version):
obj = vinegar.load(
raw,
import_custom_exceptions=self._config["import_custom_exceptions"],
instantiate_custom_exceptions=self._config[
"instantiate_custom_exceptions"],
instantiate_oldstyle_exceptions=self._config[
"instantiate_oldstyle_exceptions"])
if seq in self._async_callbacks:
self._async_callbacks.pop(seq)(True, obj)
else:
@ -411,14 +519,14 @@ class Connection(object):
self._recvlock.release()
return data
def _dispatch(self, data):
msg, seq, args = brine.load(data)
def _dispatch(self, data, version=0):
msg, seq, args = brine.load(data, version)
if msg == consts.MSG_REQUEST:
self._dispatch_request(seq, args)
self._dispatch_request(seq, args, version)
elif msg == consts.MSG_REPLY:
self._dispatch_reply(seq, args)
self._dispatch_reply(seq, args, version)
elif msg == consts.MSG_EXCEPTION:
self._dispatch_exception(seq, args)
self._dispatch_exception(seq, args, version)
else:
raise ValueError("invalid message type: %r" % (msg,))
@ -427,7 +535,7 @@ class Connection(object):
if self._sync_lock.acquire(False):
try:
self._sync_event.clear()
data = self._recv(timeout, wait_for_lock = False)
data = self._recv(timeout, wait_for_lock=False)
if not data:
return False
self._dispatch(data)
@ -438,7 +546,7 @@ class Connection(object):
else:
self._sync_event.wait()
def poll(self, timeout = 0):
def poll(self, timeout=0):
"""Serves a single transaction, should one arrives in the given
interval. Note that handling a request/reply may trigger nested
requests, which are all part of a single transaction.
@ -446,7 +554,7 @@ class Connection(object):
:returns: ``True`` if a transaction was served, ``False`` otherwise"""
return self.sync_recv_and_dispatch(timeout, wait_for_lock=False)
def serve(self, timeout = 1):
def serve(self, timeout=1):
"""Serves a single request or reply that arrives within the given
time frame (default is 1 sec). Note that the dispatching of a request
might trigger multiple (nested) requests, thus this function may be
@ -501,7 +609,8 @@ class Connection(object):
def poll_all(self, timeout=0):
"""Serves all requests and replies that arrive within the given interval.
:returns: ``True`` if at least a single transaction was served, ``False`` otherwise
:returns: ``True`` if at least a single transaction was served,
``False`` otherwise
"""
at_least_once = False
t0 = time.time()
@ -540,12 +649,12 @@ class Connection(object):
else:
return obj
def _async_request(self, handler, args = (), callback = (lambda a, b: None)):
def _async_request(self, handler, args=(), callback=(lambda a, b: None)):
seq = self._get_seq_id()
self._async_callbacks[seq] = callback
try:
self._send_request(seq, handler, args)
except:
except Exception:
if seq in self._async_callbacks:
del self._async_callbacks[seq]
raise
@ -558,7 +667,8 @@ class Connection(object):
"""
timeout = kwargs.pop("timeout", None)
if kwargs:
raise TypeError("got unexpected keyword argument(s) %s" % (list(kwargs.keys()),))
raise TypeError("got unexpected keyword argument(s) %s" % (
list(kwargs.keys()),))
res = AsyncResult(weakref.proxy(self))
self._async_request(handler, args, res)
if timeout is not None:
@ -589,15 +699,18 @@ class Connection(object):
if self._config["allow_all_attrs"]:
return name
if self._config["allow_safe_attrs"] and name in self._config["safe_attrs"]:
if self._config["allow_safe_attrs"] and \
name in self._config["safe_attrs"]:
return name
if self._config["allow_public_attrs"] and not name.startswith("_"):
if self._config["allow_public_attrs"] and \
not name.startswith("_"):
return name
return False
def _access_attr(self, oid, name, args, overrider, param, default):
name = as_native_string(name)
obj = self._local_objects[oid]
@ -652,7 +765,6 @@ class Connection(object):
kwargs = {
as_native_string(key): value for (key, value) in kwargs
}
return self._local_objects[oid](*args, **kwargs)
def _handle_dir(self, oid):
@ -693,13 +805,18 @@ class Connection(object):
as_native_string(key): value for (key, value) in kwargs
}
return self._handle_getattr(
oid, as_native_string(name)
)(*args, **kwargs)
function = self._access_attr(
oid,
as_native_string(name), (),
"_rpyc_getattr", "allow_getattr", getattr
)
return function(*args, **kwargs)
def _handle_pickle(self, oid, proto):
if not self._config["allow_pickle"]:
raise ValueError("pickling is disabled")
return pickle.dumps(self._local_objects[oid], proto)
def _handle_buffiter(self, oid, count):
@ -727,7 +844,7 @@ class Connection(object):
return getslice(start, stop, *args)
# collect handlers
_HANDLERS = {}
_HANDLERS = [None] * consts.HANDLE_MAX
for name, obj in dict(locals()).items():
if name.startswith("_handle_"):

View File

@ -37,7 +37,7 @@ except NameError:
# python 2.4 compatible
BaseException = Exception
def dump(typ, val, tb, include_local_traceback):
def dump(typ, val, tb, include_local_traceback, version):
"""Dumps the given exceptions info, as returned by ``sys.exc_info()``
:param typ: the exception's type (class)
@ -75,7 +75,7 @@ def dump(typ, val, tb, include_local_traceback):
for name in dir(val):
if name == "args":
for a in val.args:
if brine.dumpable(a):
if brine.dumpable(a, version):
args.append(a)
else:
args.append(repr(a))
@ -87,7 +87,7 @@ def dump(typ, val, tb, include_local_traceback):
except AttributeError:
# skip this attr. see issue #108
continue
if not brine.dumpable(attrval):
if not brine.dumpable(attrval, version):
attrval = repr(attrval)
attrs.append((name, attrval))

View File

@ -147,19 +147,33 @@ def nowait(proxy):
order**. In particular, multiple subsequent async requests may be
executed in reverse order.
"""
async_call = getattr(proxy, '___async_call__')
if async_call is not None:
return async_call
pid = id(proxy)
if pid in _async_proxies_cache:
return _async_proxies_cache[pid]
if not hasattr(proxy, "____conn__") or not hasattr(proxy, "____oid__"):
raise TypeError("'proxy' must be a Netref: %r", (proxy,))
if not callable(proxy):
raise TypeError("'proxy' must be callable: %r (%s)" % (proxy, type(proxy)))
raise TypeError(
"'proxy' must be callable: %r (%s)" % (proxy, type(proxy))
)
caller = _Async(proxy)
_async_proxies_cache[id(caller)] = _async_proxies_cache[pid] = caller
return caller
nowait.__doc__ = _Async.__doc__
class timed(object):
"""Creates a timed asynchronous proxy. Invoking the timed proxy will
run in the background and will raise an :class:`network.lib.rpc.core.async.AsyncResultTimeout`

View File

@ -5,6 +5,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
__all__ = (
'append_PKCS7_padding',
'strip_PKCS7_padding',
@ -13,25 +14,48 @@ __all__ = (
)
import logging
import sys
if sys.version_info.major > 2:
xrange = range
def to_byte(x):
return bytes((x,))
def from_byte(x):
return x
else:
def to_byte(x):
return chr(x)
def from_byte(x):
return ord(x)
AES_BLOCK_SIZE = 16
def append_PKCS7_padding(data):
pad = AES_BLOCK_SIZE - (len(data) % AES_BLOCK_SIZE)
return data + chr(pad)*pad
return data + to_byte(pad)*pad
def strip_PKCS7_padding(data):
if len(data) % AES_BLOCK_SIZE != 0:
raise ValueError("data is not padded !")
raise ValueError('data is not padded')
padlen = ord(data[-1])
padlen = from_byte(data[-1])
if padlen > AES_BLOCK_SIZE or padlen < 1:
raise ValueError("PKCS#7 invalid padding byte")
if data[-padlen:]!=chr(padlen)*padlen:
raise ValueError("PKCS#7 padding is invalid")
raise ValueError('PKCS#7 invalid padding byte')
if data[-padlen:] != to_byte(padlen) * padlen:
raise ValueError('PKCS#7 padding is invalid')
return data[:-padlen]
try:
from Crypto.Cipher import AES
from Crypto.Util import Counter
@ -54,7 +78,10 @@ try:
return AES.new(aes_key, mode, IV=iv)
except ImportError as e:
logging.warning('pycrypto not available, using pure python libraries for AES (slower): %s', e)
logging.warning(
'pycrypto not available, using pure python '
'libraries for AES (slower): %s', e
)
AES_MODE_CTR = 0
AES_MODE_CFB = 1
@ -85,8 +112,9 @@ except ImportError as e:
iv = long(iv.encode('hex'), 16)
self.iv = Counter(initial_value=iv)
self.cipher = AESModeOfOperationCTR(self.aes_key, counter=self.iv)
self.cipher = AESModeOfOperationCTR(
self.aes_key, counter=self.iv
)
def encrypt(self, data):
""" data has to be padded """
@ -95,8 +123,9 @@ except ImportError as e:
return self.cipher.encrypt(data)
encrypted = []
for i in range(0,len(data), AES_BLOCK_SIZE):
encrypted.append(self.cipher.encrypt(data[i:i+AES_BLOCK_SIZE]))
for i in xrange(0, len(data), AES_BLOCK_SIZE):
encrypted.append(
self.cipher.encrypt(data[i:i+AES_BLOCK_SIZE]))
return b''.join(encrypted)
@ -108,7 +137,8 @@ except ImportError as e:
cleartext = []
for i in range(0,len(data), AES_BLOCK_SIZE):
cleartext.append(self.cipher.decrypt(data[i:i+AES_BLOCK_SIZE]))
for i in xrange(0, len(data), AES_BLOCK_SIZE):
cleartext.append(
self.cipher.decrypt(data[i:i+AES_BLOCK_SIZE]))
return b''.join(cleartext)

View File

@ -16,17 +16,19 @@ import time
from io import open
from network.lib.convcompat import as_unicode_string_deep
from network.lib.convcompat import (
as_unicode_string, as_unicode_string_deep
)
families = {
v: k[3:] for k, v in socket.__dict__.items()
int(v): k[3:] for k, v in socket.__dict__.items()
if k.startswith('AF_')
}
try:
families.update({
psutil.AF_LINK: 'LINK'
int(psutil.AF_LINK): 'LINK'
})
except:
pass
@ -36,7 +38,7 @@ families.update(
)
socktypes = {
v: k[5:] for k, v in socket.__dict__.items()
int(v): k[5:] for k, v in socket.__dict__.items()
if k.startswith('SOCK_')
}
@ -138,22 +140,35 @@ def set_relations(infos):
infos['username'] = username
def _psutil_simplify(obj):
if hasattr(obj, 'value'):
return obj.value
else:
return obj
def _psiter(obj):
if hasattr(obj, '_fields'):
for field in obj._fields:
yield field, getattr(obj, field)
yield as_unicode_string(field) if isinstance(
field, bytes
) else field, _psutil_simplify(getattr(obj, field))
elif hasattr(obj, '__dict__'):
for k, v in iteritems(obj.__dict__):
yield k, v
yield as_unicode_string(k) if isinstance(
k, bytes
) else k, _psutil_simplify(v)
elif isinstance(obj, dict):
for k, v in iteritems(obj):
yield k, v
yield as_unicode_string(k) if isinstance(
k, bytes
) else k, _psutil_simplify(v)
else:
for v in obj:
yield v
yield _psutil_simplify(v)
def _is_iterable(obj):
@ -350,11 +365,10 @@ def connections():
for connection in net_connections:
obj = {
k: getattr(connection, k)
for k in (
'family', 'type', 'laddr', 'raddr', 'status'
)
as_unicode_string(k): tuple(v) if hasattr(v, '_fields') else v
for k, v in _psiter(connection)
}
try:
if connection.pid:
obj.update(
@ -389,7 +403,7 @@ def interfaces():
} for z in y
] for x, y in _psiter(if_addrs)
}
except:
except Exception:
addrs = None
try:
@ -400,7 +414,7 @@ def interfaces():
k: _tryint(v) for k, v in _psiter(y)
} for x, y in _psiter(if_stats)
}
except:
except Exception:
stats = None
return psutil_str({

View File

@ -256,7 +256,7 @@ class SafePopen(object):
close_cb()
return
queue = Queue.Queue()
queue = Queue()
self._reader = threading.Thread(
target=read_pipe,
args=(queue, self._pipe, self._bufsize)

View File

@ -2,47 +2,78 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
__all__ = ["PythonCompleter"]
class PythonCompleter:
__all__ = ('PythonCompleter',)
import re
import inspect
try:
import keyword
except ImportError:
keyword = None
from network.lib.convcompat import as_native_string
class PythonCompleter(object):
__slots__ = (
'local_ns', 'global_ns', 'matches'
)
def __init__(self, local_ns=None, global_ns=None):
if local_ns is not None:
self.local_ns=local_ns
self.local_ns = local_ns
else:
self.local_ns={}
self.local_ns = {}
if global_ns is not None:
self.global_ns=global_ns
self.global_ns = global_ns
else:
self.global_ns={}
self.global_ns = globals()
self.matches = ()
def complete(self, text, state):
text = as_native_string(text)
if state == 0:
if "." in text:
self.matches = self.attr_matches(text)
if '.' in text:
self.matches = tuple(self.attr_matches(text))
else:
self.matches = self.var_matches(text)
self.matches = tuple(self.var_matches(text))
try:
return self.matches[state]
except IndexError:
return None
def _callable_postfix(self, val, word):
if hasattr(val, '__call__'):
word = word + "("
return word
def var_matches(self, text):
import re
m = re.match(r"(\w*)", text)
m = re.match(r'\s*(\w+)', text)
if not m:
return []
words=[x for x in self.local_ns if x.startswith(m.group(1))]
if "__builtins__" in words:
words.remove("__builtins__")
text = m.group(1)
words = [
x for x in self.local_ns if x.startswith(text)
]
if keyword is not None:
words.extend(
x for x in keyword.kwlist if x.startswith(text)
)
if '__builtins__' in words:
words.remove('__builtins__')
return words
def attr_matches(self, text):
"""Compute matches when text contains a dot.
'''
Compute matches when text contains a dot.
Assuming the text is of the form NAME.NAME....[NAME], and is
evaluatable in self.namespace, it will be evaluated and its attributes
@ -51,57 +82,69 @@ class PythonCompleter:
WARNING: this can still invoke arbitrary C code, if an object
with a __getattr__ hook is evaluated.
"""
import re
bsw="[a-zA-Z0-9_\\(\\)\\[\\]\"']"
m = re.match(r"(\w+(\.\w+)*)\.(\w*)".replace(r"\w",bsw), text)
'''
bsw = "[a-zA-Z0-9_\\(\\)\\[\\]\"']"
m = re.match(r'(\w+(\.\w+)*)\.(\w*)'.replace(r'\w', bsw), text)
if not m:
return []
expr, attr = m.group(1, 3)
try:
try:
thisobject = eval(expr, self.global_ns, self.local_ns)
except NameError:
"""
print str(e)
chain = expr.split('.')
thisobject = None
while chain:
thisobject_name = chain.pop(0)
if thisobject_name is None:
break
if thisobject is None:
thisobject = self.local_ns.get(thisobject_name)
if thisobject is None:
return []
else:
try:
exec "import %s"%expr in global_ns, self.local_ns
thisobject = eval(expr, global_ns, self.local_ns)
except ImportError:
pass
"""
except:
thisobject = object.__getattribute__(
thisobject, thisobject_name
)
except AttributeError:
return []
if thisobject is None:
return []
# get the content of the object, except __builtins__
words = dir(thisobject)
if "__builtins__" in words:
words.remove("__builtins__")
words = [
name for name, value in inspect.getmembers(thisobject)
]
if hasattr(thisobject, '__class__'):
words.append('__class__')
words.extend(get_class_members(thisobject.__class__))
words=[x for x in words if not x.startswith("__")]
matches = []
n = len(attr)
for word in words:
if word[:n] == attr and hasattr(thisobject, word):
val = getattr(thisobject, word)
word = self._callable_postfix(val, "%s.%s" % (expr, word))
matches.append(word)
if attr and word[:n] != attr:
continue
value = object.__getattribute__(thisobject, word)
try:
object.__getattribute__(value, '__call__')
word += '('
except AttributeError:
pass
matches.append(expr + '.' + word)
return matches
def get_class_members(klass):
ret = dir(klass)
if hasattr(klass,'__bases__'):
for base in klass.__bases__:
ret = ret + get_class_members(base)
return ret
if __name__=="__main__":
import code
import readline
readline.set_completer(PythonCompleter().complete)
readline.parse_and_bind('tab: complete')
code.interact()
if hasattr(klass, '__bases__'):
for base in klass.__bases__:
ret.extend(get_class_members(base))
return ret

View File

@ -1,22 +1,36 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2015, Nicolas VERDIER (contact@n1nj4.eu)
# Pupy is under the BSD 3-Clause license. see the LICENSE file at the root of the project for the detailed licence terms
# Pupy is under the BSD 3-Clause license. see the LICENSE file
# at the root of the project for the detailed licence terms
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import code
from . import PythonCompleter
def new_exit(*args, **kwargs):
print("use ctrl+D to exit the interactive python interpreter.")
class PyShellController(object):
__slots__ = (
'_local_ns', 'console', 'completer'
)
def __init__(self):
local_ns={'exit':new_exit}
self.console=code.InteractiveConsole(local_ns)
self.completer=PythonCompleter.PythonCompleter(global_ns=globals(), local_ns=local_ns).complete
self._local_ns = {
'exit': new_exit
}
self.console = code.InteractiveConsole(self._local_ns)
self.completer = PythonCompleter.PythonCompleter(
global_ns=globals(), local_ns=self._local_ns
).complete
def write(self, line):
self.console.push(line)

View File

@ -771,6 +771,7 @@ def transfer_closure(callback, exclude=None, include=None, follow_symlinks=False
return _closure, _stop, t.terminate
if __name__ == '__main__':
import StringIO

View File

@ -71,13 +71,9 @@ for module in ('nt', 'posix'):
if module in sys.builtin_module_names:
os_ = __import__(module)
if sys.version_info.major < 3:
if sys.version_info.major > 2:
xrange = range
__all__ = tuple(
export.encode('ascii') for export in __all__
)
def _stub(*args, **kwargs):
raise NotImplementedError()

View File

@ -19,6 +19,8 @@ sys.path.extend((
root, os.path.join(root, 'library_patches')
))
sys.tracebacklimit = 50
import pupylib
assert(pupylib)

View File

@ -78,6 +78,14 @@ REVERSE_SLAVE_CONF = dict(
)
class Namespace(dict):
pass
class Cleanups(list):
pass
logger = pupy.get_logger('service')
_all = as_native_string('*')
@ -145,36 +153,40 @@ class ReverseSlaveService(Service):
)
def __init__(self, conn):
self.exposed_namespace = {}
self.exposed_cleanups = []
self.exposed_namespace = Namespace()
self.exposed_cleanups = Cleanups()
super(ReverseSlaveService, self).__init__(conn)
def on_connect(self):
self.exposed_namespace = {}
self.exposed_cleanups = []
self.exposed_namespace = Namespace()
self.exposed_cleanups = Cleanups()
self._conn._config.update(REVERSE_SLAVE_CONF)
infos_buffer = Buffer()
infos = self.exposed_get_infos()
pupyimporter = __import__('pupyimporter')
try:
umsgpack.dump(infos, infos_buffer, ext_handlers=MSG_TYPES_PACK)
except Exception as e:
pupy.remote_error('on_connect failed: {}; infos={}', e, infos)
self._conn.root.initialize_v1(
self._conn.root.initialize_v2(
1, (
sys.version_info.major,
sys.version_info.minor
),
self.exposed_namespace,
pupy.namespace,
__import__(builtin),
self.exposed_register_cleanup,
self.exposed_unregister_cleanup,
self.exposed_obtain_call,
self.exposed_exit,
self.exposed_eval,
self.exposed_execute,
__import__('pupyimporter'),
infos_buffer
self.exposed_get_infos(),
tuple(sys.modules),
tuple(pupyimporter.modules),
pupyimporter,
{
function: getattr(pupyimporter, function)
for function in dir(pupyimporter)
if hasattr(getattr(pupyimporter, function), '__call__')
}
)
def on_disconnect(self):

View File

@ -107,6 +107,9 @@ class PupyClient(object):
if self.conn.protocol_version is None:
# Legacy client
self._legacy_init()
else:
# Extended init
self._versioned_init(self.conn.protocol_version)
# To reuse impersonated handle in other modules
self.impersonated_dupHandle = None
@ -205,6 +208,7 @@ class PupyClient(object):
def remote(self, module, function=None, need_obtain=True):
remote_module = None
remote_function = None
need_obtain = need_obtain and self.obtain_call is not False
with self.remotes_lock:
if module in self.remotes:
@ -267,8 +271,47 @@ class PupyClient(object):
return remote_variable
def _versioned_init(self, version):
self.pupyimporter = self.conn.pupyimporter
register_package_request_hook = nowait(
self.conn.pupyimporter_funcs['register_package_request_hook']
)
register_package_error_hook = nowait(
self.conn.pupyimporter_funcs['register_package_error_hook']
)
self.conn.register_remote_cleanup(
self.conn.pupyimporter_funcs['unregister_package_request_hook']
)
register_package_request_hook(self.remote_load_package)
self.conn.register_remote_cleanup(
self.conn.pupyimporter_funcs['unregister_package_error_hook']
)
register_package_error_hook(self.remote_print_error)
self.pupy_load_dll = self.conn.pupyimporter_funcs['load_dll']
self.remote_add_package = nowait(
self.conn.pupyimporter_funcs['pupy_add_package']
)
self.remote_invalidate_package = nowait(
self.conn.pupyimporter_funcs['invalidate_module']
)
self.new_dlls = self.conn.pupyimporter_funcs['new_dlls']
self.new_modules = self.conn.pupyimporter_funcs['new_modules']
self.obtain_call = lambda func, *args, **kwargs: func(*args, **kwargs)
self.imported_modules = self.conn.remote_loaded_modules
self.cached_modules = self.conn.remote_cached_modules
def _legacy_init(self):
""" load pupyimporter in case it is not """
""" load pupyimporter in case it is not extended version """
if not self.conn.pupyimporter:
try:

View File

@ -17,6 +17,7 @@ import pupy
import encodings
def _as_unicode(x):
if isinstance(x, bytes):
try:
@ -29,22 +30,32 @@ def _as_unicode(x):
return x
# Restore write/stdout/stderr
if not hasattr(os, 'real_write'):
if type(os.write).__name__ == 'builtin_function_or_method':
os.real_write = os.write
setattr(os, 'real_write', os.write)
allowed_std = ('file', 'Blackhole', 'NoneType')
if not hasattr(sys, 'real_stdout') and type(sys.stdout).__name__ in allowed_std:
sys.real_stdout = sys.stdout
if not hasattr(sys, 'real_stderr') and type(sys.stderr).__name__ in allowed_std:
sys.real_stderr = sys.stderr
if not hasattr(sys, 'real_stdout') and type(
sys.stdout).__name__ in allowed_std:
setattr(sys, 'real_stdout', sys.stdout)
if not hasattr(sys, 'real_stderr') and type(
sys.stderr).__name__ in allowed_std:
setattr(sys, 'real_stderr', sys.stderr)
if not hasattr(sys, 'real_stdin') and type(
sys.stdin).__name__ in allowed_std:
setattr(sys, 'real_stdin', sys.stdin)
if not hasattr(sys, 'real_stdin') and type(sys.stdin).__name__ in allowed_std:
sys.real_stdin = sys.stdin
if not hasattr(os, 'stdout_write'):
def stdout_write(fd, s):
@ -55,42 +66,46 @@ if not hasattr(os, 'stdout_write'):
else:
return os.real_write(fd, s)
os.stdout_write = stdout_write
setattr(os, 'stdout_write', stdout_write)
# Remove IDNA module if it was not properly loaded
if hasattr(encodings, 'idna') and not hasattr(encodings.idna, 'getregentry'):
if 'encodings.idna' in sys.modules:
del sys.modules['encodings.idna']
if 'idna' in encodings._cache:
del encodings._cache['idna']
os_encoding = locale.getpreferredencoding() or "utf8"
if sys.platform == 'win32':
from _winreg import (
ConnectRegistry, HKEY_LOCAL_MACHINE, OpenKey, EnumValue
)
import ctypes
def redirect_stdo(stdout, stderr):
if not hasattr(sys, 'real_stdout'):
setattr(sys, 'real_stdout', sys.stdout)
if not hasattr(sys, 'real_stderr'):
setattr(sys, 'real_stderr', sys.stdout)
sys.stdout = stdout
sys.stderr = stderr
os.write = os.stdout_write
def redirect_stdio(stdin, stdout, stderr):
if not hasattr(sys, 'real_stdin'):
setattr(sys, 'real_stdin', sys.stdin)
sys.stdin = stdin
redirect_stdo(stdout, stderr)
def reset_stdo():
sys.stdout = sys.real_stdout
sys.stderr = sys.real_stderr
if hasattr(sys, 'real_stdout'):
sys.stdout = sys.real_stdout
if hasattr(sys, 'real_stderr'):
sys.stderr = sys.real_stderr
os.write = os.real_write
def reset_stdio():
sys.stdin = sys.real_stdin
if hasattr(sys, 'real_stdin'):
sys.stdout = sys.real_stdin
reset_stdo()
def get_integrity_level():
'''from http://www.programcreek.com/python/example/3211/ctypes.c_long'''
@ -100,6 +115,8 @@ def get_integrity_level():
else:
return "High"
import ctypes
mapping = {
0x0000: u'Untrusted',
0x1000: u'Low',
@ -129,42 +146,61 @@ def get_integrity_level():
TokenIntegrityLevel = ctypes.c_int(25)
ERROR_INSUFFICIENT_BUFFER = 122
ctypes.windll.kernel32.GetLastError.argtypes = ()
ctypes.windll.kernel32.GetLastError.restype = DWORD
ctypes.windll.kernel32.GetCurrentProcess.argtypes = ()
ctypes.windll.kernel32.GetCurrentProcess.restype = ctypes.c_void_p
ctypes.windll.advapi32.OpenProcessToken.argtypes = (
kernel32 = ctypes.windll.WinDLL('kernel32')
advapi32 = ctypes.windll.WinDLL('advapi32')
GetLastError = kernel32.GetLastError
GetLastError.argtypes = ()
GetLastError.restype = DWORD
CloseHandle = kernel32.CloseHandle
CloseHandle.argtypes = (HANDLE,)
GetCurrentProcess = kernel32.GetCurrentProcess
GetCurrentProcess.argtypes = ()
GetCurrentProcess.restype = ctypes.c_void_p
OpenProcessToken = advapi32.OpenProcessToken
OpenProcessToken.argtypes = (
HANDLE, DWORD, ctypes.POINTER(HANDLE))
ctypes.windll.advapi32.OpenProcessToken.restype = BOOL
ctypes.windll.advapi32.GetTokenInformation.argtypes = (
HANDLE, ctypes.c_long, ctypes.c_void_p, DWORD, ctypes.POINTER(DWORD))
ctypes.windll.advapi32.GetTokenInformation.restype = BOOL
ctypes.windll.advapi32.GetSidSubAuthorityCount.argtypes = [ctypes.c_void_p]
ctypes.windll.advapi32.GetSidSubAuthorityCount.restype = ctypes.POINTER(
ctypes.c_ubyte)
ctypes.windll.advapi32.GetSidSubAuthority.argtypes = (ctypes.c_void_p, DWORD)
ctypes.windll.advapi32.GetSidSubAuthority.restype = ctypes.POINTER(DWORD)
OpenProcessToken.restype = BOOL
GetTokenInformation = advapi32.GetTokenInformation
GetTokenInformation.argtypes = (
HANDLE, ctypes.c_long, ctypes.c_void_p,
DWORD, ctypes.POINTER(DWORD)
)
GetTokenInformation.restype = BOOL
GetSidSubAuthorityCount = advapi32.GetSidSubAuthorityCount
GetSidSubAuthorityCount.argtypes = (
ctypes.c_void_p,
)
GetSidSubAuthorityCount.restype = ctypes.POINTER(ctypes.c_ubyte)
GetSidSubAuthority = advapi32.GetSidSubAuthority
GetSidSubAuthority.argtypes = (
ctypes.c_void_p, DWORD
)
GetSidSubAuthority.restype = ctypes.POINTER(DWORD)
token = ctypes.c_void_p()
proc_handle = ctypes.windll.kernel32.GetCurrentProcess()
if not ctypes.windll.advapi32.OpenProcessToken(
proc_handle,
TOKEN_READ,
ctypes.byref(token)):
proc_handle = GetCurrentProcess()
if not OpenProcessToken(
proc_handle, TOKEN_READ, ctypes.byref(token)):
logging.error('Failed to get process token')
return None
if token.value == 0:
logging.error('Got a NULL token')
return None
try:
info_size = DWORD()
if ctypes.windll.advapi32.GetTokenInformation(
token,
TokenIntegrityLevel,
ctypes.c_void_p(),
info_size,
ctypes.byref(info_size)):
if GetTokenInformation(
token, TokenIntegrityLevel, ctypes.c_void_p(),
info_size, ctypes.byref(info_size)):
logging.error('GetTokenInformation() failed expectation')
return None
@ -172,71 +208,104 @@ def get_integrity_level():
logging.error('GetTokenInformation() returned size 0')
return None
if ctypes.windll.kernel32.GetLastError() != ERROR_INSUFFICIENT_BUFFER:
dwLastError = GetLastError()
if dwLastError != ERROR_INSUFFICIENT_BUFFER:
logging.error(
'GetTokenInformation(): Unknown error: %d',
ctypes.windll.kernel32.GetLastError())
'GetTokenInformation(): Unknown error: %d',
dwLastError
)
return None
token_info = TOKEN_MANDATORY_LABEL()
ctypes.resize(token_info, info_size.value)
if not ctypes.windll.advapi32.GetTokenInformation(
token,
TokenIntegrityLevel,
ctypes.byref(token_info),
info_size,
ctypes.byref(info_size)):
if not GetTokenInformation(
token, TokenIntegrityLevel, ctypes.byref(token_info),
info_size, ctypes.byref(info_size)):
logging.error(
'GetTokenInformation(): Unknown error with buffer size %d: %d',
info_size.value,
ctypes.windll.kernel32.GetLastError())
'GetTokenInformation(): Unknown error with buffer size %d: %d',
info_size.value, GetLastError()
)
return None
p_sid_size = ctypes.windll.advapi32.GetSidSubAuthorityCount(
token_info.Label.Sid)
res = ctypes.windll.advapi32.GetSidSubAuthority(
token_info.Label.Sid, p_sid_size.contents.value - 1)
p_sid_size = GetSidSubAuthorityCount(token_info.Label.Sid)
res = GetSidSubAuthority(
token_info.Label.Sid, p_sid_size.contents.value - 1
)
value = res.contents.value
return mapping.get(value) or u'0x%04x' % value
finally:
ctypes.windll.kernel32.CloseHandle(token)
CloseHandle(token)
def getUACLevel():
if sys.platform != 'win32':
return 'N/A'
i, consentPromptBehaviorAdmin, enableLUA, promptOnSecureDesktop = 0, None, None, None
from _winreg import (
ConnectRegistry, HKEY_LOCAL_MACHINE, OpenKey,
EnumValue, CloseKey
)
consentPromptBehaviorAdmin = None
enableLUA = None
promptOnSecureDesktop = None
try:
Registry = ConnectRegistry(None, HKEY_LOCAL_MACHINE)
RawKey = OpenKey(Registry, r'SOFTWARE\Microsoft\Windows\CurrentVersion\Policies\System')
except:
return "?"
while True:
try:
name, value, type = EnumValue(RawKey, i)
if name == "ConsentPromptBehaviorAdmin":
consentPromptBehaviorAdmin = value
elif name == "EnableLUA":
enableLUA = value
elif name == "PromptOnSecureDesktop":
promptOnSecureDesktop = value
i+=1
except WindowsError:
break
RawKey = OpenKey(
Registry,
'SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Policies\\System'
)
if consentPromptBehaviorAdmin == 2 and enableLUA == 1 and promptOnSecureDesktop == 1:
i = 0
while True:
try:
name, value, type = EnumValue(RawKey, i)
if name == "ConsentPromptBehaviorAdmin":
consentPromptBehaviorAdmin = value
elif name == "EnableLUA":
enableLUA = value
elif name == "PromptOnSecureDesktop":
promptOnSecureDesktop = value
i += 1
except WindowsError:
break
except Exception:
return "?"
finally:
CloseKey(RawKey)
if consentPromptBehaviorAdmin == 2 and enableLUA == 1 and \
promptOnSecureDesktop == 1:
return "3/3"
elif consentPromptBehaviorAdmin == 5 and enableLUA == 1 and promptOnSecureDesktop == 1:
elif consentPromptBehaviorAdmin == 5 and enableLUA == 1 and \
promptOnSecureDesktop == 1:
return "2/3"
elif consentPromptBehaviorAdmin == 5 and enableLUA == 1 and promptOnSecureDesktop == 0:
elif consentPromptBehaviorAdmin == 5 and enableLUA == 1 and \
promptOnSecureDesktop == 0:
return "1/3"
elif enableLUA == 0:
return "0/3"
else:
return "?"
def GetUserName():
from ctypes import windll, WinError, create_unicode_buffer, byref, c_uint32, GetLastError
from ctypes import (
windll, WinError, create_unicode_buffer,
byref, c_uint32, GetLastError
)
DWORD = c_uint32
nSize = DWORD(0)
@ -257,6 +326,7 @@ def GetUserName():
return lpBuffer.value
def get_uuid():
user = None
hostname = None

View File

@ -52,6 +52,7 @@ else:
from itertools import ifilterfalse as filterfalse
from netaddr import IPAddress
from netaddr.core import AddrFormatError
from random import randint
from tempfile import NamedTemporaryFile
from inspect import isclass
@ -69,7 +70,9 @@ from pupylib.PupyCompile import pupycompile
from pupylib.PupyOutput import Error, Line, Color
from pupylib.PupyModule import QA_STABLE, IgnoreModule, PupyModule
from pupylib.PupyDnsCnc import PupyDnsCnc
from pupylib.PupyTriggers import event, event_to_string, register_event_id, CUSTOM
from pupylib.PupyTriggers import (
event, event_to_string, register_event_id, CUSTOM
)
from pupylib.PupyTriggers import ON_CONNECT, ON_DISCONNECT, ON_START, ON_EXIT
from pupylib.PupyTriggers import RegistrationNotAllowed, UnregisteredEventId
from pupylib.PupyWeb import PupyWebServer
@ -86,7 +89,9 @@ from network.conf import transports
from network.transports.ssl.conf import PupySSLAuthenticator
from network.lib.connection import PupyConnectionThread
from network.lib.servers import PupyTCPServer
from network.lib.streams.PupySocketStream import PupySocketStream, PupyUDPSocketStream
from network.lib.streams.PupySocketStream import (
PupySocketStream, PupyUDPSocketStream
)
from network.lib.streams.PupyVirtualStream import PupyVirtualStream
from network.lib.utils import parse_transports_args
@ -115,7 +120,10 @@ class PupyKCPSocketStream(PupySocketStream):
class Listener(Thread):
def __init__(self, pupsrv, name, args, httpd=False, igd=False, local=None, external=None, pproxy=None):
def __init__(
self, pupsrv, name, args, httpd=False, igd=False,
local=None, external=None, pproxy=None):
Thread.__init__(self)
self.daemon = True
self.name = 'Listener({})'.format(name)
@ -126,7 +134,7 @@ class Listener(Thread):
self.name = name.lower().strip()
self.transport = transports[self.name]()
self.authenticator = self.transport.authenticator() if \
self.transport.authenticator else None
self.transport.authenticator else None
self.pupsrv = pupsrv
self.config = pupsrv.config
@ -183,7 +191,9 @@ class Listener(Thread):
self.ipv6 = (address.version == 6) or default_ipv6
except Exception as e:
raise ListenerException('Invalid IP: {} ({})'.format(ip, e))
raise ListenerException(
'Invalid IP: {} ({})'.format(ip, e)
)
else:
port = args[0]
@ -213,25 +223,31 @@ class Listener(Thread):
try:
self.external = str(IPAddress(extip))
except:
except AddrFormatError:
self.external = '127.0.0.1'
if '=' in port:
port = [x.strip() for x in port.split('=', 1)]
try:
self.external_port = int(port[0])
except:
raise ListenerException("Invalid external port: {}".format(port[0]))
except ValueError:
raise ListenerException(
"Invalid external port: {}".format(port[0])
)
try:
self.port = int(port[1])
except:
raise ListenerException("Invalid local port: {}".format(port[1]))
except ValueError:
raise ListenerException(
"Invalid local port: {}".format(port[1])
)
else:
try:
self.port = int(port)
except:
raise ListenerException("Invalid local port: {}".format(port[1]))
except ValueError:
raise ListenerException(
"Invalid local port: {}".format(port[1])
)
self.external_port = self.port
@ -335,7 +351,7 @@ class Listener(Thread):
logger.error(
"Couldn't delete IGD Mapping: {}".format(e.description)
)
except:
except Exception:
pass
if self.server:
@ -377,14 +393,14 @@ class Listener(Thread):
result += ' ' + ' '.join(
'{}={}'.format(
k, v if k != 'password' else '*'*len(v)
) for k,v in self.kwargs.items())
) for k, v in self.kwargs.items())
return '{}: {}'.format(self.name, result)
class PupyServer(object):
SUFFIXES = tuple([
suffix for suffix, _, rtype in imp.get_suffixes() \
suffix for suffix, _, rtype in imp.get_suffixes()
if rtype == imp.PY_SOURCE
])
@ -434,7 +450,8 @@ class PupyServer(object):
pproxy_dnscnc = None
if pproxy and ca and key and cert and (pproxy_listener_required or pproxy_dnscnc_required):
if pproxy and ca and key and cert and (
pproxy_listener_required or pproxy_dnscnc_required):
try:
pproxy_manager = PupyOffloadManager(
pproxy, ca, key, cert, via)
@ -452,11 +469,15 @@ class PupyServer(object):
' via {}'.format(via) if via else ''))
except (socket.error, OffloadProxyCommonError) as e:
self.motd['fail'].append('Offload proxy unavailable: {}'.format(e))
self.motd['fail'].append(
'Offload proxy unavailable: {}'.format(e)
)
except Exception as e:
logger.exception(e)
self.motd['fail'].append('Using Pupy Offload Proxy: Failed: {}'.format(e))
self.motd['fail'].append(
'Using Pupy Offload Proxy: Failed: {}'.format(e)
)
if self.config.getboolean('pupyd', 'httpd'):
self.httpd = True
@ -485,7 +506,8 @@ class PupyServer(object):
self.listeners = {}
dnscnc = self.config.get('pupyd', 'dnscnc')
if dnscnc and not dnscnc.lower() in ('no', 'false', 'stop', 'n', 'disable'):
if dnscnc and dnscnc.lower() not in (
'no', 'false', 'stop', 'n', 'disable'):
try:
self.dnscnc = PupyDnsCnc(
igd=self.igd,
@ -548,17 +570,29 @@ class PupyServer(object):
with self.clients_lock:
if isinstance(dst_id, int):
dst_client = [x for x in self.clients if x.desc['id'] == dst_id]
dst_client = [
x for x in self.clients if x.desc['id'] == dst_id
]
if not dst_client:
raise ValueError('Client with id {} not found'.format(dst_id))
raise ValueError(
'Client with id {} not found'.format(dst_id)
)
dst_client = dst_client[0]
else:
dst_client = dst_id
if isinstance(src_id, int):
src_client = [x for x in self.clients if x.desc['id'] == src_id]
src_client = [
x for x in self.clients if x.desc['id'] == src_id
]
if not src_client:
raise ValueError('Client with id {} not found'.format(src_id))
raise ValueError(
'Client with id {} not found'.format(src_id)
)
src_client = src_client[0]
else:
src_client = src_id
@ -577,7 +611,11 @@ class PupyServer(object):
logger.debug('Id not found in current_id list: %s', id)
def register_handler(self, instance):
""" register the handler instance, typically a PupyCmd, and PupyWeb in the futur"""
"""
register the handler instance, typically a PupyCmd,
and PupyWeb in the future
"""
self.handler = instance
if self.dnscnc:
@ -607,7 +645,9 @@ class PupyServer(object):
return True
return False
return nodeid in set([x.strip().lower() for x in allowed_nodes.split(',')])
return nodeid in set([
x.strip().lower() for x in allowed_nodes.split(',')
])
def add_client(self, conn):
client = None
@ -661,8 +701,8 @@ class PupyServer(object):
try:
if type(conn_id) is list:
address = conn_id[0]
address = conn_id.rsplit(':',1)[0]
except:
address = conn_id.rsplit(':', 1)[0]
except Exception:
address = str(address)
client_info.update({
@ -684,7 +724,7 @@ class PupyServer(object):
if self.handler:
try:
client_ip, client_port = conn_id.rsplit(':', 1)
except:
except Exception:
client_ip, client_port = '0.0.0.0', 0
if ':' in client_ip:
@ -724,8 +764,12 @@ class PupyServer(object):
self.info('Session {} closed'.format(client.desc['id']))
def get_clients(self, search_criteria):
""" return a list of clients corresponding to the search criteria. ex: platform:*win* """
#if the criteria is a simple id we return the good client
"""
return a list of clients corresponding to the search criteria.
ex: platform:*win*
"""
# if the criteria is a simple id we return the good client
if not search_criteria:
return self.clients
@ -746,33 +790,39 @@ class PupyServer(object):
clients = set([])
if search_criteria=="*":
if search_criteria == '*':
return self.clients
for c in self.clients:
take = False
tags = self.config.tags(c.node())
for sc in search_criteria.split():
tab = sc.split(":",1)
#if the field is specified we search for the value in this field
if len(tab)==2 and tab[0] in c.desc:
take=True
if not tab[1].lower() in str(c.desc[tab[0]]).lower():
take=False
break
elif len(tab)==2 and tab[0] == 'tag' and tab[1] in tags:
tab = sc.split(':', 1)
# if the field is specified we search for
# the value in this field
if len(tab) == 2 and tab[0] in c.desc:
take = True
elif len(tab)==2 and tab[0] == 'tags':
if not tab[1].lower() in str(c.desc[tab[0]]).lower():
take = False
break
elif len(tab) == 2 and tab[0] == 'tag' and tab[1] in tags:
take = True
elif len(tab) == 2 and tab[0] == 'tags':
if '&' in tab[1]:
take = all(x in tags for x in tab[1].split('&') if x)
else:
take = any(x in tags for x in tab[1].split(',') if x)
elif len(tab)!=2:#if there is no field specified we search in every field for at least one match
take=False
elif len(tab) != 2:
# if there is no field specified we search in every field
# for at least one match
take = False
if tab[0] in tags:
take = True
else:
for k,v in c.desc.items():
for k, v in c.desc.items():
if isinstance(v, basestring):
if tab[0].lower() in v.decode('utf8').lower():
take = True
@ -819,19 +869,29 @@ class PupyServer(object):
yield module
def get_module_name_from_category(self, modpath):
""" take a category virtual path and return the module's name or the path untouched if not found """
"""
take a category virtual path and return the module's
name or the path untouched if not found
"""
mod = self.categories.get_module_from_path(modpath)
if mod:
return mod.get_name()
else:
return modpath
def get_aliased_modules(self):
""" return a list of aliased module names that have to be displayed as commands """
"""
return a list of aliased module names that have to be
displayed as commands
"""
modules = []
for m in self.iter_modules():
if not m.is_module:
modules.append((m.get_name(), m.__doc__))
return modules
def _refresh_modules(self, force=False):
@ -870,7 +930,7 @@ class PupyServer(object):
current_stats = stat(modpath)
if not force and modname in self.modules and \
self._modules_stats[modname] == current_stats.st_mtime:
self._modules_stats[modname] == current_stats.st_mtime:
continue
try:
@ -889,12 +949,16 @@ class PupyServer(object):
Error('Invalid module:'),
Color(modname, 'yellow'),
'at ({}): {}. Traceback:\n{}'.format(
modpath, e, tb))
modpath, e, tb
)
)
self.info(error, error=True)
def get_module(self, name):
enable_dangerous_modules = self.config.getboolean('pupyd', 'enable_dangerous_modules')
enable_dangerous_modules = self.config.getboolean(
'pupyd', 'enable_dangerous_modules'
)
if name not in self.modules:
self._refresh_modules(force=True)
@ -923,22 +987,27 @@ class PupyServer(object):
'but it is already registered as "%s"',
name, event_name, registered_event_name)
raise PupyModuleDisabled('Modules with errors are disabled.')
raise PupyModuleDisabled(
'Modules with errors are disabled.'
)
except UnregisteredEventId:
try:
register_event_id(event_id, event_name)
except RegistrationNotAllowed:
logger.error(
'script "%s" registers event_id 0x%08x which is not allowed, '
'eventid should be >0x%08x',
name, event_id, CUSTOM)
'script "%s" registers event_id 0x%08x '
'which is not allowed, eventid should be >0x%08x',
name, event_id, CUSTOM
)
raise PupyModuleDisabled('Modules with errors are disabled.')
raise PupyModuleDisabled(
'Modules with errors are disabled.'
)
if not class_name:
#TODO automatically search the class name in the file
exit("Error : no __class_name__ for module %s"%module)
# TODO automatically search the class name in the file
exit('Error : no __class_name__ for module %s' % module)
module_class = getattr(module, class_name)
@ -949,9 +1018,13 @@ class PupyServer(object):
return module_class
def module_parse_args(self, module_name, args):
""" This method is used by the PupyCmd class to verify validity of arguments passed to a specific module """
module=self.get_module(module_name)
ps=module(None,None)
"""
This method is used by the PupyCmd class to verify validity
of arguments passed to a specific module
"""
module = self.get_module(module_name)
ps = module(None, None)
if ps.known_args:
return ps.arg_parser.parse_known_args(args)
else:
@ -959,22 +1032,23 @@ class PupyServer(object):
def del_job(self, job_id):
if job_id is not None:
job_id=int(job_id)
job_id = int(job_id)
if job_id in self.jobs:
del self.jobs[job_id]
def add_job(self, job):
job.id=self.jobs_id
self.jobs[self.jobs_id]=job
self.jobs_id+=1
job.id = self.jobs_id
self.jobs[self.jobs_id] = job
self.jobs_id += 1
def get_job(self, job_id):
try:
job_id=int(job_id)
job_id = int(job_id)
except ValueError:
raise PupyModuleError("job id must be an integer !")
if job_id not in self.jobs:
raise PupyModuleError("%s: no such job !"%job_id)
raise PupyModuleError("%s: no such job !" % job_id)
return self.jobs[job_id]
def create_virtual_connection(self, transport, peer):
@ -987,8 +1061,10 @@ class PupyServer(object):
transport_conf = transports.get(transport)
transport_class = transport_conf().server_transport
logger.debug('create_virtual_connection(%s, %s) - transport - %s / %s',
transport, peer, transport_conf, transport_class)
logger.debug(
'create_virtual_connection(%s, %s) - transport - %s / %s',
transport, peer, transport_conf, transport_class
)
stream = PupyVirtualStream(transport_class)
@ -1002,18 +1078,24 @@ class PupyServer(object):
})
def activate(peername, on_receive):
logger.debug('VirtualStream (%s, %s) - activating',
stream, peername)
logger.debug(
'VirtualStream (%s, %s) - activating',
stream, peername
)
stream.activate(peername, on_receive)
logger.debug('VirtualStream (%s, %s) - starting thread',
stream, peername)
logger.debug(
'VirtualStream (%s, %s) - starting thread',
stream, peername
)
vc.start()
logger.debug('VirstualStream (%s, %s) - activated',
stream, peername)
logger.debug(
'VirstualStream (%s, %s) - activated',
stream, peername
)
return activate, stream.submit, stream.close
@ -1070,11 +1152,14 @@ class PupyServer(object):
def add_listener(self, name, config=None, motd=False, ignore_pproxy=False):
if self.listeners and name in self.listeners:
self.handler.display_warning('Listener {} already registered'.format(name))
self.handler.display_warning(
'Listener {} already registered'.format(name)
)
return
if name not in transports:
error = 'Transport {} is not registered. To show available: listen -L'.format(repr(name))
error = 'Transport {} is not registered. ' \
'To show available: listen -L'.format(repr(name))
if motd:
self.motd['fail'].append(error)
@ -1085,8 +1170,8 @@ class PupyServer(object):
listener_config = config or self.config.get('listeners', name)
if not listener_config:
error = 'Transport {} does not have default settings. Specfiy args (at least port)'.format(
repr(name))
error = 'Transport {} does not have default settings. ' \
'Specfiy args (at least port)'.format(repr(name))
if motd:
self.motd['fail'].append(error)
@ -1131,12 +1216,15 @@ class PupyServer(object):
except socket.error as e:
if e.errno == errno.EACCES:
error = 'Listen: {}: Insufficient privileges to bind'.format(listener)
error = 'Listen: {}: ' \
'Insufficient privileges to bind'.format(listener)
elif e.errno == errno.EADDRINUSE:
error = 'Listen: {}: Address/Port already used'.format(listener)
error = 'Listen: {}: ' \
'Address/Port already used'.format(listener)
elif e.errno == errno.EADDRNOTAVAIL:
error = 'Listen: {}: No network interface with addresss {}'.format(
listener, listener.address)
error = 'Listen: {}: ' \
'No network interface with addresss {}'.format(
listener, listener.address)
else:
error = 'Listen: {}: {}'.format(listener, e)

View File

@ -47,6 +47,7 @@ from pupylib.PupyCredentials import Credentials
from network.lib.msgtypes import msgpack_exthook
from network.lib.rpc import Service, timed, nowait
from network.lib.convcompat import as_native_string
from . import getLogger
logger = getLogger('service')
@ -69,22 +70,22 @@ class PupyService(Service):
self.eval = None
self.execute = None
self.pupyimporter = None
self.pupyimporter_funcs = None
self.infos = None
self.get_infos = None
self.protocol_version = None
self.remote_version = (2, 7)
self.remote_arch = None
self.events_receiver = None
self.remote_loaded_modules = None
self.remote_cached_modules = None
def exposed_on_connect(self):
if sys.version_info.major == 3:
# Deprecated API
self._conn.activate_3to2()
# raise NotImplementedError(
# 'Too old RPC version - python3 to python2 is not supported'
# )
self._conn._config.update({
'allow_safe_attrs': False,
@ -114,6 +115,9 @@ class PupyService(Service):
infos, *args
):
if __debug__:
logger.debug('Initialize legacy V1 connection.')
if sys.version_info.major == 3:
# Deprecated API
self._conn.activate_3to2()
@ -136,6 +140,61 @@ class PupyService(Service):
self.pupy_srv.add_client(self)
def exposed_initialize_v2(
self,
protocol_version, remote_version,
namespace, modules, builtin,
register_cleanup, unregister_cleanup,
remote_exit, remote_eval, remote_execute,
infos, loaded_modules, cached_modules,
pupyimporter, pupyimporter_funcs, *args):
if __debug__:
logger.debug(
'Initialize V2 connection. Remote proto: %s Python: %s',
protocol_version, remote_version
)
self.protocol_version = protocol_version
self.remote_version = remote_version
if sys.version_info.major == 3 and \
self.remote_version[0] == 2:
if __debug__:
logger.debug(
'Enable python3 to python2 communication hacks'
)
self._conn.activate_3to2()
self.namespace = namespace
self.modules = modules
self.builtin = self.builtins = builtin
self.register_remote_cleanup = nowait(register_cleanup)
self.unregister_remote_cleanup = nowait(unregister_cleanup)
self.obtain_call = False
self.exit = timed(remote_exit, 1)
self.eval = remote_eval
self.execute = remote_execute
self.pupyimporter = pupyimporter
self.pupyimporter_funcs = {
as_native_string(func): ref
for func, ref in pupyimporter_funcs.items()
}
self.infos = infos
self.get_infos = lambda: self.infos
self.remote_loaded_modules = set(
as_native_string(module) for module in loaded_modules
)
self.remote_cached_modules = set(
as_native_string(module) for module in cached_modules
)
self.pupy_srv.add_client(self)
def register_local_cleanup(self, cleanup):
self._local_cleanups.append(cleanup)
@ -160,6 +219,9 @@ class PupyService(Service):
# Compatibility call
def exposed_set_modules(self, modules):
if __debug__:
logger.debug('Initialize legacy V0 connection.')
try:
self.modules = modules
self.builtin = modules.__builtin__
@ -174,19 +236,19 @@ class PupyService(Service):
try:
self.register_remote_cleanup = \
self._conn.root.register_cleanup
except:
except Exception:
self.register_remote_cleanup = None
if self.register_remote_cleanup:
try:
self.unregister_remote_cleanup = \
self._conn.root.unregister_cleanup
except:
except Exception:
self.unregister_remote_cleanup = None
try:
self.obtain_call = self._conn.root.obtain_call
except:
except Exception:
pass
self.exit = self._conn.root.exit
@ -199,7 +261,7 @@ class PupyService(Service):
logger.error(traceback.format_exc())
try:
self._conn.close()
except:
except Exception:
pass
def exposed_msgpack_dumps(self, js, compressed=False):

View File

@ -38,11 +38,16 @@ def safe_obtain(proxy):
ptype = type(proxy)
if type(proxy) in (tuple, list, set):
objs = list(safe_obtain(x) for x in proxy)
return ptype(objs)
return ptype([
safe_obtain(x) for x in proxy
])
return proxy
if conn.is_extended():
# No need to call obtain
return proxy
if not hasattr(conn, 'obtain'):
try:
setattr(conn, 'obtain', conn.root.msgpack_dumps)