From c019fc633d0b4e1c5b7e384216aab13149c90c9e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 29 Jul 2022 12:50:26 +0200 Subject: [PATCH] meta pkg: wrap imports for traceability (#13924) --- .actions/setup_tools.py | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/.actions/setup_tools.py b/.actions/setup_tools.py index 3a105f508f..2aff3bdf9a 100644 --- a/.actions/setup_tools.py +++ b/.actions/setup_tools.py @@ -270,6 +270,37 @@ def prune_comments_docstrings(lines: List[str]) -> List[str]: return body +def wrap_try_except(body: List[str], pkg: str, ver: str) -> List[str]: + """Wrap the file with try/except for better traceability of import misalignment.""" + not_empty = sum(1 for ln in body if ln) + if not_empty == 0: + return body + body = ["try:"] + [f" {ln}" if ln else "" for ln in body] + body += [ + "", + "except ImportError as err:", + "", + " from os import linesep", + f" from {pkg} import __version__", + f" msg = f'Your `lightning` package was built for `{pkg}=={ver}`," + " but you are running {__version__}'", + " raise type(err)(str(err) + linesep + msg)", + ] + return body + + +def parse_version_from_file(pkg_root: str) -> str: + """Loading the package version from file.""" + file_ver = os.path.join(pkg_root, "__version__.py") + file_about = os.path.join(pkg_root, "__about__.py") + if os.path.isfile(file_ver): + ver = _load_py_module("version", file_ver).version + elif os.path.isfile(file_about): + ver = _load_py_module("about", file_about).__version__ + else: # this covers case you have build only meta-package so not additional source files are present + ver = "" + return ver + + def create_meta_package(src_folder: str, pkg_name: str = "pytorch_lightning", lit_name: str = "pytorch"): """Parse the real python package and for each module create a mirroe version with repalcing all function and class implementations by cross-imports to the true package. @@ -279,6 +310,7 @@ def create_meta_package(src_folder: str, pkg_name: str = "pytorch_lightning", li >>> create_meta_package(os.path.join(_PROJECT_ROOT, "src")) """ package_dir = os.path.join(src_folder, pkg_name) + pkg_ver = parse_version_from_file(package_dir) # shutil.rmtree(os.path.join(src_folder, "lightning", lit_name)) py_files = glob.glob(os.path.join(src_folder, pkg_name, "**", "*.py"), recursive=True) for py_file in py_files: @@ -310,15 +342,16 @@ def create_meta_package(src_folder: str, pkg_name: str = "pytorch_lightning", li while body_len != len(body): body_len = len(body) body = prune_empty_statements(body) - # TODO: add try/catch wrapper for whole body, + # add try/catch wrapper for whole body, # so when import fails it tells you what is the package version this meta package was generated for... + body = wrap_try_except(body, pkg_name, pkg_ver) # todo: apply pre-commit formatting body = [ln for ln, _group in groupby(body)] lines = [] # drop duplicated lines for ln in body: - if ln + os.linesep not in lines or ln in (")", ""): + if ln + os.linesep not in lines or ln.lstrip() in (")", ""): lines.append(ln + os.linesep) # compose the target file name new_file = os.path.join(src_folder, "lightning", lit_name, local_path)