mirror of https://github.com/explosion/spaCy.git
Output more stats in evaluate
This commit is contained in:
parent
90b7fa8fed
commit
dbfa292ed3
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional, List
|
from typing import Optional, List, Dict
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
from wasabi import Printer
|
from wasabi import Printer
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -89,8 +89,20 @@ def evaluate(
|
||||||
"Sent R": f"{scorer.sent_r:.2f}",
|
"Sent R": f"{scorer.sent_r:.2f}",
|
||||||
"Sent F": f"{scorer.sent_f:.2f}",
|
"Sent F": f"{scorer.sent_f:.2f}",
|
||||||
}
|
}
|
||||||
|
data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()}
|
||||||
|
|
||||||
msg.table(results, title="Results")
|
msg.table(results, title="Results")
|
||||||
|
|
||||||
|
if scorer.ents_per_type:
|
||||||
|
data["ents_per_type"] = scorer.ents_per_type
|
||||||
|
print_ents_per_type(msg, scorer.ents_per_type)
|
||||||
|
if scorer.textcats_f_per_cat:
|
||||||
|
data["textcats_f_per_cat"] = scorer.textcats_f_per_cat
|
||||||
|
print_textcats_f_per_cat(msg, scorer.textcats_f_per_cat)
|
||||||
|
if scorer.textcats_auc_per_cat:
|
||||||
|
data["textcats_auc_per_cat"] = scorer.textcats_auc_per_cat
|
||||||
|
print_textcats_auc_per_cat(msg, scorer.textcats_auc_per_cat)
|
||||||
|
|
||||||
if displacy_path:
|
if displacy_path:
|
||||||
docs = [ex.predicted for ex in dev_dataset]
|
docs = [ex.predicted for ex in dev_dataset]
|
||||||
render_deps = "parser" in nlp.meta.get("pipeline", [])
|
render_deps = "parser" in nlp.meta.get("pipeline", [])
|
||||||
|
@ -105,7 +117,6 @@ def evaluate(
|
||||||
)
|
)
|
||||||
msg.good(f"Generated {displacy_limit} parses as HTML", displacy_path)
|
msg.good(f"Generated {displacy_limit} parses as HTML", displacy_path)
|
||||||
|
|
||||||
data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()}
|
|
||||||
if output_path is not None:
|
if output_path is not None:
|
||||||
srsly.write_json(output_path, data)
|
srsly.write_json(output_path, data)
|
||||||
msg.good(f"Saved results to {output_path}")
|
msg.good(f"Saved results to {output_path}")
|
||||||
|
@ -131,3 +142,40 @@ def render_parses(
|
||||||
)
|
)
|
||||||
with (output_path / "parses.html").open("w", encoding="utf8") as file_:
|
with (output_path / "parses.html").open("w", encoding="utf8") as file_:
|
||||||
file_.write(html)
|
file_.write(html)
|
||||||
|
|
||||||
|
|
||||||
|
def print_ents_per_type(msg: Printer, scores: Dict[str, Dict[str, float]]) -> None:
|
||||||
|
data = [
|
||||||
|
(k, f"{v['p']:.2f}", f"{v['r']:.2f}", f"{v['f']:.2f}")
|
||||||
|
for k, v in scores.items()
|
||||||
|
]
|
||||||
|
msg.table(
|
||||||
|
data,
|
||||||
|
header=("", "P", "R", "F"),
|
||||||
|
aligns=("l", "r", "r", "r"),
|
||||||
|
title="NER (per type)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def print_textcats_f_per_cat(msg: Printer, scores: Dict[str, Dict[str, float]]) -> None:
|
||||||
|
data = [
|
||||||
|
(k, f"{v['p']:.2f}", f"{v['r']:.2f}", f"{v['f']:.2f}")
|
||||||
|
for k, v in scores.items()
|
||||||
|
]
|
||||||
|
msg.table(
|
||||||
|
data,
|
||||||
|
header=("", "P", "R", "F"),
|
||||||
|
aligns=("l", "r", "r", "r"),
|
||||||
|
title="Textcat F (per type)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def print_textcats_auc_per_cat(
|
||||||
|
msg: Printer, scores: Dict[str, Dict[str, float]]
|
||||||
|
) -> None:
|
||||||
|
msg.table(
|
||||||
|
[(k, f"{v['roc_auc_score']:.2f}") for k, v in scores.items()],
|
||||||
|
header=("", "ROC AUC"),
|
||||||
|
aligns=("l", "r"),
|
||||||
|
title="Textcat ROC AUC (per label)",
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue