From 1396c8c07a69f17f71fe8c3522b0b5a7ec18ad5d Mon Sep 17 00:00:00 2001 From: Hood Chatham Date: Wed, 24 Jan 2024 15:19:02 -0800 Subject: [PATCH] Make pyimport able to return module attributes (#4395) Before this PR, `pyimport` can be used like: `pyimport("package")` or `pyimport("package.module")` but `pyimport("package.attribute")` fails. This updates `pyimport` to also work to get package attributes. I also updated the docs for pyimport. --- conftest.py | 10 +++++----- docs/project/changelog.md | 4 ++++ src/js/api.ts | 40 ++++++++++++++++++--------------------- src/js/types.ts | 1 + src/py/_pyodide/_base.py | 8 ++++++++ src/tests/test_pyodide.py | 12 ++++++++++-- 6 files changed, 46 insertions(+), 29 deletions(-) diff --git a/conftest.py b/conftest.py index 136a0bbc8..384f68513 100644 --- a/conftest.py +++ b/conftest.py @@ -35,18 +35,18 @@ pytest_pyodide.runner.NODE_FLAGS.extend(["--experimental-wasm-stack-switching"]) # We need to go through and touch them all once to keep everything okay. pytest_pyodide.runner.INITIALIZE_SCRIPT = """ pyodide.globals.get; + pyodide.runPython("import pyodide_js._api; del pyodide_js"); + pyodide._api.importlib.invalidate_caches; + pyodide._api.package_loader.unpack_buffer; + pyodide._api.package_loader.get_dynlibs; pyodide._api.pyodide_code.eval_code; pyodide._api.pyodide_code.eval_code_async; pyodide._api.pyodide_code.find_imports; pyodide._api.pyodide_ffi.register_js_module; pyodide._api.pyodide_ffi.unregister_js_module; - pyodide._api.importlib.invalidate_caches; - pyodide._api.package_loader.unpack_buffer; - pyodide._api.package_loader.get_dynlibs; - pyodide.runPython(""); pyodide.pyimport("pyodide.ffi.wrappers").destroy(); pyodide.pyimport("pyodide.http").destroy(); - pyodide.pyimport("pyodide_js._api") + pyodide.pyimport("pyodide_js._api"); """ only_node = pytest.mark.xfail_browsers( diff --git a/docs/project/changelog.md b/docs/project/changelog.md index c4fcee287..0c6b288ac 100644 --- a/docs/project/changelog.md +++ b/docs/project/changelog.md @@ -21,11 +21,15 @@ myst: - {{ Breaking }} `pyodide-build` entrypoint is removed in favor of `pyodide`. This entrypoint was deprecated since 0.22.0. + {pr}`4368` - {{ Enhancement }} Added apis to discard extra arguments when calling Python functions. {pr}`4392` +- {{ Enhancement }} Updated `pyimport` to support `pyimport("module.attribute")`. + {pr}`4395` + ### Packages - Upgraded scikit-learn to 1.4.0 {pr}`4409` diff --git a/src/js/api.ts b/src/js/api.ts index 0a8cc6cef..094305dad 100644 --- a/src/js/api.ts +++ b/src/js/api.ts @@ -451,31 +451,26 @@ export class PyodideAPI { /** * Imports a module and returns it. * - * .. admonition:: Warning - * :class: warning - * - * This function has a completely different behavior than the old removed pyimport function! - * - * ``pyimport`` is roughly equivalent to: - * - * .. code-block:: js - * - * pyodide.runPython(`import ${pkgname}; ${pkgname}`); - * - * except that the global namespace will not change. - * - * Example: - * - * .. code-block:: js - * - * let sysmodule = pyodide.pyimport("sys"); - * let recursionLimit = sysmodule.getrecursionlimit(); + * If `name` has no dot in it, then `pyimport(name)` is approximately + * equivalent to: + * ```js + * pyodide.runPython(`import ${name}; ${name}`) + * ``` + * except that `name` is not introduced into the Python global namespace. If + * the name has one or more dots in it, say it is of the form `path.name` + * where `name` has no dots but path may have zero or more dots. Then it is + * approximately the same as: + * ```js + * pyodide.runPython(`from ${path} import ${name}; ${name}`); + * ``` * * @param mod_name The name of the module to import - * @returns A PyProxy for the imported module + * + * @example + * pyodide.pyimport("math.comb")(4, 2) // returns 4 choose 2 = 6 */ - static pyimport(mod_name: string): PyProxy { - return API.importlib.import_module(mod_name); + static pyimport(mod_name: string): any { + return API.pyodide_base.pyimport_impl(mod_name); } /** @@ -749,6 +744,7 @@ API.finalizeBootstrap = function (): PyodideInterface { API.pyodide_code = import_module("pyodide.code"); API.pyodide_ffi = import_module("pyodide.ffi"); API.package_loader = import_module("pyodide._package_loader"); + API.pyodide_base = import_module("_pyodide._base"); API.sitepackages = API.package_loader.SITE_PACKAGES.__str__(); API.dsodir = API.package_loader.DSO_DIR.__str__(); diff --git a/src/js/types.ts b/src/js/types.ts index d82cc2c77..c812d28bf 100644 --- a/src/js/types.ts +++ b/src/js/types.ts @@ -309,6 +309,7 @@ export interface API { pyodide_py: any; pyodide_code: any; pyodide_ffi: any; + pyodide_base: any; globals: PyProxy; rawRun: (code: string) => [number, string]; runPythonInternal: (code: string) => any; diff --git a/src/py/_pyodide/_base.py b/src/py/_pyodide/_base.py index 7acf43366..c002f7fe9 100644 --- a/src/py/_pyodide/_base.py +++ b/src/py/_pyodide/_base.py @@ -621,3 +621,11 @@ def find_imports(source: str) -> list[str]: continue imports.add(module_name.split(".")[0]) return list(sorted(imports)) + + +def pyimport_impl(path: str) -> Any: + [stem, *fromlist] = path.rsplit(".", 1) + res = __import__(stem, fromlist=fromlist) + if fromlist: + res = getattr(res, fromlist[0]) + return res diff --git a/src/tests/test_pyodide.py b/src/tests/test_pyodide.py index c55830675..e34597bf7 100644 --- a/src/tests/test_pyodide.py +++ b/src/tests/test_pyodide.py @@ -41,7 +41,15 @@ def test_ffi_import_star(): exec("from pyodide.ffi import *", {}) -def test_pyimport(selenium): +def test_pyimport1(): + from _pyodide._base import pyimport_impl + + assert pyimport_impl("pyodide").__name__ == "pyodide" + assert pyimport_impl("pyodide.console").__name__ == "pyodide.console" + assert pyimport_impl("pyodide.console.BANNER").startswith("Python ") + + +def test_pyimport2(selenium): selenium.run_js( """ let platform = pyodide.pyimport("platform"); @@ -1155,7 +1163,7 @@ def test_js_stackframes(selenium): ["", "c1"], ["test.html", "b"], ["pyodide.asm.js", "pyimport"], - ["importlib/__init__.py", "import_module"], + ["_pyodide/_base.py", "pyimport_impl"], ] assert normalize_tb(res[: len(frames)]) == frames