160 lines
5.9 KiB
Python
160 lines
5.9 KiB
Python
import os
|
|
import re
|
|
from itertools import chain
|
|
from pathlib import Path
|
|
from pprint import pprint
|
|
from typing import Dict, List, Optional, Sequence, Tuple
|
|
|
|
import pkg_resources
|
|
|
|
REQUIREMENT_FILES = {
|
|
"pytorch": (
|
|
"requirements/pytorch/base.txt",
|
|
"requirements/pytorch/extra.txt",
|
|
"requirements/pytorch/strategies.txt",
|
|
"requirements/pytorch/examples.txt",
|
|
),
|
|
"app": (
|
|
"requirements/app/base.txt",
|
|
"requirements/app/ui.txt",
|
|
"requirements/app/cloud.txt",
|
|
),
|
|
"lite": (
|
|
"requirements/lite/base.txt",
|
|
"requirements/lite/strategies.txt",
|
|
),
|
|
}
|
|
REQUIREMENT_FILES_ALL = list(chain(*REQUIREMENT_FILES.values()))
|
|
|
|
|
|
def _retrieve_files(directory: str, *ext: str) -> List[str]:
|
|
all_files = []
|
|
for root, _, files in os.walk(directory):
|
|
for fname in files:
|
|
if not ext or any(os.path.split(fname)[1].lower().endswith(e) for e in ext):
|
|
all_files.append(os.path.join(root, fname))
|
|
|
|
return all_files
|
|
|
|
|
|
def _replace_imports(lines: List[str], mapping: List[Tuple[str, str]]) -> List[str]:
|
|
"""Replace imports of standalone package to lightning.
|
|
|
|
>>> lns = [
|
|
... "lightning_app",
|
|
... "delete_cloud_lightning_apps",
|
|
... "from lightning_app import",
|
|
... "lightning_apps = []",
|
|
... "lightning_app and pytorch_lightning are ours",
|
|
... "def _lightning_app():",
|
|
... ":class:`~lightning_app.core.flow.LightningFlow`"
|
|
... ]
|
|
>>> mapping = [("lightning_app", "lightning.app"), ("pytorch_lightning", "lightning.pytorch")]
|
|
>>> _replace_imports(lns, mapping) # doctest: +NORMALIZE_WHITESPACE
|
|
['lightning.app', 'delete_cloud_lightning_apps', 'from lightning.app import', 'lightning_apps = []',\
|
|
'lightning.app and lightning.pytorch are ours', 'def _lightning_app():',\
|
|
':class:`~lightning.app.core.flow.LightningFlow`']
|
|
"""
|
|
out = lines[:]
|
|
for source_import, target_import in mapping:
|
|
for i, ln in enumerate(out):
|
|
out[i] = re.sub(rf"([^_]|^){source_import}([^_\w]|$)", rf"\1{target_import}\2", ln)
|
|
return out
|
|
|
|
|
|
def copy_replace_imports(
|
|
source_dir: str, source_imports: List[str], target_imports: List[str], target_dir: Optional[str] = None
|
|
) -> None:
|
|
print(f"Replacing imports: {locals()}")
|
|
assert len(source_imports) == len(target_imports), (
|
|
"source and target imports must have the same length, "
|
|
f"source: {len(source_imports)}, target: {len(target_imports)}"
|
|
)
|
|
if target_dir is None:
|
|
target_dir = source_dir
|
|
|
|
ls = _retrieve_files(source_dir)
|
|
for fp in ls:
|
|
if fp.endswith(".py") or not fp.endswith(".pyc"):
|
|
with open(fp, encoding="utf-8") as fo:
|
|
try:
|
|
lines = fo.readlines()
|
|
except UnicodeDecodeError:
|
|
# a binary file, skip
|
|
print(f"Skipped replacing imports for {fp}")
|
|
continue
|
|
lines = _replace_imports(lines, list(zip(source_imports, target_imports)))
|
|
fp_new = fp.replace(source_dir, target_dir)
|
|
os.makedirs(os.path.dirname(fp_new), exist_ok=True)
|
|
with open(fp_new, "w", encoding="utf-8") as fo:
|
|
fo.writelines(lines)
|
|
|
|
|
|
def create_mirror_package(source_dir: str, package_mapping: Dict[str, str]) -> None:
|
|
# replace imports and copy the code
|
|
mapping = package_mapping.copy()
|
|
mapping.pop("lightning", None) # pop this key to avoid replacing `lightning` to `lightning.lightning`
|
|
for new, previous in mapping.items():
|
|
copy_replace_imports(
|
|
source_dir=os.path.join(source_dir, previous),
|
|
# pytorch_lightning uses lightning_lite, so we need to replace all imports for all directories
|
|
source_imports=list(mapping.values()),
|
|
target_imports=[f"lightning.{new}" for new in mapping],
|
|
target_dir=os.path.join(source_dir, "lightning", new),
|
|
)
|
|
|
|
|
|
class AssistantCLI:
|
|
@staticmethod
|
|
def requirements_prune_pkgs(packages: Sequence[str], req_files: Sequence[str] = REQUIREMENT_FILES_ALL) -> None:
|
|
"""Remove some packages from given requirement files."""
|
|
if isinstance(req_files, str):
|
|
req_files = [req_files]
|
|
for req in req_files:
|
|
AssistantCLI._prune_packages(req, packages)
|
|
|
|
@staticmethod
|
|
def _prune_packages(req_file: str, packages: Sequence[str]) -> None:
|
|
"""Remove some packages from given requirement files."""
|
|
path = Path(req_file)
|
|
assert path.exists()
|
|
text = path.read_text()
|
|
lines = text.splitlines()
|
|
final = []
|
|
for line in lines:
|
|
ln_ = line.strip()
|
|
if not ln_ or ln_.startswith("#"):
|
|
final.append(line)
|
|
continue
|
|
req = list(pkg_resources.parse_requirements(ln_))[0]
|
|
if req.name not in packages:
|
|
final.append(line)
|
|
pprint(final)
|
|
path.write_text("\n".join(final))
|
|
|
|
@staticmethod
|
|
def _replace_min(fname: str) -> None:
|
|
req = open(fname, encoding="utf-8").read().replace(">=", "==")
|
|
open(fname, "w", encoding="utf-8").write(req)
|
|
|
|
@staticmethod
|
|
def replace_oldest_ver(requirement_fnames: Sequence[str] = REQUIREMENT_FILES_ALL) -> None:
|
|
"""Replace the min package version by fixed one."""
|
|
for fname in requirement_fnames:
|
|
AssistantCLI._replace_min(fname)
|
|
|
|
@staticmethod
|
|
def copy_replace_imports(
|
|
source_dir: str, source_import: str, target_import: str, target_dir: Optional[str] = None
|
|
) -> None:
|
|
"""Recursively replace imports in given folder."""
|
|
source_imports = source_import.strip().split(",")
|
|
target_imports = target_import.strip().split(",")
|
|
copy_replace_imports(source_dir, source_imports, target_imports, target_dir=target_dir)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import jsonargparse
|
|
|
|
jsonargparse.CLI(AssistantCLI, as_positional=False)
|