From dbfa292ed3cce441462f94787acac9e13fcb572d Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Sun, 28 Jun 2020 15:34:28 +0200 Subject: [PATCH] Output more stats in evaluate --- spacy/cli/evaluate.py | 52 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/spacy/cli/evaluate.py b/spacy/cli/evaluate.py index 67123ecf1..a9ddfe9be 100644 --- a/spacy/cli/evaluate.py +++ b/spacy/cli/evaluate.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Optional, List, Dict from timeit import default_timer as timer from wasabi import Printer from pathlib import Path @@ -89,8 +89,20 @@ def evaluate( "Sent R": f"{scorer.sent_r:.2f}", "Sent F": f"{scorer.sent_f:.2f}", } + data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()} + msg.table(results, title="Results") + if scorer.ents_per_type: + data["ents_per_type"] = scorer.ents_per_type + print_ents_per_type(msg, scorer.ents_per_type) + if scorer.textcats_f_per_cat: + data["textcats_f_per_cat"] = scorer.textcats_f_per_cat + print_textcats_f_per_cat(msg, scorer.textcats_f_per_cat) + if scorer.textcats_auc_per_cat: + data["textcats_auc_per_cat"] = scorer.textcats_auc_per_cat + print_textcats_auc_per_cat(msg, scorer.textcats_auc_per_cat) + if displacy_path: docs = [ex.predicted for ex in dev_dataset] render_deps = "parser" in nlp.meta.get("pipeline", []) @@ -105,7 +117,6 @@ def evaluate( ) msg.good(f"Generated {displacy_limit} parses as HTML", displacy_path) - data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()} if output_path is not None: srsly.write_json(output_path, data) msg.good(f"Saved results to {output_path}") @@ -131,3 +142,40 @@ def render_parses( ) with (output_path / "parses.html").open("w", encoding="utf8") as file_: file_.write(html) + + +def print_ents_per_type(msg: Printer, scores: Dict[str, Dict[str, float]]) -> None: + data = [ + (k, f"{v['p']:.2f}", f"{v['r']:.2f}", f"{v['f']:.2f}") + for k, v in scores.items() + ] + msg.table( + data, + header=("", "P", "R", "F"), + aligns=("l", "r", "r", "r"), + title="NER (per type)", + ) + + +def print_textcats_f_per_cat(msg: Printer, scores: Dict[str, Dict[str, float]]) -> None: + data = [ + (k, f"{v['p']:.2f}", f"{v['r']:.2f}", f"{v['f']:.2f}") + for k, v in scores.items() + ] + msg.table( + data, + header=("", "P", "R", "F"), + aligns=("l", "r", "r", "r"), + title="Textcat F (per type)", + ) + + +def print_textcats_auc_per_cat( + msg: Printer, scores: Dict[str, Dict[str, float]] +) -> None: + msg.table( + [(k, f"{v['roc_auc_score']:.2f}") for k, v in scores.items()], + header=("", "ROC AUC"), + aligns=("l", "r"), + title="Textcat ROC AUC (per label)", + )