diff --git a/spacy/cli/evaluate.py b/spacy/cli/evaluate.py index a18e51623..fcc7fbf9b 100644 --- a/spacy/cli/evaluate.py +++ b/spacy/cli/evaluate.py @@ -2,6 +2,8 @@ from typing import Optional, List from timeit import default_timer as timer from wasabi import Printer from pathlib import Path +import re +import srsly from ..gold import Corpus from ..tokens import Doc @@ -16,13 +18,12 @@ def evaluate_cli( # fmt: off model: str = Arg(..., help="Model name or path"), data_path: Path = Arg(..., help="Location of JSON-formatted evaluation data", exists=True), + output: Optional[Path] = Opt(None, "--output", "-o", help="Output JSON file for metrics", dir_okay=False), gpu_id: int = Opt(-1, "--gpu-id", "-g", help="Use GPU"), gold_preproc: bool = Opt(False, "--gold-preproc", "-G", help="Use gold preprocessing"), displacy_path: Optional[Path] = Opt(None, "--displacy-path", "-dp", help="Directory to output rendered parses as HTML", exists=True, file_okay=False), displacy_limit: int = Opt(25, "--displacy-limit", "-dl", help="Limit of parses to render as HTML"), - return_scores: bool = Opt(False, "--return-scores", "-R", help="Return dict containing model scores"), - - # fmt: on + # fmt: on ): """ Evaluate a model. To render a sample of parses in a HTML file, set an @@ -31,24 +32,24 @@ def evaluate_cli( evaluate( model, data_path, + output=output, gpu_id=gpu_id, gold_preproc=gold_preproc, displacy_path=displacy_path, displacy_limit=displacy_limit, silent=False, - return_scores=return_scores, ) def evaluate( model: str, data_path: Path, + output: Optional[Path], gpu_id: int = -1, gold_preproc: bool = False, displacy_path: Optional[Path] = None, displacy_limit: int = 25, silent: bool = True, - return_scores: bool = False, ) -> Scorer: msg = Printer(no_print=silent, pretty=not silent) util.fix_random_seed() @@ -56,6 +57,7 @@ def evaluate( util.use_gpu(gpu_id) util.set_env_log(False) data_path = util.ensure_path(data_path) + output_path = util.ensure_path(output) displacy_path = util.ensure_path(displacy_path) if not data_path.exists(): msg.fail("Evaluation data not found", data_path, exits=1) @@ -105,8 +107,11 @@ def evaluate( ents=render_ents, ) msg.good(f"Generated {displacy_limit} parses as HTML", displacy_path) - if return_scores: - return scorer.scores + + data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()} + if output_path is not None: + srsly.write_json(output_path, data) + return data def render_parses(