mirror of https://github.com/explosion/spaCy.git
Improve output handling in evaluate
This commit is contained in:
parent
df22d490b1
commit
42eb381ec6
|
@ -2,6 +2,8 @@ from typing import Optional, List
|
|||
from timeit import default_timer as timer
|
||||
from wasabi import Printer
|
||||
from pathlib import Path
|
||||
import re
|
||||
import srsly
|
||||
|
||||
from ..gold import Corpus
|
||||
from ..tokens import Doc
|
||||
|
@ -16,12 +18,11 @@ def evaluate_cli(
|
|||
# fmt: off
|
||||
model: str = Arg(..., help="Model name or path"),
|
||||
data_path: Path = Arg(..., help="Location of JSON-formatted evaluation data", exists=True),
|
||||
output: Optional[Path] = Opt(None, "--output", "-o", help="Output JSON file for metrics", dir_okay=False),
|
||||
gpu_id: int = Opt(-1, "--gpu-id", "-g", help="Use GPU"),
|
||||
gold_preproc: bool = Opt(False, "--gold-preproc", "-G", help="Use gold preprocessing"),
|
||||
displacy_path: Optional[Path] = Opt(None, "--displacy-path", "-dp", help="Directory to output rendered parses as HTML", exists=True, file_okay=False),
|
||||
displacy_limit: int = Opt(25, "--displacy-limit", "-dl", help="Limit of parses to render as HTML"),
|
||||
return_scores: bool = Opt(False, "--return-scores", "-R", help="Return dict containing model scores"),
|
||||
|
||||
# fmt: on
|
||||
):
|
||||
"""
|
||||
|
@ -31,24 +32,24 @@ def evaluate_cli(
|
|||
evaluate(
|
||||
model,
|
||||
data_path,
|
||||
output=output,
|
||||
gpu_id=gpu_id,
|
||||
gold_preproc=gold_preproc,
|
||||
displacy_path=displacy_path,
|
||||
displacy_limit=displacy_limit,
|
||||
silent=False,
|
||||
return_scores=return_scores,
|
||||
)
|
||||
|
||||
|
||||
def evaluate(
|
||||
model: str,
|
||||
data_path: Path,
|
||||
output: Optional[Path],
|
||||
gpu_id: int = -1,
|
||||
gold_preproc: bool = False,
|
||||
displacy_path: Optional[Path] = None,
|
||||
displacy_limit: int = 25,
|
||||
silent: bool = True,
|
||||
return_scores: bool = False,
|
||||
) -> Scorer:
|
||||
msg = Printer(no_print=silent, pretty=not silent)
|
||||
util.fix_random_seed()
|
||||
|
@ -56,6 +57,7 @@ def evaluate(
|
|||
util.use_gpu(gpu_id)
|
||||
util.set_env_log(False)
|
||||
data_path = util.ensure_path(data_path)
|
||||
output_path = util.ensure_path(output)
|
||||
displacy_path = util.ensure_path(displacy_path)
|
||||
if not data_path.exists():
|
||||
msg.fail("Evaluation data not found", data_path, exits=1)
|
||||
|
@ -105,8 +107,11 @@ def evaluate(
|
|||
ents=render_ents,
|
||||
)
|
||||
msg.good(f"Generated {displacy_limit} parses as HTML", displacy_path)
|
||||
if return_scores:
|
||||
return scorer.scores
|
||||
|
||||
data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()}
|
||||
if output_path is not None:
|
||||
srsly.write_json(output_path, data)
|
||||
return data
|
||||
|
||||
|
||||
def render_parses(
|
||||
|
|
Loading…
Reference in New Issue