mirror of https://github.com/explosion/spaCy.git
Improve NER per type scoring (#4052)
* Improve NER per type scoring * include all gold labels in per type scoring, not only when recall > 0 * improve efficiency of per type scoring * Create Scorer tests, initially with NER tests * move regression test #3968 (per type NER scoring) to Scorer tests * add new test for per type NER scoring with imperfect P/R/F and per type P/R/F including a case where R == 0.0
This commit is contained in:
parent
f7d950de6d
commit
925a852bb6
|
@ -159,12 +159,19 @@ class Scorer(object):
|
||||||
else:
|
else:
|
||||||
cand_deps.add((gold_i, gold_head, token.dep_.lower()))
|
cand_deps.add((gold_i, gold_head, token.dep_.lower()))
|
||||||
if "-" not in [token[-1] for token in gold.orig_annot]:
|
if "-" not in [token[-1] for token in gold.orig_annot]:
|
||||||
|
# Find all NER labels in gold and doc
|
||||||
|
ent_labels = set([x[0] for x in gold_ents]
|
||||||
|
+ [k.label_ for k in doc.ents])
|
||||||
|
# Set up all labels for per type scoring and prepare gold per type
|
||||||
|
gold_per_ents = {ent_label: set() for ent_label in ent_labels}
|
||||||
|
for ent_label in ent_labels:
|
||||||
|
if ent_label not in self.ner_per_ents:
|
||||||
|
self.ner_per_ents[ent_label] = PRFScore()
|
||||||
|
gold_per_ents[ent_label].update([x for x in gold_ents if x[0] == ent_label])
|
||||||
|
# Find all candidate labels, for all and per type
|
||||||
cand_ents = set()
|
cand_ents = set()
|
||||||
current_ent = {k.label_: set() for k in doc.ents}
|
cand_per_ents = {ent_label: set() for ent_label in ent_labels}
|
||||||
current_gold = {k.label_: set() for k in doc.ents}
|
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
if ent.label_ not in self.ner_per_ents:
|
|
||||||
self.ner_per_ents[ent.label_] = PRFScore()
|
|
||||||
first = gold.cand_to_gold[ent.start]
|
first = gold.cand_to_gold[ent.start]
|
||||||
last = gold.cand_to_gold[ent.end - 1]
|
last = gold.cand_to_gold[ent.end - 1]
|
||||||
if first is None or last is None:
|
if first is None or last is None:
|
||||||
|
@ -172,14 +179,11 @@ class Scorer(object):
|
||||||
self.ner_per_ents[ent.label_].fp += 1
|
self.ner_per_ents[ent.label_].fp += 1
|
||||||
else:
|
else:
|
||||||
cand_ents.add((ent.label_, first, last))
|
cand_ents.add((ent.label_, first, last))
|
||||||
current_ent[ent.label_].update([x for x in cand_ents if x[0] == ent.label_])
|
cand_per_ents[ent.label_].add((ent.label_, first, last))
|
||||||
current_gold[ent.label_].update([x for x in gold_ents if x[0] == ent.label_])
|
|
||||||
# Scores per ent
|
# Scores per ent
|
||||||
[
|
for k, v in self.ner_per_ents.items():
|
||||||
v.score_set(current_ent[k], current_gold[k])
|
if k in cand_per_ents:
|
||||||
for k, v in self.ner_per_ents.items()
|
v.score_set(cand_per_ents[k], gold_per_ents[k])
|
||||||
if k in current_ent
|
|
||||||
]
|
|
||||||
# Score for all ents
|
# Score for all ents
|
||||||
self.ner.score_set(cand_ents, gold_ents)
|
self.ner.score_set(cand_ents, gold_ents)
|
||||||
self.tags.score_set(cand_tags, gold_tags)
|
self.tags.score_set(cand_tags, gold_tags)
|
||||||
|
|
|
@ -1,34 +0,0 @@
|
||||||
# coding: utf-8
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
|
|
||||||
from spacy.gold import GoldParse
|
|
||||||
from spacy.scorer import Scorer
|
|
||||||
from ..util import get_doc
|
|
||||||
|
|
||||||
test_samples = [
|
|
||||||
[
|
|
||||||
"100 - 200",
|
|
||||||
{
|
|
||||||
"entities": [
|
|
||||||
[0, 3, "CARDINAL"],
|
|
||||||
[6, 9, "CARDINAL"]
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
]
|
|
||||||
|
|
||||||
def test_issue3625(en_vocab):
|
|
||||||
scorer = Scorer()
|
|
||||||
for input_, annot in test_samples:
|
|
||||||
doc = get_doc(en_vocab, words = input_.split(' '), ents = [[0,1,'CARDINAL'], [2,3,'CARDINAL']]);
|
|
||||||
gold = GoldParse(doc, entities = annot['entities'])
|
|
||||||
scorer.score(doc, gold)
|
|
||||||
results = scorer.scores
|
|
||||||
|
|
||||||
# Expects total accuracy and accuracy for each each entity to be 100%
|
|
||||||
assert results['ents_p'] == 100
|
|
||||||
assert results['ents_f'] == 100
|
|
||||||
assert results['ents_r'] == 100
|
|
||||||
assert results['ents_per_type']['CARDINAL']['p'] == 100
|
|
||||||
assert results['ents_per_type']['CARDINAL']['f'] == 100
|
|
||||||
assert results['ents_per_type']['CARDINAL']['r'] == 100
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
# coding: utf-8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from pytest import approx
|
||||||
|
from spacy.gold import GoldParse
|
||||||
|
from spacy.scorer import Scorer
|
||||||
|
from .util import get_doc
|
||||||
|
|
||||||
|
test_ner_cardinal = [
|
||||||
|
[
|
||||||
|
"100 - 200",
|
||||||
|
{
|
||||||
|
"entities": [
|
||||||
|
[0, 3, "CARDINAL"],
|
||||||
|
[6, 9, "CARDINAL"]
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
test_ner_apple = [
|
||||||
|
[
|
||||||
|
"Apple is looking at buying U.K. startup for $1 billion",
|
||||||
|
{
|
||||||
|
"entities": [
|
||||||
|
(0, 5, "ORG"),
|
||||||
|
(27, 31, "GPE"),
|
||||||
|
(44, 54, "MONEY"),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_ner_per_type(en_vocab):
|
||||||
|
# Gold and Doc are identical
|
||||||
|
scorer = Scorer()
|
||||||
|
for input_, annot in test_ner_cardinal:
|
||||||
|
doc = get_doc(en_vocab, words = input_.split(' '), ents = [[0, 1, 'CARDINAL'], [2, 3, 'CARDINAL']])
|
||||||
|
gold = GoldParse(doc, entities = annot['entities'])
|
||||||
|
scorer.score(doc, gold)
|
||||||
|
results = scorer.scores
|
||||||
|
|
||||||
|
assert results['ents_p'] == 100
|
||||||
|
assert results['ents_f'] == 100
|
||||||
|
assert results['ents_r'] == 100
|
||||||
|
assert results['ents_per_type']['CARDINAL']['p'] == 100
|
||||||
|
assert results['ents_per_type']['CARDINAL']['f'] == 100
|
||||||
|
assert results['ents_per_type']['CARDINAL']['r'] == 100
|
||||||
|
|
||||||
|
# Doc has one missing and one extra entity
|
||||||
|
# Entity type MONEY is not present in Doc
|
||||||
|
scorer = Scorer()
|
||||||
|
for input_, annot in test_ner_apple:
|
||||||
|
doc = get_doc(en_vocab, words = input_.split(' '), ents = [[0, 1, 'ORG'], [5, 6, 'GPE'], [6, 7, 'ORG']])
|
||||||
|
gold = GoldParse(doc, entities = annot['entities'])
|
||||||
|
scorer.score(doc, gold)
|
||||||
|
results = scorer.scores
|
||||||
|
|
||||||
|
assert results['ents_p'] == approx(66.66666)
|
||||||
|
assert results['ents_r'] == approx(66.66666)
|
||||||
|
assert results['ents_f'] == approx(66.66666)
|
||||||
|
assert 'GPE' in results['ents_per_type']
|
||||||
|
assert 'MONEY' in results['ents_per_type']
|
||||||
|
assert 'ORG' in results['ents_per_type']
|
||||||
|
assert results['ents_per_type']['GPE']['p'] == 100
|
||||||
|
assert results['ents_per_type']['GPE']['r'] == 100
|
||||||
|
assert results['ents_per_type']['GPE']['f'] == 100
|
||||||
|
assert results['ents_per_type']['MONEY']['p'] == 0
|
||||||
|
assert results['ents_per_type']['MONEY']['r'] == 0
|
||||||
|
assert results['ents_per_type']['MONEY']['f'] == 0
|
||||||
|
assert results['ents_per_type']['ORG']['p'] == 50
|
||||||
|
assert results['ents_per_type']['ORG']['r'] == 100
|
||||||
|
assert results['ents_per_type']['ORG']['f'] == approx(66.66666)
|
Loading…
Reference in New Issue