From dd84577a98614260340f374813963142650914e6 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Tue, 25 Aug 2020 11:54:53 +0200 Subject: [PATCH] Update CLI utils, project.yml schema and add test --- spacy/cli/_util.py | 6 +++--- spacy/schemas.py | 19 +++++++++++++++++-- spacy/tests/test_cli.py | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py index f1ab2effc..4ed316219 100644 --- a/spacy/cli/_util.py +++ b/spacy/cli/_util.py @@ -195,7 +195,7 @@ def get_checksum(path: Union[Path, str]) -> str: for sub_file in sorted(fp for fp in path.rglob("*") if fp.is_file()): dir_checksum.update(sub_file.read_bytes()) return dir_checksum.hexdigest() - raise ValueError(f"Can't get checksum for {path}: not a file or directory") + msg.fail(f"Can't get checksum for {path}: not a file or directory", exits=1) @contextmanager @@ -320,9 +320,9 @@ def git_sparse_checkout( repo: str, subpath: str, dest: Path, *, branch: Optional[str] = None ): if dest.exists(): - raise IOError("Destination of checkout must not exist") + msg.fail("Destination of checkout must not exist", exits=1) if not dest.parent.exists(): - raise IOError("Parent of destination of checkout must exist") + msg.fail("Parent of destination of checkout must exist", exits=1) # We're using Git and sparse checkout to only clone the files we need with make_tempdir() as tmp_dir: cmd = ( diff --git a/spacy/schemas.py b/spacy/schemas.py index 15282e98e..069bd7c0a 100644 --- a/spacy/schemas.py +++ b/spacy/schemas.py @@ -283,7 +283,15 @@ class ConfigSchema(BaseModel): # Project config Schema -class ProjectConfigAsset(BaseModel): +class ProjectConfigAssetGitItem(BaseModel): + # fmt: off + repo: StrictStr = Field(..., title="URL of Git repo to download from") + path: StrictStr = Field(..., title="File path or sub-directory to download (used for sparse checkout)") + branch: StrictStr = Field("master", title="Branch to clone from") + # fmt: on + + +class ProjectConfigAssetURL(BaseModel): # fmt: off dest: StrictStr = Field(..., title="Destination of downloaded asset") url: Optional[StrictStr] = Field(None, title="URL of asset") @@ -291,6 +299,13 @@ class ProjectConfigAsset(BaseModel): # fmt: on +class ProjectConfigAssetGit(BaseModel): + # fmt: off + git: ProjectConfigAssetGitItem = Field(..., title="Git repo information") + checksum: str = Field(None, title="MD5 hash of file", regex=r"([a-fA-F\d]{32})") + # fmt: on + + class ProjectConfigCommand(BaseModel): # fmt: off name: StrictStr = Field(..., title="Name of command") @@ -310,7 +325,7 @@ class ProjectConfigCommand(BaseModel): class ProjectConfigSchema(BaseModel): # fmt: off vars: Dict[StrictStr, Any] = Field({}, title="Optional variables to substitute in commands") - assets: List[ProjectConfigAsset] = Field([], title="Data assets") + assets: List[Union[ProjectConfigAssetURL, ProjectConfigAssetGit]] = Field([], title="Data assets") workflows: Dict[StrictStr, List[StrictStr]] = Field({}, title="Named workflows, mapped to list of project commands to run in order") commands: List[ProjectConfigCommand] = Field([], title="Project command shortucts") # fmt: on diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 104c7c516..448aaf202 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -270,6 +270,41 @@ def test_pretrain_make_docs(): assert skip_count == 0 +def test_project_config_validation_full(): + config = { + "vars": {"some_var": 20}, + "directories": ["assets", "configs", "corpus", "scripts", "training"], + "assets": [ + { + "dest": "x", + "url": "https://example.com", + "checksum": "63373dd656daa1fd3043ce166a59474c", + }, + { + "dest": "y", + "git": { + "repo": "https://github.com/example/repo", + "branch": "develop", + "path": "y", + }, + }, + ], + "commands": [ + { + "name": "train", + "help": "Train a model", + "script": ["python -m spacy train config.cfg -o training"], + "deps": ["config.cfg", "corpus/training.spcy"], + "outputs": ["training/model-best"], + }, + {"name": "test", "script": ["pytest", "custom.py"], "no_skip": True}, + ], + "workflows": {"all": ["train", "test"], "train": ["train"]}, + } + errors = validate(ProjectConfigSchema, config) + assert not errors + + @pytest.mark.parametrize( "config", [