mirror of https://github.com/explosion/spaCy.git
Improve output handling in evaluate
This commit is contained in:
parent
df22d490b1
commit
42eb381ec6
|
@ -2,6 +2,8 @@ from typing import Optional, List
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
from wasabi import Printer
|
from wasabi import Printer
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
import srsly
|
||||||
|
|
||||||
from ..gold import Corpus
|
from ..gold import Corpus
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
|
@ -16,12 +18,11 @@ def evaluate_cli(
|
||||||
# fmt: off
|
# fmt: off
|
||||||
model: str = Arg(..., help="Model name or path"),
|
model: str = Arg(..., help="Model name or path"),
|
||||||
data_path: Path = Arg(..., help="Location of JSON-formatted evaluation data", exists=True),
|
data_path: Path = Arg(..., help="Location of JSON-formatted evaluation data", exists=True),
|
||||||
|
output: Optional[Path] = Opt(None, "--output", "-o", help="Output JSON file for metrics", dir_okay=False),
|
||||||
gpu_id: int = Opt(-1, "--gpu-id", "-g", help="Use GPU"),
|
gpu_id: int = Opt(-1, "--gpu-id", "-g", help="Use GPU"),
|
||||||
gold_preproc: bool = Opt(False, "--gold-preproc", "-G", help="Use gold preprocessing"),
|
gold_preproc: bool = Opt(False, "--gold-preproc", "-G", help="Use gold preprocessing"),
|
||||||
displacy_path: Optional[Path] = Opt(None, "--displacy-path", "-dp", help="Directory to output rendered parses as HTML", exists=True, file_okay=False),
|
displacy_path: Optional[Path] = Opt(None, "--displacy-path", "-dp", help="Directory to output rendered parses as HTML", exists=True, file_okay=False),
|
||||||
displacy_limit: int = Opt(25, "--displacy-limit", "-dl", help="Limit of parses to render as HTML"),
|
displacy_limit: int = Opt(25, "--displacy-limit", "-dl", help="Limit of parses to render as HTML"),
|
||||||
return_scores: bool = Opt(False, "--return-scores", "-R", help="Return dict containing model scores"),
|
|
||||||
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -31,24 +32,24 @@ def evaluate_cli(
|
||||||
evaluate(
|
evaluate(
|
||||||
model,
|
model,
|
||||||
data_path,
|
data_path,
|
||||||
|
output=output,
|
||||||
gpu_id=gpu_id,
|
gpu_id=gpu_id,
|
||||||
gold_preproc=gold_preproc,
|
gold_preproc=gold_preproc,
|
||||||
displacy_path=displacy_path,
|
displacy_path=displacy_path,
|
||||||
displacy_limit=displacy_limit,
|
displacy_limit=displacy_limit,
|
||||||
silent=False,
|
silent=False,
|
||||||
return_scores=return_scores,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
model: str,
|
model: str,
|
||||||
data_path: Path,
|
data_path: Path,
|
||||||
|
output: Optional[Path],
|
||||||
gpu_id: int = -1,
|
gpu_id: int = -1,
|
||||||
gold_preproc: bool = False,
|
gold_preproc: bool = False,
|
||||||
displacy_path: Optional[Path] = None,
|
displacy_path: Optional[Path] = None,
|
||||||
displacy_limit: int = 25,
|
displacy_limit: int = 25,
|
||||||
silent: bool = True,
|
silent: bool = True,
|
||||||
return_scores: bool = False,
|
|
||||||
) -> Scorer:
|
) -> Scorer:
|
||||||
msg = Printer(no_print=silent, pretty=not silent)
|
msg = Printer(no_print=silent, pretty=not silent)
|
||||||
util.fix_random_seed()
|
util.fix_random_seed()
|
||||||
|
@ -56,6 +57,7 @@ def evaluate(
|
||||||
util.use_gpu(gpu_id)
|
util.use_gpu(gpu_id)
|
||||||
util.set_env_log(False)
|
util.set_env_log(False)
|
||||||
data_path = util.ensure_path(data_path)
|
data_path = util.ensure_path(data_path)
|
||||||
|
output_path = util.ensure_path(output)
|
||||||
displacy_path = util.ensure_path(displacy_path)
|
displacy_path = util.ensure_path(displacy_path)
|
||||||
if not data_path.exists():
|
if not data_path.exists():
|
||||||
msg.fail("Evaluation data not found", data_path, exits=1)
|
msg.fail("Evaluation data not found", data_path, exits=1)
|
||||||
|
@ -105,8 +107,11 @@ def evaluate(
|
||||||
ents=render_ents,
|
ents=render_ents,
|
||||||
)
|
)
|
||||||
msg.good(f"Generated {displacy_limit} parses as HTML", displacy_path)
|
msg.good(f"Generated {displacy_limit} parses as HTML", displacy_path)
|
||||||
if return_scores:
|
|
||||||
return scorer.scores
|
data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()}
|
||||||
|
if output_path is not None:
|
||||||
|
srsly.write_json(output_path, data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def render_parses(
|
def render_parses(
|
||||||
|
|
Loading…
Reference in New Issue