diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 89e7576bf..12c3e382f 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -234,10 +234,11 @@ class EntityLinker(TrainablePipe): nO = self.kb.entity_vector_length doc_sample = [] vector_sample = [] - for example in islice(get_examples(), 10): - doc = example.x + for eg in islice(get_examples(), 10): + doc = eg.x if self.use_gold_ents: - doc.ents = example.y.ents + ents, _ = eg.get_aligned_ents_and_ner() + doc.ents = ents doc_sample.append(doc) vector_sample.append(self.model.ops.alloc1f(nO)) assert len(doc_sample) > 0, Errors.E923.format(name=self.name) @@ -312,7 +313,8 @@ class EntityLinker(TrainablePipe): for doc, ex in zip(docs, examples): if self.use_gold_ents: - doc.ents = ex.reference.ents + ents, _ = ex.get_aligned_ents_and_ner() + doc.ents = ents else: # only keep matching ents doc.ents = ex.get_matching_ents() @@ -345,7 +347,7 @@ class EntityLinker(TrainablePipe): for eg in examples: kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True) - for ent in eg.reference.ents: + for ent in eg.get_matching_ents(): kb_id = kb_ids[ent.start] if kb_id: entity_encoding = self.kb.get_vector(kb_id) @@ -356,7 +358,11 @@ class EntityLinker(TrainablePipe): entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32") selected_encodings = sentence_encodings[keep_ents] - # If the entity encodings list is empty, then + # if there are no matches, short circuit + if not keep_ents: + out = self.model.ops.alloc2f(*sentence_encodings.shape) + return 0, out + if selected_encodings.shape != entity_encodings.shape: err = Errors.E147.format( method="get_loss", msg="gold entities do not match up" diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 83d5bf0e2..ccf26f890 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -14,7 +14,7 @@ from spacy.pipeline.legacy import EntityLinker_v1 from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL from spacy.scorer import Scorer from spacy.tests.util import make_tempdir -from spacy.tokens import Span +from spacy.tokens import Span, Doc from spacy.training import Example from spacy.util import ensure_path from spacy.vocab import Vocab @@ -1075,3 +1075,32 @@ def test_no_gold_ents(patterns): # this will run the pipeline on the examples and shouldn't crash results = nlp.evaluate(train_examples) + +@pytest.mark.issue(9575) +def test_tokenization_mismatch(): + nlp = English() + # include a matching entity so that update isn't skipped + doc1 = Doc(nlp.vocab, words=["Kirby", "123456"], spaces=[True, False], ents=["B-CHARACTER", "B-CARDINAL"]) + doc2 = Doc(nlp.vocab, words=["Kirby", "123", "456"], spaces=[True, False, False], ents=["B-CHARACTER", "B-CARDINAL", "B-CARDINAL"]) + + eg = Example(doc1, doc2) + train_examples = [eg] + vector_length = 3 + + def create_kb(vocab): + # create placeholder KB + mykb = KnowledgeBase(vocab, entity_vector_length=vector_length) + mykb.add_entity(entity="Q613241", freq=12, entity_vector=[6, -4, 3]) + mykb.add_alias("Kirby", ["Q613241"], [0.9]) + return mykb + + entity_linker = nlp.add_pipe("entity_linker", last=True) + entity_linker.set_kb(create_kb) + + optimizer = nlp.initialize(get_examples=lambda: train_examples) + for i in range(2): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) + + nlp.add_pipe("sentencizer", first=True) + results = nlp.evaluate(train_examples)