From 6be09bbd07243b546ed52d5a03cdf54c8e028566 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Tue, 24 May 2022 03:42:26 +0900 Subject: [PATCH] Fix Entity Linker with tokenization mismatches (fix #9575) (#10457) * Add failing test * Partial fix for issue This kind of works. The issue with token length mismatches is gone. The problem is that when you get empty lists of encodings to compare, it fails because the sizes are not the same, even though they're both zero: (0, 3) vs (0,). Not sure why that happens... * Short circuit on empties * Remove spurious check The check here isn't needed now the the short circuit is fixed. * Update spacy/tests/pipeline/test_entity_linker.py Co-authored-by: Sofie Van Landeghem * Use "eg", not "example" Co-authored-by: Sofie Van Landeghem --- spacy/pipeline/entity_linker.py | 18 ++++++++----- spacy/tests/pipeline/test_entity_linker.py | 31 +++++++++++++++++++++- 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 89e7576bf..12c3e382f 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -234,10 +234,11 @@ class EntityLinker(TrainablePipe): nO = self.kb.entity_vector_length doc_sample = [] vector_sample = [] - for example in islice(get_examples(), 10): - doc = example.x + for eg in islice(get_examples(), 10): + doc = eg.x if self.use_gold_ents: - doc.ents = example.y.ents + ents, _ = eg.get_aligned_ents_and_ner() + doc.ents = ents doc_sample.append(doc) vector_sample.append(self.model.ops.alloc1f(nO)) assert len(doc_sample) > 0, Errors.E923.format(name=self.name) @@ -312,7 +313,8 @@ class EntityLinker(TrainablePipe): for doc, ex in zip(docs, examples): if self.use_gold_ents: - doc.ents = ex.reference.ents + ents, _ = ex.get_aligned_ents_and_ner() + doc.ents = ents else: # only keep matching ents doc.ents = ex.get_matching_ents() @@ -345,7 +347,7 @@ class EntityLinker(TrainablePipe): for eg in examples: kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True) - for ent in eg.reference.ents: + for ent in eg.get_matching_ents(): kb_id = kb_ids[ent.start] if kb_id: entity_encoding = self.kb.get_vector(kb_id) @@ -356,7 +358,11 @@ class EntityLinker(TrainablePipe): entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32") selected_encodings = sentence_encodings[keep_ents] - # If the entity encodings list is empty, then + # if there are no matches, short circuit + if not keep_ents: + out = self.model.ops.alloc2f(*sentence_encodings.shape) + return 0, out + if selected_encodings.shape != entity_encodings.shape: err = Errors.E147.format( method="get_loss", msg="gold entities do not match up" diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 83d5bf0e2..ccf26f890 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -14,7 +14,7 @@ from spacy.pipeline.legacy import EntityLinker_v1 from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL from spacy.scorer import Scorer from spacy.tests.util import make_tempdir -from spacy.tokens import Span +from spacy.tokens import Span, Doc from spacy.training import Example from spacy.util import ensure_path from spacy.vocab import Vocab @@ -1075,3 +1075,32 @@ def test_no_gold_ents(patterns): # this will run the pipeline on the examples and shouldn't crash results = nlp.evaluate(train_examples) + +@pytest.mark.issue(9575) +def test_tokenization_mismatch(): + nlp = English() + # include a matching entity so that update isn't skipped + doc1 = Doc(nlp.vocab, words=["Kirby", "123456"], spaces=[True, False], ents=["B-CHARACTER", "B-CARDINAL"]) + doc2 = Doc(nlp.vocab, words=["Kirby", "123", "456"], spaces=[True, False, False], ents=["B-CHARACTER", "B-CARDINAL", "B-CARDINAL"]) + + eg = Example(doc1, doc2) + train_examples = [eg] + vector_length = 3 + + def create_kb(vocab): + # create placeholder KB + mykb = KnowledgeBase(vocab, entity_vector_length=vector_length) + mykb.add_entity(entity="Q613241", freq=12, entity_vector=[6, -4, 3]) + mykb.add_alias("Kirby", ["Q613241"], [0.9]) + return mykb + + entity_linker = nlp.add_pipe("entity_linker", last=True) + entity_linker.set_kb(create_kb) + + optimizer = nlp.initialize(get_examples=lambda: train_examples) + for i in range(2): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) + + nlp.add_pipe("sentencizer", first=True) + results = nlp.evaluate(train_examples)