diff --git a/src/pyodide-py/pyodide/__init__.py b/src/pyodide-py/pyodide/__init__.py index 141627da3..2e3fb6dcf 100644 --- a/src/pyodide-py/pyodide/__init__.py +++ b/src/pyodide-py/pyodide/__init__.py @@ -1,15 +1,15 @@ from ._base import open_url, eval_code, eval_code_async, find_imports from ._core import JsException, create_once_callable, create_proxy # type: ignore -from ._importhooks import JsFinder +from ._importhooks import jsfinder from .webloop import WebLoopPolicy +from . import _state # type: ignore # noqa import asyncio import sys import platform -jsfinder = JsFinder() +sys.meta_path.append(jsfinder) # type: ignore register_js_module = jsfinder.register_js_module unregister_js_module = jsfinder.unregister_js_module -sys.meta_path.append(jsfinder) # type: ignore if platform.system() == "Emscripten": asyncio.set_event_loop_policy(WebLoopPolicy()) diff --git a/src/pyodide-py/pyodide/_importhooks.py b/src/pyodide-py/pyodide/_importhooks.py index 1d5273d50..c30b09600 100644 --- a/src/pyodide-py/pyodide/_importhooks.py +++ b/src/pyodide-py/pyodide/_importhooks.py @@ -96,3 +96,6 @@ class JsLoader(Loader): # used by importlib.util.spec_from_loader def is_package(self, fullname): return True + + +jsfinder = JsFinder() diff --git a/src/pyodide-py/pyodide/_state.py b/src/pyodide-py/pyodide/_state.py new file mode 100644 index 000000000..4e1acfbd2 --- /dev/null +++ b/src/pyodide-py/pyodide/_state.py @@ -0,0 +1,48 @@ +import __main__ +import sys +import gc + +from ._core import JsProxy +from ._importhooks import jsfinder + + +def save_state() -> dict: + """Record the current global state. + + This includes which Javascript packages are loaded and the global scope in + ``__main__.__dict__``. Many loaded modules might have global state, but + there is no general way to track it and we don't try to. + """ + loaded_js_modules = {} + for [key, value] in sys.modules.items(): + if isinstance(value, JsProxy): + loaded_js_modules[key] = value + + return dict( + globals=dict(__main__.__dict__), + js_modules=dict(jsfinder.jsproxies), + loaded_js_modules=loaded_js_modules, + ) + + +def restore_state(state: dict): + """Restore the global state to a snapshot. The argument ``state`` should + come from ``save_state``""" + __main__.__dict__.clear() + __main__.__dict__.update(state["globals"]) + + jsfinder.jsproxies = state["js_modules"] + loaded_js_modules = state["loaded_js_modules"] + for [key, value] in list(sys.modules.items()): + if isinstance(value, JsProxy) and key not in loaded_js_modules: + del sys.modules[key] + sys.modules.update(loaded_js_modules) + + sys.last_type = None + sys.last_value = None + sys.last_traceback = None + + return gc.collect(2) + + +__all__ = ["save_state", "restore_state"] diff --git a/src/pyodide.js b/src/pyodide.js index a4de0e73d..47b564a31 100644 --- a/src/pyodide.js +++ b/src/pyodide.js @@ -567,7 +567,6 @@ globalThis.languagePluginLoader = (async () => { * @param {string} name Name of js module to add * @param {object} module Javascript object backing the module */ - // clang-format off Module.registerJsModule = function(name, module) { Module.pyodide_py.register_js_module(name, module); }; @@ -709,6 +708,11 @@ def temp(Module): Module.builtins = builtins.__dict__ Module.pyodide_py = pyodide `); + + Module.saveState = () => Module.pyodide_py._state.save_state(); + Module.restoreState = (state) => + Module.pyodide_py._state.restore_state(state); + Module.init_dict.get("temp")(Module); // Wrap "globals" in a special Proxy that allows `pyodide.globals.x` access. diff --git a/src/tests/test_pyodide.py b/src/tests/test_pyodide.py index 1a49c7c44..8521fd1ec 100644 --- a/src/tests/test_pyodide.py +++ b/src/tests/test_pyodide.py @@ -339,3 +339,39 @@ def test_create_proxy(selenium): `); """ ) + + +def test_restore_state(selenium): + selenium.run_js( + """ + pyodide.registerJsModule("a", {somefield : 82}); + pyodide.registerJsModule("b", { otherfield : 3 }); + pyodide.runPython("x = 7; from a import somefield"); + let state = pyodide._module.saveState(); + + pyodide.registerJsModule("c", { thirdfield : 9 }); + pyodide.runPython("y = 77; from b import otherfield; import c;"); + pyodide._module.restoreState(state); + state.destroy(); + """ + ) + + selenium.run( + """ + from unittest import TestCase + raises = TestCase().assertRaises + import sys + + assert x == 7 + assert "a" in sys.modules + assert somefield == 82 + with raises(NameError): + y + with raises(NameError): + otherfield + assert "b" not in sys.modules + import b + with raises(ModuleNotFoundError): + import c + """ + )