spaCy/spacy/scorer.py

171 lines
5.6 KiB
Python
Raw Normal View History

# coding: utf8
from __future__ import division, print_function, unicode_literals
2015-03-11 01:07:03 +00:00
from .gold import tags_to_entities, GoldParse
2015-05-27 01:18:16 +00:00
2015-04-05 20:29:30 +00:00
class PRFScore(object):
"""
A precision / recall / F score
"""
def __init__(self):
self.tp = 0
self.fp = 0
self.fn = 0
def score_set(self, cand, gold):
self.tp += len(cand.intersection(gold))
self.fp += len(cand - gold)
self.fn += len(gold - cand)
@property
def precision(self):
return self.tp / (self.tp + self.fp + 1e-100)
@property
def recall(self):
return self.tp / (self.tp + self.fn + 1e-100)
@property
def fscore(self):
p = self.precision
r = self.recall
return 2 * ((p * r) / (p + r + 1e-100))
2015-03-11 01:07:03 +00:00
class Scorer(object):
2019-05-24 12:06:04 +00:00
"""Compute evaluation scores."""
2015-03-11 01:07:03 +00:00
def __init__(self, eval_punct=False):
2019-05-24 12:06:04 +00:00
"""Initialize the Scorer.
eval_punct (bool): Evaluate the dependency attachments to and from
punctuation.
RETURNS (Scorer): The newly created object.
DOCS: https://spacy.io/api/scorer#init
"""
self.tokens = PRFScore()
self.sbd = PRFScore()
self.unlabelled = PRFScore()
self.labelled = PRFScore()
self.tags = PRFScore()
self.ner = PRFScore()
2015-03-11 01:07:03 +00:00
self.eval_punct = eval_punct
@property
def tags_acc(self):
2019-05-24 12:06:04 +00:00
"""RETURNS (float): Part-of-speech tag accuracy (fine grained tags,
i.e. `Token.tag`).
"""
return self.tags.fscore * 100
@property
def token_acc(self):
2019-05-24 12:06:04 +00:00
"""RETURNS (float): Tokenization accuracy."""
2015-06-28 04:21:38 +00:00
return self.tokens.precision * 100
2015-03-11 01:07:03 +00:00
@property
def uas(self):
2019-05-24 12:06:04 +00:00
"""RETURNS (float): Unlabelled dependency score."""
return self.unlabelled.fscore * 100
2015-03-11 01:07:03 +00:00
@property
def las(self):
2019-05-24 12:06:04 +00:00
"""RETURNS (float): Labelled depdendency score."""
return self.labelled.fscore * 100
2015-03-11 01:07:03 +00:00
@property
def ents_p(self):
2019-05-24 12:06:04 +00:00
"""RETURNS (float): Named entity accuracy (precision)."""
2015-05-27 01:18:16 +00:00
return self.ner.precision * 100
2015-03-11 01:07:03 +00:00
@property
def ents_r(self):
2019-05-24 12:06:04 +00:00
"""RETURNS (float): Named entity accuracy (recall)."""
2015-05-27 01:18:16 +00:00
return self.ner.recall * 100
2015-04-19 08:31:31 +00:00
2015-03-11 01:07:03 +00:00
@property
def ents_f(self):
2019-05-24 12:06:04 +00:00
"""RETURNS (float): Named entity accuracy (F-score)."""
2015-05-27 01:18:16 +00:00
return self.ner.fscore * 100
2015-03-11 01:07:03 +00:00
@property
def scores(self):
2019-05-24 12:06:04 +00:00
"""RETURNS (dict): All scores with keys `uas`, `las`, `ents_p`,
`ents_r`, `ents_f`, `tags_acc` and `token_acc`.
"""
return {
"uas": self.uas,
"las": self.las,
"ents_p": self.ents_p,
"ents_r": self.ents_r,
"ents_f": self.ents_f,
"tags_acc": self.tags_acc,
"token_acc": self.token_acc,
}
2019-05-24 12:06:04 +00:00
def score(self, doc, gold, verbose=False, punct_labels=("p", "punct")):
"""Update the evaluation scores from a single Doc / GoldParse pair.
doc (Doc): The predicted annotations.
gold (GoldParse): The correct annotations.
verbose (bool): Print debugging information.
punct_labels (tuple): Dependency labels for punctuation. Used to
evaluate dependency attachments to punctuation if `eval_punct` is
`True`.
DOCS: https://spacy.io/api/scorer#score
"""
if len(doc) != len(gold):
gold = GoldParse.from_annot_tuples(doc, zip(*gold.orig_annot))
gold_deps = set()
gold_tags = set()
gold_ents = set(tags_to_entities([annot[-1] for annot in gold.orig_annot]))
for id_, word, tag, head, dep, ner in gold.orig_annot:
gold_tags.add((id_, tag))
if dep not in (None, "") and dep.lower() not in punct_labels:
2015-05-27 01:18:16 +00:00
gold_deps.add((id_, head, dep.lower()))
cand_deps = set()
cand_tags = set()
2019-05-24 12:06:04 +00:00
for token in doc:
2015-06-07 17:10:32 +00:00
if token.orth_.isspace():
continue
gold_i = gold.cand_to_gold[token.i]
if gold_i is None:
self.tokens.fp += 1
else:
2015-06-28 04:21:38 +00:00
self.tokens.tp += 1
cand_tags.add((gold_i, token.tag_))
if token.dep_.lower() not in punct_labels and token.orth_.strip():
gold_head = gold.cand_to_gold[token.head.i]
# None is indistinct, so we can't just add it to the set
# Multiple (None, None) deps are possible
if gold_i is None or gold_head is None:
self.unlabelled.fp += 1
self.labelled.fp += 1
else:
2015-05-27 01:18:16 +00:00
cand_deps.add((gold_i, gold_head, token.dep_.lower()))
if "-" not in [token[-1] for token in gold.orig_annot]:
cand_ents = set()
2019-05-24 12:06:04 +00:00
for ent in doc.ents:
first = gold.cand_to_gold[ent.start]
last = gold.cand_to_gold[ent.end - 1]
if first is None or last is None:
self.ner.fp += 1
else:
cand_ents.add((ent.label_, first, last))
self.ner.score_set(cand_ents, gold_ents)
2015-05-27 01:18:16 +00:00
self.tags.score_set(cand_tags, gold_tags)
self.labelled.score_set(cand_deps, gold_deps)
self.unlabelled.score_set(
set(item[:2] for item in cand_deps), set(item[:2] for item in gold_deps)
)
2015-06-14 15:45:50 +00:00
if verbose:
gold_words = [item[1] for item in gold.orig_annot]
for w_id, h_id, dep in cand_deps - gold_deps:
print("F", gold_words[w_id], dep, gold_words[h_id])
for w_id, h_id, dep in gold_deps - cand_deps:
print("M", gold_words[w_id], dep, gold_words[h_id])