From 7ecb52c0edb1ec9f48e4de9ffc4bbed4ab5d21f2 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 10 Mar 2015 21:07:03 -0400 Subject: [PATCH] * Add scorer script --- spacy/scorer.py | 66 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 spacy/scorer.py diff --git a/spacy/scorer.py b/spacy/scorer.py new file mode 100644 index 000000000..586976bad --- /dev/null +++ b/spacy/scorer.py @@ -0,0 +1,66 @@ +from __future__ import division + +class Scorer(object): + def __init__(self, eval_punct=False): + self.heads_corr = 0 + self.labels_corr = 0 + self.tags_corr = 0 + self.ents_tp = 0 + self.ents_fp = 0 + self.ents_fn = 0 + self.total = 1e-100 + self.eval_punct = eval_punct + + @property + def tags_acc(self): + return (self.tags_corr / self.total) * 100 + + @property + def uas(self): + return (self.heads_corr / self.total) * 100 + + @property + def las(self): + return (self.labels_corr / self.total) * 100 + + @property + def ents_p(self): + return (self.ents_tp / (self.ents_tp + self.ents_fp + 1e-100)) * 100 + + @property + def ents_r(self): + return (self.ents_tp / (self.ents_tp + self.ents_fn + 1e-100)) * 100 + + @property + def ents_f(self): + return (2 * self.ents_p * self.ents_r) / (self.ents_p + self.ents_r + 1e-100) + + def score(self, tokens, gold, verbose=False): + assert len(tokens) == len(gold) + + for i, token in enumerate(tokens): + if not self.skip_token(i, token, gold): + self.total += 1 + if token.head.i == gold.heads[i]: + self.heads_corr += 1 + self.labels_corr += token.dep_ == gold.labels[i] + self.tags_corr += token.tag_ == gold.tags[i] + gold_ents = set((start, end, label) for (start, end, label) in gold.ents) + guess_ents = set(tokens.ents) + if verbose: + for start, end, label in guess_ents: + mark = 'T' if (start, end, label) in gold_ents else 'F' + ent_str = ' '.join(tokens[i].orth_ for i in range(start, end)) + print mark, label, ent_str + for start, end, label in gold_ents: + if (start, end, label) not in guess_ents: + ent_str = ' '.join(tokens[i].orth_ for i in range(start, end)) + print 'M', label, ent_str + print + if gold_ents: + self.ents_tp += len(gold_ents.intersection(guess_ents)) + self.ents_fn += len(gold_ents - guess_ents) + self.ents_fp += len(guess_ents - gold_ents) + + def skip_token(self, i, token, gold): + return gold.labels[i] == 'P'