meta pkg: wrap imports for traceability (#13924)
This commit is contained in:
parent
caaf35689c
commit
c019fc633d
|
@ -270,6 +270,37 @@ def prune_comments_docstrings(lines: List[str]) -> List[str]:
|
|||
return body
|
||||
|
||||
|
||||
def wrap_try_except(body: List[str], pkg: str, ver: str) -> List[str]:
|
||||
"""Wrap the file with try/except for better traceability of import misalignment."""
|
||||
not_empty = sum(1 for ln in body if ln)
|
||||
if not_empty == 0:
|
||||
return body
|
||||
body = ["try:"] + [f" {ln}" if ln else "" for ln in body]
|
||||
body += [
|
||||
"",
|
||||
"except ImportError as err:",
|
||||
"",
|
||||
" from os import linesep",
|
||||
f" from {pkg} import __version__",
|
||||
f" msg = f'Your `lightning` package was built for `{pkg}=={ver}`," + " but you are running {__version__}'",
|
||||
" raise type(err)(str(err) + linesep + msg)",
|
||||
]
|
||||
return body
|
||||
|
||||
|
||||
def parse_version_from_file(pkg_root: str) -> str:
|
||||
"""Loading the package version from file."""
|
||||
file_ver = os.path.join(pkg_root, "__version__.py")
|
||||
file_about = os.path.join(pkg_root, "__about__.py")
|
||||
if os.path.isfile(file_ver):
|
||||
ver = _load_py_module("version", file_ver).version
|
||||
elif os.path.isfile(file_about):
|
||||
ver = _load_py_module("about", file_about).__version__
|
||||
else: # this covers case you have build only meta-package so not additional source files are present
|
||||
ver = ""
|
||||
return ver
|
||||
|
||||
|
||||
def create_meta_package(src_folder: str, pkg_name: str = "pytorch_lightning", lit_name: str = "pytorch"):
|
||||
"""Parse the real python package and for each module create a mirroe version with repalcing all function and
|
||||
class implementations by cross-imports to the true package.
|
||||
|
@ -279,6 +310,7 @@ def create_meta_package(src_folder: str, pkg_name: str = "pytorch_lightning", li
|
|||
>>> create_meta_package(os.path.join(_PROJECT_ROOT, "src"))
|
||||
"""
|
||||
package_dir = os.path.join(src_folder, pkg_name)
|
||||
pkg_ver = parse_version_from_file(package_dir)
|
||||
# shutil.rmtree(os.path.join(src_folder, "lightning", lit_name))
|
||||
py_files = glob.glob(os.path.join(src_folder, pkg_name, "**", "*.py"), recursive=True)
|
||||
for py_file in py_files:
|
||||
|
@ -310,15 +342,16 @@ def create_meta_package(src_folder: str, pkg_name: str = "pytorch_lightning", li
|
|||
while body_len != len(body):
|
||||
body_len = len(body)
|
||||
body = prune_empty_statements(body)
|
||||
# TODO: add try/catch wrapper for whole body,
|
||||
# add try/catch wrapper for whole body,
|
||||
# so when import fails it tells you what is the package version this meta package was generated for...
|
||||
body = wrap_try_except(body, pkg_name, pkg_ver)
|
||||
|
||||
# todo: apply pre-commit formatting
|
||||
body = [ln for ln, _group in groupby(body)]
|
||||
lines = []
|
||||
# drop duplicated lines
|
||||
for ln in body:
|
||||
if ln + os.linesep not in lines or ln in (")", ""):
|
||||
if ln + os.linesep not in lines or ln.lstrip() in (")", ""):
|
||||
lines.append(ln + os.linesep)
|
||||
# compose the target file name
|
||||
new_file = os.path.join(src_folder, "lightning", lit_name, local_path)
|
||||
|
|
Loading…
Reference in New Issue