meta pkg: wrap imports for traceability (#13924)

This commit is contained in:
Jirka Borovec 2022-07-29 12:50:26 +02:00 committed by GitHub
parent caaf35689c
commit c019fc633d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 35 additions and 2 deletions

View File

@ -270,6 +270,37 @@ def prune_comments_docstrings(lines: List[str]) -> List[str]:
return body
def wrap_try_except(body: List[str], pkg: str, ver: str) -> List[str]:
"""Wrap the file with try/except for better traceability of import misalignment."""
not_empty = sum(1 for ln in body if ln)
if not_empty == 0:
return body
body = ["try:"] + [f" {ln}" if ln else "" for ln in body]
body += [
"",
"except ImportError as err:",
"",
" from os import linesep",
f" from {pkg} import __version__",
f" msg = f'Your `lightning` package was built for `{pkg}=={ver}`," + " but you are running {__version__}'",
" raise type(err)(str(err) + linesep + msg)",
]
return body
def parse_version_from_file(pkg_root: str) -> str:
"""Loading the package version from file."""
file_ver = os.path.join(pkg_root, "__version__.py")
file_about = os.path.join(pkg_root, "__about__.py")
if os.path.isfile(file_ver):
ver = _load_py_module("version", file_ver).version
elif os.path.isfile(file_about):
ver = _load_py_module("about", file_about).__version__
else: # this covers case you have build only meta-package so not additional source files are present
ver = ""
return ver
def create_meta_package(src_folder: str, pkg_name: str = "pytorch_lightning", lit_name: str = "pytorch"):
"""Parse the real python package and for each module create a mirroe version with repalcing all function and
class implementations by cross-imports to the true package.
@ -279,6 +310,7 @@ def create_meta_package(src_folder: str, pkg_name: str = "pytorch_lightning", li
>>> create_meta_package(os.path.join(_PROJECT_ROOT, "src"))
"""
package_dir = os.path.join(src_folder, pkg_name)
pkg_ver = parse_version_from_file(package_dir)
# shutil.rmtree(os.path.join(src_folder, "lightning", lit_name))
py_files = glob.glob(os.path.join(src_folder, pkg_name, "**", "*.py"), recursive=True)
for py_file in py_files:
@ -310,15 +342,16 @@ def create_meta_package(src_folder: str, pkg_name: str = "pytorch_lightning", li
while body_len != len(body):
body_len = len(body)
body = prune_empty_statements(body)
# TODO: add try/catch wrapper for whole body,
# add try/catch wrapper for whole body,
# so when import fails it tells you what is the package version this meta package was generated for...
body = wrap_try_except(body, pkg_name, pkg_ver)
# todo: apply pre-commit formatting
body = [ln for ln, _group in groupby(body)]
lines = []
# drop duplicated lines
for ln in body:
if ln + os.linesep not in lines or ln in (")", ""):
if ln + os.linesep not in lines or ln.lstrip() in (")", ""):
lines.append(ln + os.linesep)
# compose the target file name
new_file = os.path.join(src_folder, "lightning", lit_name, local_path)