lightning/.actions/setup_tools.py

533 lines
22 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import logging
import os
import pathlib
import re
import shutil
import tarfile
import tempfile
import urllib.request
from datetime import datetime
from distutils.version import LooseVersion
from importlib.util import module_from_spec, spec_from_file_location
from itertools import chain, groupby
from types import ModuleType
from typing import List
from pkg_resources import parse_requirements
_PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__))
_PACKAGE_MAPPING = {"pytorch": "pytorch_lightning", "app": "lightning_app"}
# TODO: remove this once lightning-ui package is ready as a dependency
_LIGHTNING_FRONTEND_RELEASE_URL = "https://storage.googleapis.com/grid-packages/lightning-ui/v0.0.0/build.tar.gz"
def _load_py_module(name: str, location: str) -> ModuleType:
spec = spec_from_file_location(name, location)
assert spec, f"Failed to load module {name} from {location}"
py = module_from_spec(spec)
assert spec.loader, f"ModuleSpec.loader is None for {name} from {location}"
spec.loader.exec_module(py)
return py
def _augment_requirement(ln: str, comment_char: str = "#", unfreeze: str = "all") -> str:
"""Adjust the upper version contrains.
Args:
ln: raw line from requirement
comment_char: charter marking comment
unfreeze: Enum or "all"|"major"|""
Returns:
adjusted requirement
>>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # anything", unfreeze="")
'arrow>=1.2.0, <=1.2.2'
>>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # strict", unfreeze="")
'arrow>=1.2.0, <=1.2.2 # strict'
>>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # my name", unfreeze="all")
'arrow>=1.2.0'
>>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # strict", unfreeze="all")
'arrow>=1.2.0, <=1.2.2 # strict'
>>> _augment_requirement("arrow", unfreeze="all")
'arrow'
>>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # cool", unfreeze="major")
'arrow>=1.2.0, <2.0 # strict'
>>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # strict", unfreeze="major")
'arrow>=1.2.0, <=1.2.2 # strict'
>>> _augment_requirement("arrow>=1.2.0", unfreeze="major")
'arrow>=1.2.0, <2.0 # strict'
>>> _augment_requirement("arrow", unfreeze="major")
'arrow'
"""
# filer all comments
if comment_char in ln:
comment = ln[ln.index(comment_char) :]
ln = ln[: ln.index(comment_char)]
is_strict = "strict" in comment
else:
is_strict = False
req = ln.strip()
# skip directly installed dependencies
if not req or req.startswith("http") or "@http" in req:
return ""
# extract the major version from all listed versions
if unfreeze == "major":
req_ = list(parse_requirements([req]))[0]
vers = [LooseVersion(v) for s, v in req_.specs if s not in ("==", "~=")]
ver_major = sorted(vers)[-1].version[0] if vers else None
else:
ver_major = None
# remove version restrictions unless they are strict
if unfreeze and "<" in req and not is_strict:
req = re.sub(r",? *<=? *[\d\.\*]+", "", req).strip()
if ver_major is not None and not is_strict:
# add , only if there are already some versions
req += f"{',' if any(c in req for c in '<=>') else ''} <{int(ver_major) + 1}.0"
# adding strict back to the comment
if is_strict or ver_major is not None:
req += " # strict"
return req
def load_requirements(
path_dir: str, file_name: str = "base.txt", comment_char: str = "#", unfreeze: str = "all"
) -> List[str]:
"""Loading requirements from a file.
>>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
>>> load_requirements(path_req, unfreeze="major") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
['pytorch_lightning...', 'lightning_app...']
"""
with open(os.path.join(path_dir, file_name)) as file:
lines = [ln.strip() for ln in file.readlines()]
reqs = []
for ln in lines:
reqs.append(_augment_requirement(ln, comment_char=comment_char, unfreeze=unfreeze))
# filter empty lines
return [str(req) for req in reqs if req]
def load_readme_description(path_dir: str, homepage: str, version: str) -> str:
"""Load readme as decribtion.
>>> load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
'...'
"""
path_readme = os.path.join(path_dir, "README.md")
text = open(path_readme, encoding="utf-8").read()
# drop images from readme
text = text.replace("![PT to PL](docs/source/_static/images/general/pl_quick_start_full_compressed.gif)", "")
# https://github.com/Lightning-AI/lightning/raw/master/docs/source/_static/images/lightning_module/pt_to_pl.png
github_source_url = os.path.join(homepage, "raw", version)
# replace relative repository path to absolute link to the release
# do not replace all "docs" as in the readme we reger some other sources with particular path to docs
text = text.replace("docs/source/_static/", f"{os.path.join(github_source_url, 'docs/source/_static/')}")
# readthedocs badge
text = text.replace("badge/?version=stable", f"badge/?version={version}")
text = text.replace("pytorch-lightning.readthedocs.io/en/stable/", f"pytorch-lightning.readthedocs.io/en/{version}")
# codecov badge
text = text.replace("/branch/master/graph/badge.svg", f"/release/{version}/graph/badge.svg")
# github actions badge
text = text.replace("badge.svg?branch=master&event=push", f"badge.svg?tag={version}")
# azure pipelines badge
text = text.replace("?branchName=master", f"?branchName=refs%2Ftags%2F{version}")
skip_begin = r"<!-- following section will be skipped from PyPI description -->"
skip_end = r"<!-- end skipping PyPI description -->"
# todo: wrap content as commented description
text = re.sub(rf"{skip_begin}.+?{skip_end}", "<!-- -->", text, flags=re.IGNORECASE + re.DOTALL)
# # https://github.com/Borda/pytorch-lightning/releases/download/1.1.0a6/codecov_badge.png
# github_release_url = os.path.join(homepage, "releases", "download", version)
# # download badge and replace url with local file
# text = _parse_for_badge(text, github_release_url)
return text
def replace_block_with_imports(lines: List[str], import_path: str, kword: str = "class") -> List[str]:
"""Parse a file and replace implementtaions bodies of function or class.
>>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "loggers", "logger.py")
>>> import_path = ".".join(["pytorch_lightning", "loggers", "logger"])
>>> with open(py_file, encoding="utf-8") as fp:
... lines = [ln.rstrip() for ln in fp.readlines()]
>>> lines = replace_block_with_imports(lines, import_path, "class")
>>> lines = replace_block_with_imports(lines, import_path, "def")
"""
body, tracking, skip_offset = [], False, 0
for i, ln in enumerate(lines):
# support for defining a class with this condition
conditional_class_definitions = ("if TYPE_CHECKING", "if typing.TYPE_CHECKING", "if torch.", "if _TORCH_")
if (
any(ln.startswith(pattern) for pattern in conditional_class_definitions)
# avoid bug in CI for the <1.7 meta code
and "pytorch_lightning.utilities.meta" not in import_path
):
# dedent the next line
lines[i + 1] = lines[i + 1].lstrip()
continue
offset = len(ln) - len(ln.lstrip())
# in case of mating the class args are multi-line
if tracking and ln and offset <= skip_offset and not any(ln.lstrip().startswith(c) for c in ")]"):
tracking = False
if ln.lstrip().startswith(f"{kword} ") and not tracking:
name = ln.replace(f"{kword} ", "").strip()
idxs = [name.index(c) for c in ":(" if c in name]
name = name[: min(idxs)]
# skip private, TODO: consider skip even protected
if not name.startswith("__"):
body.append(f"{' ' * offset}from {import_path} import {name} # noqa: F401")
tracking, skip_offset = True, offset
continue
if not tracking:
body.append(ln)
return body
def replace_vars_with_imports(lines: List[str], import_path: str) -> List[str]:
"""Parse a file and replace variable filling with import.
>>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "utilities", "imports.py")
>>> import_path = ".".join(["pytorch_lightning", "utilities", "imports"])
>>> with open(py_file, encoding="utf-8") as fp:
... lines = [ln.rstrip() for ln in fp.readlines()]
>>> lines = replace_vars_with_imports(lines, import_path)
"""
copied = []
body, tracking, skip_offset = [], False, 0
for ln in lines:
offset = len(ln) - len(ln.lstrip())
# in case of mating the class args are multi-line
if tracking and ln and offset <= skip_offset and not any(ln.lstrip().startswith(c) for c in ")]}"):
tracking = False
var = re.match(r"^([\w_\d]+)[: [\w\., \[\]]*]? = ", ln.lstrip())
if var:
name = var.groups()[0]
# skip private or apply white-list for allowed vars
if name not in copied and (not name.startswith("__") or name in ("__all__",)):
body.append(f"{' ' * offset}from {import_path} import {name} # noqa: F401")
copied.append(name)
tracking, skip_offset = True, offset
continue
if not tracking:
body.append(ln)
return body
def prune_imports_callables(lines: List[str]) -> List[str]:
"""Prune imports and calling functions from a file, even multi-line.
>>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "utilities", "cli.py")
>>> import_path = ".".join(["pytorch_lightning", "utilities", "cli"])
>>> with open(py_file, encoding="utf-8") as fp:
... lines = [ln.rstrip() for ln in fp.readlines()]
>>> lines = prune_imports_callables(lines)
"""
body, tracking, skip_offset = [], False, 0
for ln in lines:
if ln.lstrip().startswith("import "):
continue
offset = len(ln) - len(ln.lstrip())
# in case of mating the class args are multi-line
if tracking and ln and offset <= skip_offset and not any(ln.lstrip().startswith(c) for c in ")]}"):
tracking = False
# catching callable
call = re.match(r"^[\w_\d\.]+\(", ln.lstrip())
if (ln.lstrip().startswith("from ") and " import " in ln) or call:
tracking, skip_offset = True, offset
continue
if not tracking:
body.append(ln)
return body
def prune_func_calls(lines: List[str]) -> List[str]:
"""Prune calling functions from a file, even multi-line.
>>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "loggers", "__init__.py")
>>> import_path = ".".join(["pytorch_lightning", "loggers"])
>>> with open(py_file, encoding="utf-8") as fp:
... lines = [ln.rstrip() for ln in fp.readlines()]
>>> lines = prune_func_calls(lines)
"""
body, tracking, score = [], False, 0
for ln in lines:
# catching callable
calling = re.match(r"^@?[\w_\d\.]+ *\(", ln.lstrip())
if calling and " import " not in ln:
tracking = True
score = 0
if tracking:
score += ln.count("(") - ln.count(")")
if score == 0:
tracking = False
else:
body.append(ln)
return body
def prune_empty_statements(lines: List[str]) -> List[str]:
"""Prune emprty if/else and try/except.
>>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "utilities", "cli.py")
>>> import_path = ".".join(["pytorch_lightning", "utilities", "cli"])
>>> with open(py_file, encoding="utf-8") as fp:
... lines = [ln.rstrip() for ln in fp.readlines()]
>>> lines = prune_imports_callables(lines)
>>> lines = prune_empty_statements(lines)
"""
kwords_pairs = ("with", "if ", "elif ", "else", "try", "except")
body, tracking, skip_offset, last_count = [], False, 0, 0
# todo: consider some more complex logic as for example only some leaves of if/else tree are empty
for i, ln in enumerate(lines):
offset = len(ln) - len(ln.lstrip())
# skipp all decorators
if ln.lstrip().startswith("@"):
# consider also multi-line decorators
if "(" in ln and ")" not in ln:
tracking, skip_offset = True, offset
continue
# in case of mating the class args are multi-line
if tracking and ln and offset <= skip_offset and not any(ln.lstrip().startswith(c) for c in ")]}"):
tracking = False
starts = [k for k in kwords_pairs if ln.lstrip().startswith(k)]
if starts:
start, count = starts[0], -1
# look forward if this statement has a body
for ln_ in lines[i:]:
offset_ = len(ln_) - len(ln_.lstrip())
if count == -1 and ln_.rstrip().endswith(":"):
count = 0
elif ln_ and offset_ <= offset:
break
# skipp all til end of statement
elif ln_.lstrip():
# count non-zero body lines
count += 1
# cache the last key body as the supplement canot be without
if start in ("if", "elif", "try"):
last_count = count
if count <= 0 or (start in ("else", "except") and last_count <= 0):
tracking, skip_offset = True, offset
if not tracking:
body.append(ln)
return body
def prune_comments_docstrings(lines: List[str]) -> List[str]:
"""Prune all doctsrings with triple " notation.
>>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "loggers", "csv_logs.py")
>>> import_path = ".".join(["pytorch_lightning", "loggers", "csv_logs"])
>>> with open(py_file, encoding="utf-8") as fp:
... lines = [ln.rstrip() for ln in fp.readlines()]
>>> lines = prune_comments_docstrings(lines)
"""
body, tracking = [], False
for ln in lines:
if "#" in ln and "noqa:" not in ln:
ln = ln[: ln.index("#")]
if not tracking and any(ln.lstrip().startswith(s) for s in ['"""', 'r"""']):
# oneliners skip directly
if len(ln.strip()) >= 6 and ln.rstrip().endswith('"""'):
continue
tracking = True
elif ln.rstrip().endswith('"""'):
tracking = False
continue
if not tracking:
body.append(ln.rstrip())
return body
def wrap_try_except(body: List[str], pkg: str, ver: str) -> List[str]:
"""Wrap the file with try/except for better traceability of import misalignment."""
not_empty = sum(1 for ln in body if ln)
if not_empty == 0:
return body
body = ["try:"] + [f" {ln}" if ln else "" for ln in body]
body += [
"",
"except ImportError as err:",
"",
" from os import linesep",
f" from {pkg} import __version__",
f" msg = f'Your `lightning` package was built for `{pkg}=={ver}`," + " but you are running {__version__}'",
" raise type(err)(str(err) + linesep + msg)",
]
return body
def parse_version_from_file(pkg_root: str) -> str:
"""Loading the package version from file."""
file_ver = os.path.join(pkg_root, "__version__.py")
file_about = os.path.join(pkg_root, "__about__.py")
if os.path.isfile(file_ver):
ver = _load_py_module("version", file_ver).version
elif os.path.isfile(file_about):
ver = _load_py_module("about", file_about).__version__
else: # this covers case you have build only meta-package so not additional source files are present
ver = ""
return ver
def prune_duplicate_lines(body):
body_ = []
# drop duplicated lines
for ln in body:
if ln.lstrip() not in body_ or ln.lstrip() in (")", ""):
body_.append(ln)
return body_
def create_meta_package(src_folder: str, pkg_name: str = "pytorch_lightning", lit_name: str = "pytorch"):
"""Parse the real python package and for each module create a mirroe version with repalcing all function and
class implementations by cross-imports to the true package.
As validation run in termnal: `flake8 src/lightning/ --ignore E402,F401,E501`
>>> create_meta_package(os.path.join(_PROJECT_ROOT, "src"))
"""
package_dir = os.path.join(src_folder, pkg_name)
pkg_ver = parse_version_from_file(package_dir)
# shutil.rmtree(os.path.join(src_folder, "lightning", lit_name))
py_files = glob.glob(os.path.join(src_folder, pkg_name, "**", "*.py"), recursive=True)
for py_file in py_files:
local_path = py_file.replace(package_dir + os.path.sep, "")
fname = os.path.basename(py_file)
if "-" in local_path:
continue
with open(py_file, encoding="utf-8") as fp:
lines = [ln.rstrip() for ln in fp.readlines()]
import_path = pkg_name + "." + local_path.replace(".py", "").replace(os.path.sep, ".")
import_path = import_path.replace(".__init__", "")
if fname in ("__about__.py", "__version__.py"):
body = lines
else:
if fname.startswith("_") and fname not in ("__init__.py", "__main__.py"):
logging.warning(f"unsupported file: {local_path}")
continue
# ToDO: perform some smarter parsing - preserve Constants, lambdas, etc
body = prune_comments_docstrings([ln.rstrip() for ln in lines])
if fname not in ("__init__.py", "__main__.py"):
body = prune_imports_callables(body)
for key_word in ("class", "def", "async def"):
body = replace_block_with_imports(body, import_path, key_word)
# TODO: fix reimporting which is artefact after replacing var assignment with import;
# after fixing , update CI by remove F811 from CI/check pkg
body = replace_vars_with_imports(body, import_path)
if fname not in ("__main__.py",):
body = prune_func_calls(body)
body_len = -1
# in case of several in-depth statements
while body_len != len(body):
body_len = len(body)
body = prune_duplicate_lines(body)
body = prune_empty_statements(body)
# add try/catch wrapper for whole body,
# so when import fails it tells you what is the package version this meta package was generated for...
body = wrap_try_except(body, pkg_name, pkg_ver)
# todo: apply pre-commit formatting
# clean to many empty lines
body = [ln for ln, _group in groupby(body)]
# drop duplicated lines
body = prune_duplicate_lines(body)
# compose the target file name
new_file = os.path.join(src_folder, "lightning", lit_name, local_path)
os.makedirs(os.path.dirname(new_file), exist_ok=True)
with open(new_file, "w", encoding="utf-8") as fp:
fp.writelines([ln + os.linesep for ln in body])
def set_version_today(fpath: str) -> None:
"""Replace the template date with today."""
with open(fpath) as fp:
lines = fp.readlines()
def _replace_today(ln):
today = datetime.now()
return ln.replace("YYYY.-M.-D", f"{today.year}.{today.month}.{today.day}")
lines = list(map(_replace_today, lines))
with open(fpath, "w") as fp:
fp.writelines(lines)
def _download_frontend(root: str = _PROJECT_ROOT):
"""Downloads an archive file for a specific release of the Lightning frontend and extracts it to the correct
directory."""
try:
frontend_dir = pathlib.Path(root, "src", "lightning_app", "ui")
download_dir = tempfile.mkdtemp()
shutil.rmtree(frontend_dir, ignore_errors=True)
response = urllib.request.urlopen(_LIGHTNING_FRONTEND_RELEASE_URL)
file = tarfile.open(fileobj=response, mode="r|gz")
file.extractall(path=download_dir)
shutil.move(os.path.join(download_dir, "build"), frontend_dir)
print("The Lightning UI has successfully been downloaded!")
# If installing from source without internet connection, we don't want to break the installation
except Exception:
print("The Lightning UI downloading has failed!")
def _relax_require_versions(source_dir: str = "src", req_dir: str = "requirements") -> None:
"""Parse the base requirements and append as version adjustments if needed `pkg>=X1.Y1.Z1,==X2.Y2.*`.
>>> _relax_require_versions("../src", "../requirements")
"""
reqs = load_requirements(req_dir, file_name="base.txt")
for i, req in enumerate(parse_requirements(reqs)):
ver_ = parse_version_from_file(os.path.join(source_dir, req.name))
if not ver_:
continue
ver2 = ".".join(ver_.split(".")[:2] + ["*"])
reqs[i] = f"{req}, =={ver2}"
with open(os.path.join(req_dir, "base.txt"), "w") as fp:
fp.writelines([ln + os.linesep for ln in reqs])
def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requirements: bool = False) -> None:
"""Load all base requirements from all particular packages and prune duplicates."""
requires = [
load_requirements(d, file_name="base.txt", unfreeze=not freeze_requirements)
for d in glob.glob(os.path.join(req_dir, "*"))
if os.path.isdir(d)
]
if not requires:
return None
# TODO: add some smarter version aggregation per each package
requires = list(chain(*requires))
with open(os.path.join(req_dir, "base.txt"), "w") as fp:
fp.writelines([ln + os.linesep for ln in requires])