* Accept punct_labels as an argument to the scorer

This commit is contained in:
Matthew Honnibal 2016-02-02 22:59:06 +01:00
parent e2ed6251d7
commit 99b8906100
1 changed files with 4 additions and 4 deletions

View File

@ -70,7 +70,7 @@ class Scorer(object):
def ents_f(self): def ents_f(self):
return self.ner.fscore * 100 return self.ner.fscore * 100
def score(self, tokens, gold, verbose=False): def score(self, tokens, gold, verbose=False, punct_labels=('p', 'punct')):
assert len(tokens) == len(gold) assert len(tokens) == len(gold)
gold_deps = set() gold_deps = set()
@ -78,7 +78,7 @@ class Scorer(object):
gold_ents = set(tags_to_entities([annot[-1] for annot in gold.orig_annot])) gold_ents = set(tags_to_entities([annot[-1] for annot in gold.orig_annot]))
for id_, word, tag, head, dep, ner in gold.orig_annot: for id_, word, tag, head, dep, ner in gold.orig_annot:
gold_tags.add((id_, tag)) gold_tags.add((id_, tag))
if dep.lower() not in ('p', 'punct'): if dep.lower() not in punct_labels:
gold_deps.add((id_, head, dep.lower())) gold_deps.add((id_, head, dep.lower()))
cand_deps = set() cand_deps = set()
cand_tags = set() cand_tags = set()
@ -87,12 +87,12 @@ class Scorer(object):
continue continue
gold_i = gold.cand_to_gold[token.i] gold_i = gold.cand_to_gold[token.i]
if gold_i is None: if gold_i is None:
if token.dep_.lower() not in ('p', 'punct'): if token.dep_.lower() not in punct_labels:
self.tokens.fp += 1 self.tokens.fp += 1
else: else:
self.tokens.tp += 1 self.tokens.tp += 1
cand_tags.add((gold_i, token.tag_)) cand_tags.add((gold_i, token.tag_))
if token.dep_.lower() not in ('p', 'punct') and token.orth_.strip(): if token.dep_.lower() not in punct_labels and token.orth_.strip():
gold_head = gold.cand_to_gold[token.head.i] gold_head = gold.cand_to_gold[token.head.i]
# None is indistinct, so we can't just add it to the set # None is indistinct, so we can't just add it to the set
# Multiple (None, None) deps are possible # Multiple (None, None) deps are possible