mirror of https://github.com/explosion/spaCy.git
218 lines
8.0 KiB
Python
218 lines
8.0 KiB
Python
import os
|
|
import re
|
|
import shutil
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Optional
|
|
|
|
import requests
|
|
import typer
|
|
from wasabi import msg
|
|
|
|
from ...util import ensure_path, working_dir
|
|
from .._util import (
|
|
PROJECT_FILE,
|
|
Arg,
|
|
Opt,
|
|
SimpleFrozenDict,
|
|
download_file,
|
|
get_checksum,
|
|
get_git_version,
|
|
git_checkout,
|
|
load_project_config,
|
|
parse_config_overrides,
|
|
project_cli,
|
|
)
|
|
|
|
# Whether assets are extra if `extra` is not set.
|
|
EXTRA_DEFAULT = False
|
|
|
|
|
|
@project_cli.command(
|
|
"assets",
|
|
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
|
)
|
|
def project_assets_cli(
|
|
# fmt: off
|
|
ctx: typer.Context, # This is only used to read additional arguments
|
|
project_dir: Path = Arg(Path.cwd(), help="Path to cloned project. Defaults to current working directory.", exists=True, file_okay=False),
|
|
sparse_checkout: bool = Opt(False, "--sparse", "-S", help="Use sparse checkout for assets provided via Git, to only check out and clone the files needed. Requires Git v22.2+."),
|
|
extra: bool = Opt(False, "--extra", "-e", help="Download all assets, including those marked as 'extra'.")
|
|
# fmt: on
|
|
):
|
|
"""Fetch project assets like datasets and pretrained weights. Assets are
|
|
defined in the "assets" section of the project.yml. If a checksum is
|
|
provided in the project.yml, the file is only downloaded if no local file
|
|
with the same checksum exists.
|
|
|
|
DOCS: https://spacy.io/api/cli#project-assets
|
|
"""
|
|
overrides = parse_config_overrides(ctx.args)
|
|
project_assets(
|
|
project_dir,
|
|
overrides=overrides,
|
|
sparse_checkout=sparse_checkout,
|
|
extra=extra,
|
|
)
|
|
|
|
|
|
def project_assets(
|
|
project_dir: Path,
|
|
*,
|
|
overrides: Dict[str, Any] = SimpleFrozenDict(),
|
|
sparse_checkout: bool = False,
|
|
extra: bool = False,
|
|
) -> None:
|
|
"""Fetch assets for a project using DVC if possible.
|
|
|
|
project_dir (Path): Path to project directory.
|
|
sparse_checkout (bool): Use sparse checkout for assets provided via Git, to only check out and clone the files
|
|
needed.
|
|
extra (bool): Whether to download all assets, including those marked as 'extra'.
|
|
"""
|
|
project_path = ensure_path(project_dir)
|
|
config = load_project_config(project_path, overrides=overrides)
|
|
assets = [
|
|
asset
|
|
for asset in config.get("assets", [])
|
|
if extra or not asset.get("extra", EXTRA_DEFAULT)
|
|
]
|
|
if not assets:
|
|
msg.warn(
|
|
f"No assets specified in {PROJECT_FILE} (if assets are marked as extra, download them with --extra)",
|
|
exits=0,
|
|
)
|
|
msg.info(f"Fetching {len(assets)} asset(s)")
|
|
|
|
for asset in assets:
|
|
dest = (project_dir / asset["dest"]).resolve()
|
|
checksum = asset.get("checksum")
|
|
if "git" in asset:
|
|
git_err = (
|
|
f"Cloning spaCy project templates requires Git and the 'git' command. "
|
|
f"Make sure it's installed and that the executable is available."
|
|
)
|
|
get_git_version(error=git_err)
|
|
if dest.exists():
|
|
# If there's already a file, check for checksum
|
|
if checksum and checksum == get_checksum(dest):
|
|
msg.good(
|
|
f"Skipping download with matching checksum: {asset['dest']}"
|
|
)
|
|
continue
|
|
else:
|
|
if dest.is_dir():
|
|
shutil.rmtree(dest)
|
|
else:
|
|
dest.unlink()
|
|
if "repo" not in asset["git"] or asset["git"]["repo"] is None:
|
|
msg.fail(
|
|
"A git asset must include 'repo', the repository address.", exits=1
|
|
)
|
|
if "path" not in asset["git"] or asset["git"]["path"] is None:
|
|
msg.fail(
|
|
"A git asset must include 'path' - use \"\" to get the entire repository.",
|
|
exits=1,
|
|
)
|
|
git_checkout(
|
|
asset["git"]["repo"],
|
|
asset["git"]["path"],
|
|
dest,
|
|
branch=asset["git"].get("branch"),
|
|
sparse=sparse_checkout,
|
|
)
|
|
msg.good(f"Downloaded asset {dest}")
|
|
else:
|
|
url = asset.get("url")
|
|
if not url:
|
|
# project.yml defines asset without URL that the user has to place
|
|
check_private_asset(dest, checksum)
|
|
continue
|
|
fetch_asset(project_path, url, dest, checksum)
|
|
|
|
|
|
def check_private_asset(dest: Path, checksum: Optional[str] = None) -> None:
|
|
"""Check and validate assets without a URL (private assets that the user
|
|
has to provide themselves) and give feedback about the checksum.
|
|
|
|
dest (Path): Destination path of the asset.
|
|
checksum (Optional[str]): Optional checksum of the expected file.
|
|
"""
|
|
if not Path(dest).exists():
|
|
err = f"No URL provided for asset. You need to add this file yourself: {dest}"
|
|
msg.warn(err)
|
|
else:
|
|
if not checksum:
|
|
msg.good(f"Asset already exists: {dest}")
|
|
elif checksum == get_checksum(dest):
|
|
msg.good(f"Asset exists with matching checksum: {dest}")
|
|
else:
|
|
msg.fail(f"Asset available but with incorrect checksum: {dest}")
|
|
|
|
|
|
def fetch_asset(
|
|
project_path: Path, url: str, dest: Path, checksum: Optional[str] = None
|
|
) -> None:
|
|
"""Fetch an asset from a given URL or path. If a checksum is provided and a
|
|
local file exists, it's only re-downloaded if the checksum doesn't match.
|
|
|
|
project_path (Path): Path to project directory.
|
|
url (str): URL or path to asset.
|
|
checksum (Optional[str]): Optional expected checksum of local file.
|
|
RETURNS (Optional[Path]): The path to the fetched asset or None if fetching
|
|
the asset failed.
|
|
"""
|
|
dest_path = (project_path / dest).resolve()
|
|
if dest_path.exists():
|
|
# If there's already a file, check for checksum
|
|
if checksum:
|
|
if checksum == get_checksum(dest_path):
|
|
msg.good(f"Skipping download with matching checksum: {dest}")
|
|
return
|
|
else:
|
|
# If there's not a checksum, make sure the file is a possibly valid size
|
|
if os.path.getsize(dest_path) == 0:
|
|
msg.warn(f"Asset exists but with size of 0 bytes, deleting: {dest}")
|
|
os.remove(dest_path)
|
|
# We might as well support the user here and create parent directories in
|
|
# case the asset dir isn't listed as a dir to create in the project.yml
|
|
if not dest_path.parent.exists():
|
|
dest_path.parent.mkdir(parents=True)
|
|
with working_dir(project_path):
|
|
url = convert_asset_url(url)
|
|
try:
|
|
download_file(url, dest_path)
|
|
msg.good(f"Downloaded asset {dest}")
|
|
except requests.exceptions.RequestException as e:
|
|
if Path(url).exists() and Path(url).is_file():
|
|
# If it's a local file, copy to destination
|
|
shutil.copy(url, str(dest_path))
|
|
msg.good(f"Copied local asset {dest}")
|
|
else:
|
|
msg.fail(f"Download failed: {dest}", e)
|
|
if checksum and checksum != get_checksum(dest_path):
|
|
msg.fail(f"Checksum doesn't match value defined in {PROJECT_FILE}: {dest}")
|
|
|
|
|
|
def convert_asset_url(url: str) -> str:
|
|
"""Check and convert the asset URL if needed.
|
|
|
|
url (str): The asset URL.
|
|
RETURNS (str): The converted URL.
|
|
"""
|
|
# If the asset URL is a regular GitHub URL it's likely a mistake
|
|
if (
|
|
re.match(r"(http(s?)):\/\/github.com", url)
|
|
and "releases/download" not in url
|
|
and "/raw/" not in url
|
|
):
|
|
converted = url.replace("github.com", "raw.githubusercontent.com")
|
|
converted = re.sub(r"/(tree|blob)/", "/", converted)
|
|
msg.warn(
|
|
"Downloading from a regular GitHub URL. This will only download "
|
|
"the source of the page, not the actual file. Converting the URL "
|
|
"to a raw URL.",
|
|
converted,
|
|
)
|
|
return converted
|
|
return url
|