diff --git a/spacy/scorer.py b/spacy/scorer.py index 253c1bd1a..1d27375d2 100644 --- a/spacy/scorer.py +++ b/spacy/scorer.py @@ -1,78 +1,102 @@ from __future__ import division +class PRFScore(object): + """A precision / recall / F score""" + def __init__(self): + self.tp = 0 + self.fp = 0 + self.fn = 0 + + def score_set(self, cand, gold): + self.tp += len(cand.intersection(gold)) + self.fp += len(cand - gold) + self.fn += len(gold - cand) + + @property + def precision(self): + return self.tp / (self.tp + self.fp + 1e-100) + + @property + def recall(self): + return self.tp / (self.tp + self.fn + 1e-100) + + @property + def fscore(self): + p = self.precision + r = self.recall + return 2 * ((p * r) / (p + r + 1e-100)) + + class Scorer(object): def __init__(self, eval_punct=False): - self.heads_corr = 0 - self.labels_corr = 0 - self.tags_corr = 0 - self.ents_tp = 0 - self.ents_fp = 0 - self.ents_fn = 0 - self.total = 1e-100 - self.mistokened = 0 - self.n_tokens = 0 + self.tokens = PRFScore() + self.sbd = PRFScore() + self.unlabelled = PRFScore() + self.labelled = PRFScore() + self.tags = PRFScore() + self.ner = PRFScore() self.eval_punct = eval_punct @property def tags_acc(self): - return (self.tags_corr / (self.n_tokens - self.mistokened)) * 100 + return self.tags.fscore * 100 @property def token_acc(self): - return (self.mistokened / self.n_tokens) * 100 - + return self.tokens.fscore * 100 @property def uas(self): - return (self.heads_corr / self.total) * 100 + return self.unlabelled.fscore * 100 @property def las(self): - return (self.labels_corr / self.total) * 100 + return self.labelled.fscore * 100 @property def ents_p(self): - return (self.ents_tp / (self.ents_tp + self.ents_fp + 1e-100)) * 100 + return self.ner.precision @property def ents_r(self): - return (self.ents_tp / (self.ents_tp + self.ents_fn + 1e-100)) * 100 + return self.ner.recall @property def ents_f(self): - return (2 * self.ents_p * self.ents_r) / (self.ents_p + self.ents_r + 1e-100) + return self.ner.fscore def score(self, tokens, gold, verbose=False): assert len(tokens) == len(gold) - for i, token in enumerate(tokens): - if not self.skip_token(i, token, gold): - self.total += 1 - if verbose: - print token.orth_, token.tag_, token.dep_, token.head.orth_, token.head.i == gold.heads[i] - if token.head.i == gold.heads[i]: - self.heads_corr += 1 - self.labels_corr += token.dep_.lower() == gold.labels[i].lower() - if gold.tags[i] != None: - self.tags_corr += token.tag_ == gold.tags[i] - self.n_tokens += 1 - gold_ents = set((start, end, label) for (start, end, label) in gold.ents) - guess_ents = set((e.start, e.end, e.label_) for e in tokens.ents) - if verbose and gold_ents: - for start, end, label in guess_ents: - mark = 'T' if (start, end, label) in gold_ents else 'F' - ent_str = ' '.join(tokens[i].orth_ for i in range(start, end)) - print mark, label, ent_str - for start, end, label in gold_ents: - if (start, end, label) not in guess_ents: - ent_str = ' '.join(tokens[i].orth_ for i in range(start, end)) - print 'M', label, ent_str - print - if gold_ents: - self.ents_tp += len(gold_ents.intersection(guess_ents)) - self.ents_fn += len(gold_ents - guess_ents) - self.ents_fp += len(guess_ents - gold_ents) + gold_deps = set() + gold_tags = set() + gold_tags = set() + for id_, word, tag, head, dep, ner in gold.orig_annot: + if dep.lower() not in ('p', 'punct'): + gold_deps.add((id_, head, dep)) + gold_tags.add((id_, tag)) + cand_deps = set() + cand_tags = set() + for token in tokens: + if token.dep_ not in ('p', 'punct') and token.orth_.strip(): + gold_i = gold.cand_to_gold[token.i] + gold_head = gold.cand_to_gold[token.head.i] + # None is indistinct, so we can't just add it to the set + # Multiple (None, None) deps are possible + if gold_i is None or gold_head is None: + self.unlabelled.fp += 1 + self.labelled.fp += 1 + else: + cand_deps.add((gold_i, gold_head, token.dep_)) + if gold_i is None: + self.tags.fp += 1 + else: + cand_tags.add((gold_i, token.tag_)) - def skip_token(self, i, token, gold): - return gold.labels[i] in ('P', 'punct') or gold.heads[i] == None + self.tags.score_set(cand_tags, cand_deps) + self.labelled.score_set(cand_deps, gold_deps) + self.unlabelled.score_set( + set(item[:2] for item in cand_deps), + set(item[:2] for item in gold_deps), + )