mirror of https://github.com/explosion/spaCy.git
* Add failing test * Partial fix for issue This kind of works. The issue with token length mismatches is gone. The problem is that when you get empty lists of encodings to compare, it fails because the sizes are not the same, even though they're both zero: (0, 3) vs (0,). Not sure why that happens... * Short circuit on empties * Remove spurious check The check here isn't needed now the the short circuit is fixed. * Update spacy/tests/pipeline/test_entity_linker.py Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com> * Use "eg", not "example" Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
parent
1d34aa2b3d
commit
6be09bbd07
|
@ -234,10 +234,11 @@ class EntityLinker(TrainablePipe):
|
|||
nO = self.kb.entity_vector_length
|
||||
doc_sample = []
|
||||
vector_sample = []
|
||||
for example in islice(get_examples(), 10):
|
||||
doc = example.x
|
||||
for eg in islice(get_examples(), 10):
|
||||
doc = eg.x
|
||||
if self.use_gold_ents:
|
||||
doc.ents = example.y.ents
|
||||
ents, _ = eg.get_aligned_ents_and_ner()
|
||||
doc.ents = ents
|
||||
doc_sample.append(doc)
|
||||
vector_sample.append(self.model.ops.alloc1f(nO))
|
||||
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||
|
@ -312,7 +313,8 @@ class EntityLinker(TrainablePipe):
|
|||
|
||||
for doc, ex in zip(docs, examples):
|
||||
if self.use_gold_ents:
|
||||
doc.ents = ex.reference.ents
|
||||
ents, _ = ex.get_aligned_ents_and_ner()
|
||||
doc.ents = ents
|
||||
else:
|
||||
# only keep matching ents
|
||||
doc.ents = ex.get_matching_ents()
|
||||
|
@ -345,7 +347,7 @@ class EntityLinker(TrainablePipe):
|
|||
for eg in examples:
|
||||
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
||||
|
||||
for ent in eg.reference.ents:
|
||||
for ent in eg.get_matching_ents():
|
||||
kb_id = kb_ids[ent.start]
|
||||
if kb_id:
|
||||
entity_encoding = self.kb.get_vector(kb_id)
|
||||
|
@ -356,7 +358,11 @@ class EntityLinker(TrainablePipe):
|
|||
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
||||
selected_encodings = sentence_encodings[keep_ents]
|
||||
|
||||
# If the entity encodings list is empty, then
|
||||
# if there are no matches, short circuit
|
||||
if not keep_ents:
|
||||
out = self.model.ops.alloc2f(*sentence_encodings.shape)
|
||||
return 0, out
|
||||
|
||||
if selected_encodings.shape != entity_encodings.shape:
|
||||
err = Errors.E147.format(
|
||||
method="get_loss", msg="gold entities do not match up"
|
||||
|
|
|
@ -14,7 +14,7 @@ from spacy.pipeline.legacy import EntityLinker_v1
|
|||
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
||||
from spacy.scorer import Scorer
|
||||
from spacy.tests.util import make_tempdir
|
||||
from spacy.tokens import Span
|
||||
from spacy.tokens import Span, Doc
|
||||
from spacy.training import Example
|
||||
from spacy.util import ensure_path
|
||||
from spacy.vocab import Vocab
|
||||
|
@ -1075,3 +1075,32 @@ def test_no_gold_ents(patterns):
|
|||
|
||||
# this will run the pipeline on the examples and shouldn't crash
|
||||
results = nlp.evaluate(train_examples)
|
||||
|
||||
@pytest.mark.issue(9575)
|
||||
def test_tokenization_mismatch():
|
||||
nlp = English()
|
||||
# include a matching entity so that update isn't skipped
|
||||
doc1 = Doc(nlp.vocab, words=["Kirby", "123456"], spaces=[True, False], ents=["B-CHARACTER", "B-CARDINAL"])
|
||||
doc2 = Doc(nlp.vocab, words=["Kirby", "123", "456"], spaces=[True, False, False], ents=["B-CHARACTER", "B-CARDINAL", "B-CARDINAL"])
|
||||
|
||||
eg = Example(doc1, doc2)
|
||||
train_examples = [eg]
|
||||
vector_length = 3
|
||||
|
||||
def create_kb(vocab):
|
||||
# create placeholder KB
|
||||
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
|
||||
mykb.add_entity(entity="Q613241", freq=12, entity_vector=[6, -4, 3])
|
||||
mykb.add_alias("Kirby", ["Q613241"], [0.9])
|
||||
return mykb
|
||||
|
||||
entity_linker = nlp.add_pipe("entity_linker", last=True)
|
||||
entity_linker.set_kb(create_kb)
|
||||
|
||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||
for i in range(2):
|
||||
losses = {}
|
||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
|
||||
nlp.add_pipe("sentencizer", first=True)
|
||||
results = nlp.evaluate(train_examples)
|
||||
|
|
Loading…
Reference in New Issue