Add init labels command

This commit is contained in:
Matthew Honnibal 2020-09-29 16:22:37 +02:00
parent 58c8d4b414
commit 45daf5c9fe
2 changed files with 44 additions and 0 deletions

View File

@ -16,6 +16,7 @@ from .debug_model import debug_model # noqa: F401
from .evaluate import evaluate # noqa: F401 from .evaluate import evaluate # noqa: F401
from .convert import convert # noqa: F401 from .convert import convert # noqa: F401
from .init_pipeline import init_pipeline_cli # noqa: F401 from .init_pipeline import init_pipeline_cli # noqa: F401
from .init_labels import init_labels_cli # noqa: F401
from .init_config import init_config, fill_config # noqa: F401 from .init_config import init_config, fill_config # noqa: F401
from .validate import validate # noqa: F401 from .validate import validate # noqa: F401
from .project.clone import project_clone # noqa: F401 from .project.clone import project_clone # noqa: F401

43
spacy/cli/init_labels.py Normal file
View File

@ -0,0 +1,43 @@
from typing import Optional
import logging
from pathlib import Path
from wasabi import msg
import typer
import srsly
from .. import util
from ..training.initialize import init_nlp, convert_vectors
from ._util import init_cli, Arg, Opt, parse_config_overrides, show_validation_error
from ._util import import_code, setup_gpu
@init_cli.command(
"labels",
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
)
def init_labels_cli(
# fmt: off
ctx: typer.Context, # This is only used to read additional arguments
config_path: Path = Arg(..., help="Path to config file", exists=True),
output_path: Path = Arg(..., help="Output directory for the labels"),
code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
use_gpu: int = Opt(-1, "--gpu-id", "-g", help="GPU ID or -1 for CPU")
# fmt: on
):
if not output_path.exists():
output_path.mkdir()
util.logger.setLevel(logging.DEBUG if verbose else logging.ERROR)
overrides = parse_config_overrides(ctx.args)
import_code(code_path)
setup_gpu(use_gpu)
with show_validation_error(config_path):
config = util.load_config(config_path, overrides=overrides)
with show_validation_error(hint_fill=False):
nlp = init_nlp(config, use_gpu=use_gpu, silent=False)
for name, component in nlp.pipeline:
if getattr(component, "label_data", None) is not None:
srsly.write_json(output_path / f"{name}.json", component.label_data)
msg.good(f"Saving {name} labels to {output_path}/{name}.json")
else:
msg.info(f"No labels found for {name}")