# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import logging
import os
import pathlib
import re
import shutil
import tarfile
import tempfile
import urllib.request
from datetime import datetime
from importlib.util import module_from_spec, spec_from_file_location
from itertools import chain, groupby
from types import ModuleType
from typing import List
_PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__))
_PACKAGE_MAPPING = {"pytorch": "pytorch_lightning", "app": "lightning_app"}
# TODO: remove this once lightning-ui package is ready as a dependency
_LIGHTNING_FRONTEND_RELEASE_URL = "https://storage.googleapis.com/grid-packages/lightning-ui/v0.0.0/build.tar.gz"
def _load_py_module(name: str, location: str) -> ModuleType:
spec = spec_from_file_location(name, location)
assert spec, f"Failed to load module {name} from {location}"
py = module_from_spec(spec)
assert spec.loader, f"ModuleSpec.loader is None for {name} from {location}"
spec.loader.exec_module(py)
return py
def load_requirements(
path_dir: str, file_name: str = "base.txt", comment_char: str = "#", unfreeze: bool = True
) -> List[str]:
"""Loading requirements from a file.
>>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
>>> load_requirements(path_req) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
['numpy...', 'torch...', ...]
"""
with open(os.path.join(path_dir, file_name)) as file:
lines = [ln.strip() for ln in file.readlines()]
reqs = []
for ln in lines:
# filer all comments
comment = ""
if comment_char in ln:
comment = ln[ln.index(comment_char) :]
ln = ln[: ln.index(comment_char)]
req = ln.strip()
# skip directly installed dependencies
if not req or req.startswith("http") or "@http" in req:
continue
# remove version restrictions unless they are strict
if unfreeze and "<" in req and "strict" not in comment:
req = re.sub(r",? *<=? *[\d\.\*]+", "", req).strip()
reqs.append(req)
return reqs
def load_readme_description(path_dir: str, homepage: str, version: str) -> str:
"""Load readme as decribtion.
>>> load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
'
...'
"""
path_readme = os.path.join(path_dir, "README.md")
text = open(path_readme, encoding="utf-8").read()
# drop images from readme
text = text.replace("![PT to PL](docs/source/_static/images/general/pl_quick_start_full_compressed.gif)", "")
# https://github.com/Lightning-AI/lightning/raw/master/docs/source/_static/images/lightning_module/pt_to_pl.png
github_source_url = os.path.join(homepage, "raw", version)
# replace relative repository path to absolute link to the release
# do not replace all "docs" as in the readme we reger some other sources with particular path to docs
text = text.replace("docs/source/_static/", f"{os.path.join(github_source_url, 'docs/source/_static/')}")
# readthedocs badge
text = text.replace("badge/?version=stable", f"badge/?version={version}")
text = text.replace("pytorch-lightning.readthedocs.io/en/stable/", f"pytorch-lightning.readthedocs.io/en/{version}")
# codecov badge
text = text.replace("/branch/master/graph/badge.svg", f"/release/{version}/graph/badge.svg")
# replace github badges for release ones
text = text.replace("badge.svg?branch=master&event=push", f"badge.svg?tag={version}")
# Azure...
text = text.replace("?branchName=master", f"?branchName=refs%2Ftags%2F{version}")
text = re.sub(r"\?definitionId=\d+&branchName=master", f"?definitionId=2&branchName=refs%2Ftags%2F{version}", text)
skip_begin = r""
skip_end = r""
# todo: wrap content as commented description
text = re.sub(rf"{skip_begin}.+?{skip_end}", "", text, flags=re.IGNORECASE + re.DOTALL)
# # https://github.com/Borda/pytorch-lightning/releases/download/1.1.0a6/codecov_badge.png
# github_release_url = os.path.join(homepage, "releases", "download", version)
# # download badge and replace url with local file
# text = _parse_for_badge(text, github_release_url)
return text
def replace_block_with_imports(lines: List[str], import_path: str, kword: str = "class") -> List[str]:
"""Parse a file and replace implementtaions bodies of function or class.
>>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "loggers", "logger.py")
>>> import_path = ".".join(["pytorch_lightning", "loggers", "logger"])
>>> with open(py_file, encoding="utf-8") as fp:
... lines = [ln.rstrip() for ln in fp.readlines()]
>>> lines = replace_block_with_imports(lines, import_path, "class")
>>> lines = replace_block_with_imports(lines, import_path, "def")
"""
body, tracking, skip_offset = [], False, 0
for ln in lines:
offset = len(ln) - len(ln.lstrip())
# in case of mating the class args are multi-line
if tracking and ln and offset <= skip_offset and not any(ln.lstrip().startswith(c) for c in ")]"):
tracking = False
if ln.lstrip().startswith(f"{kword} ") and not tracking:
name = ln.replace(f"{kword} ", "").strip()
idxs = [name.index(c) for c in ":(" if c in name]
name = name[: min(idxs)]
# skip private, TODO: consider skip even protected
if not name.startswith("__"):
body.append(f"{' ' * offset}from {import_path} import {name} # noqa: F401")
tracking, skip_offset = True, offset
continue
if not tracking:
body.append(ln)
return body
def replace_vars_with_imports(lines: List[str], import_path: str) -> List[str]:
"""Parse a file and replace variable filling with import.
>>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "utilities", "imports.py")
>>> import_path = ".".join(["pytorch_lightning", "utilities", "imports"])
>>> with open(py_file, encoding="utf-8") as fp:
... lines = [ln.rstrip() for ln in fp.readlines()]
>>> lines = replace_vars_with_imports(lines, import_path)
"""
copied = []
body, tracking, skip_offset = [], False, 0
for ln in lines:
offset = len(ln) - len(ln.lstrip())
# in case of mating the class args are multi-line
if tracking and ln and offset <= skip_offset and not any(ln.lstrip().startswith(c) for c in ")]}"):
tracking = False
var = re.match(r"^([\w_\d]+)[: [\w\., \[\]]*]? = ", ln.lstrip())
if var:
name = var.groups()[0]
# skip private or apply white-list for allowed vars
if name not in copied and (not name.startswith("__") or name in ("__all__",)):
body.append(f"{' ' * offset}from {import_path} import {name} # noqa: F401")
copied.append(name)
tracking, skip_offset = True, offset
continue
if not tracking:
body.append(ln)
return body
def prune_imports_callables(lines: List[str]) -> List[str]:
"""Prune imports and calling functions from a file, even multi-line.
>>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "utilities", "cli.py")
>>> import_path = ".".join(["pytorch_lightning", "utilities", "cli"])
>>> with open(py_file, encoding="utf-8") as fp:
... lines = [ln.rstrip() for ln in fp.readlines()]
>>> lines = prune_imports_callables(lines)
"""
body, tracking, skip_offset = [], False, 0
for ln in lines:
if ln.lstrip().startswith("import "):
continue
offset = len(ln) - len(ln.lstrip())
# in case of mating the class args are multi-line
if tracking and ln and offset <= skip_offset and not any(ln.lstrip().startswith(c) for c in ")]}"):
tracking = False
# catching callable
call = re.match(r"^[\w_\d\.]+\(", ln.lstrip())
if (ln.lstrip().startswith("from ") and " import " in ln) or call:
tracking, skip_offset = True, offset
continue
if not tracking:
body.append(ln)
return body
def prune_func_calls(lines: List[str]) -> List[str]:
"""Prune calling functions from a file, even multi-line.
>>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "loggers", "__init__.py")
>>> import_path = ".".join(["pytorch_lightning", "loggers"])
>>> with open(py_file, encoding="utf-8") as fp:
... lines = [ln.rstrip() for ln in fp.readlines()]
>>> lines = prune_func_calls(lines)
"""
body, tracking, score = [], False, 0
for ln in lines:
# catching callable
calling = re.match(r"^@?[\w_\d\.]+ *\(", ln.lstrip())
if calling and " import " not in ln:
tracking = True
score = 0
if tracking:
score += ln.count("(") - ln.count(")")
if score == 0:
tracking = False
else:
body.append(ln)
return body
def prune_empty_statements(lines: List[str]) -> List[str]:
"""Prune emprty if/else and try/except.
>>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "utilities", "cli.py")
>>> import_path = ".".join(["pytorch_lightning", "utilities", "cli"])
>>> with open(py_file, encoding="utf-8") as fp:
... lines = [ln.rstrip() for ln in fp.readlines()]
>>> lines = prune_imports_callables(lines)
>>> lines = prune_empty_statements(lines)
"""
kwords_pairs = ("with", "if ", "elif ", "else", "try", "except")
body, tracking, skip_offset, last_count = [], False, 0, 0
# todo: consider some more complex logic as for example only some leaves of if/else tree are empty
for i, ln in enumerate(lines):
offset = len(ln) - len(ln.lstrip())
# skipp all decorators
if ln.lstrip().startswith("@"):
# consider also multi-line decorators
if "(" in ln and ")" not in ln:
tracking, skip_offset = True, offset
continue
# in case of mating the class args are multi-line
if tracking and ln and offset <= skip_offset and not any(ln.lstrip().startswith(c) for c in ")]}"):
tracking = False
starts = [k for k in kwords_pairs if ln.lstrip().startswith(k)]
if starts:
start, count = starts[0], -1
# look forward if this statement has a body
for ln_ in lines[i:]:
offset_ = len(ln_) - len(ln_.lstrip())
if count == -1 and ln_.rstrip().endswith(":"):
count = 0
elif ln_ and offset_ <= offset:
break
# skipp all til end of statement
elif ln_.lstrip():
# count non-zero body lines
count += 1
# cache the last key body as the supplement canot be without
if start in ("if", "elif", "try"):
last_count = count
if count <= 0 or (start in ("else", "except") and last_count <= 0):
tracking, skip_offset = True, offset
if not tracking:
body.append(ln)
return body
def prune_comments_docstrings(lines: List[str]) -> List[str]:
"""Prune all doctsrings with triple " notation.
>>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "loggers", "csv_logs.py")
>>> import_path = ".".join(["pytorch_lightning", "loggers", "csv_logs"])
>>> with open(py_file, encoding="utf-8") as fp:
... lines = [ln.rstrip() for ln in fp.readlines()]
>>> lines = prune_comments_docstrings(lines)
"""
body, tracking = [], False
for ln in lines:
if "#" in ln and "noqa:" not in ln:
ln = ln[: ln.index("#")]
if not tracking and any(ln.lstrip().startswith(s) for s in ['"""', 'r"""']):
# oneliners skip directly
if len(ln.strip()) >= 6 and ln.rstrip().endswith('"""'):
continue
tracking = True
elif ln.rstrip().endswith('"""'):
tracking = False
continue
if not tracking:
body.append(ln.rstrip())
return body
def wrap_try_except(body: List[str], pkg: str, ver: str) -> List[str]:
"""Wrap the file with try/except for better traceability of import misalignment."""
not_empty = sum(1 for ln in body if ln)
if not_empty == 0:
return body
body = ["try:"] + [f" {ln}" if ln else "" for ln in body]
body += [
"",
"except ImportError as err:",
"",
" from os import linesep",
f" from {pkg} import __version__",
f" msg = f'Your `lightning` package was built for `{pkg}=={ver}`," + " but you are running {__version__}'",
" raise type(err)(str(err) + linesep + msg)",
]
return body
def parse_version_from_file(pkg_root: str) -> str:
"""Loading the package version from file."""
file_ver = os.path.join(pkg_root, "__version__.py")
file_about = os.path.join(pkg_root, "__about__.py")
if os.path.isfile(file_ver):
ver = _load_py_module("version", file_ver).version
elif os.path.isfile(file_about):
ver = _load_py_module("about", file_about).__version__
else: # this covers case you have build only meta-package so not additional source files are present
ver = ""
return ver
def prune_duplicate_lines(body):
body_ = []
# drop duplicated lines
for ln in body:
if ln.lstrip() not in body_ or ln.lstrip() in (")", ""):
body_.append(ln)
return body_
def create_meta_package(src_folder: str, pkg_name: str = "pytorch_lightning", lit_name: str = "pytorch"):
"""Parse the real python package and for each module create a mirroe version with repalcing all function and
class implementations by cross-imports to the true package.
As validation run in termnal: `flake8 src/lightning/ --ignore E402,F401,E501`
>>> create_meta_package(os.path.join(_PROJECT_ROOT, "src"))
"""
package_dir = os.path.join(src_folder, pkg_name)
pkg_ver = parse_version_from_file(package_dir)
# shutil.rmtree(os.path.join(src_folder, "lightning", lit_name))
py_files = glob.glob(os.path.join(src_folder, pkg_name, "**", "*.py"), recursive=True)
for py_file in py_files:
local_path = py_file.replace(package_dir + os.path.sep, "")
fname = os.path.basename(py_file)
if "-" in local_path:
continue
with open(py_file, encoding="utf-8") as fp:
lines = [ln.rstrip() for ln in fp.readlines()]
import_path = pkg_name + "." + local_path.replace(".py", "").replace(os.path.sep, ".")
import_path = import_path.replace(".__init__", "")
if fname in ("__about__.py", "__version__.py"):
body = lines
else:
if fname.startswith("_") and fname not in ("__init__.py", "__main__.py"):
logging.warning(f"unsupported file: {local_path}")
continue
# ToDO: perform some smarter parsing - preserve Constants, lambdas, etc
body = prune_comments_docstrings([ln.rstrip() for ln in lines])
if fname not in ("__init__.py", "__main__.py"):
body = prune_imports_callables(body)
for key_word in ("class", "def", "async def"):
body = replace_block_with_imports(body, import_path, key_word)
# TODO: fix reimporting which is artefact after replacing var assignment with import;
# after fixing , update CI by remove F811 from CI/check pkg
body = replace_vars_with_imports(body, import_path)
if fname not in ("__main__.py",):
body = prune_func_calls(body)
body_len = -1
# in case of several in-depth statements
while body_len != len(body):
body_len = len(body)
body = prune_duplicate_lines(body)
body = prune_empty_statements(body)
# add try/catch wrapper for whole body,
# so when import fails it tells you what is the package version this meta package was generated for...
body = wrap_try_except(body, pkg_name, pkg_ver)
# todo: apply pre-commit formatting
# clean to many empty lines
body = [ln for ln, _group in groupby(body)]
# drop duplicated lines
body = prune_duplicate_lines(body)
# compose the target file name
new_file = os.path.join(src_folder, "lightning", lit_name, local_path)
os.makedirs(os.path.dirname(new_file), exist_ok=True)
with open(new_file, "w", encoding="utf-8") as fp:
fp.writelines([ln + os.linesep for ln in body])
def set_version_today(fpath: str) -> None:
"""Replace the template date with today."""
with open(fpath) as fp:
lines = fp.readlines()
def _replace_today(ln):
today = datetime.now()
return ln.replace("YYYY.-M.-D", f"{today.year}.{today.month}.{today.day}")
lines = list(map(_replace_today, lines))
with open(fpath, "w") as fp:
fp.writelines(lines)
def _download_frontend(root: str = _PROJECT_ROOT):
"""Downloads an archive file for a specific release of the Lightning frontend and extracts it to the correct
directory."""
try:
frontend_dir = pathlib.Path(root, "src", "lightning_app", "ui")
download_dir = tempfile.mkdtemp()
shutil.rmtree(frontend_dir, ignore_errors=True)
response = urllib.request.urlopen(_LIGHTNING_FRONTEND_RELEASE_URL)
file = tarfile.open(fileobj=response, mode="r|gz")
file.extractall(path=download_dir)
shutil.move(os.path.join(download_dir, "build"), frontend_dir)
print("The Lightning UI has successfully been downloaded!")
# If installing from source without internet connection, we don't want to break the installation
except Exception:
print("The Lightning UI downloading has failed!")
def _adjust_require_versions(source_dir: str = "src", req_dir: str = "requirements") -> None:
"""Parse the base requirements and append as version adjustments if needed `pkg>=X1.Y1.Z1,==X2.Y2.*`."""
reqs = load_requirements(req_dir, file_name="base.txt")
for i, req in enumerate(reqs):
pkg_name = req[: min(req.index(c) for c in ">=" if c in req)]
ver_ = parse_version_from_file(os.path.join(source_dir, pkg_name))
if not ver_:
continue
ver2 = ".".join(ver_.split(".")[:2] + ["*"])
reqs[i] = f"{req}, =={ver2}"
with open(os.path.join(req_dir, "base.txt"), "w") as fp:
fp.writelines([ln + os.linesep for ln in reqs])
def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requirements: bool = False) -> None:
"""Load all base requirements from all particular packages and prune duplicates."""
requires = [
load_requirements(d, file_name="base.txt", unfreeze=not freeze_requirements)
for d in glob.glob(os.path.join(req_dir, "*"))
if os.path.isdir(d)
]
if not requires:
return None
# TODO: add some smarter version aggregation per each package
requires = list(chain(*requires))
with open(os.path.join(req_dir, "base.txt"), "w") as fp:
fp.writelines([ln + os.linesep for ln in requires])