diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 0b1bd8ccf..4ee470606 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1490,6 +1490,7 @@ class EntityLinker(Pipe): def to_disk(self, path, exclude=tuple(), **kwargs): serialize = {} + self.cfg["entity_width"] = self.kb.entity_vector_length serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg) serialize["vocab"] = lambda p: self.vocab.to_disk(p) serialize["kb"] = lambda p: self.kb.dump(p) @@ -1561,6 +1562,11 @@ class Sentencizer(Pipe): def from_nlp(cls, nlp, model=None, **cfg): return cls(**cfg) + def begin_training( + self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs + ): + pass + def __call__(self, example): """Apply the sentencizer to a Doc and set Token.is_sent_start. diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 9ff5f8194..cdd8451fd 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -1,8 +1,11 @@ import pytest from spacy.kb import KnowledgeBase + +from spacy import util from spacy.lang.en import English from spacy.pipeline import EntityRuler +from spacy.tests.util import make_tempdir from spacy.tokens import Span @@ -245,3 +248,72 @@ def test_preserving_links_ents_2(nlp): assert len(list(doc.ents)) == 1 assert list(doc.ents)[0].label_ == "LOC" assert list(doc.ents)[0].kb_id_ == "Q1" + + +# fmt: off +TRAIN_DATA = [ + ("Russ Cochran captured his first major title with his son as caddie.", {"links": {(0, 12): {"Q7381115": 0.0, "Q2146908": 1.0}}}), + ("Russ Cochran his reprints include EC Comics.", {"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}}}), + ("Russ Cochran has been publishing comic art.", {"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}}}), + ("Russ Cochran was a member of University of Kentucky's golf team.", {"links": {(0, 12): {"Q7381115": 0.0, "Q2146908": 1.0}}}), +] +GOLD_entities = ["Q2146908", "Q7381115", "Q7381115", "Q2146908"] +# fmt: on + + +def test_overfitting_IO(): + # Simple test to try and quickly overfit the NEL component - ensuring the ML models work correctly + nlp = English() + nlp.add_pipe(nlp.create_pipe('sentencizer')) + + # Add a custom component to recognize "Russ Cochran" as an entity for the example training data + ruler = EntityRuler(nlp) + patterns = [{"label": "PERSON", "pattern": [{"LOWER": "russ"}, {"LOWER": "cochran"}]}] + ruler.add_patterns(patterns) + nlp.add_pipe(ruler) + + # Convert the texts to docs to make sure we have doc.ents set for the training examples + TRAIN_DOCS = [] + for text, annotation in TRAIN_DATA: + doc = nlp(text) + annotation_clean = annotation + TRAIN_DOCS.append((doc, annotation_clean)) + + # create artificial KB - assign same prior weight to the two russ cochran's + # Q2146908 (Russ Cochran): American golfer + # Q7381115 (Russ Cochran): publisher + mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) + mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3]) + mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7]) + mykb.add_alias(alias="Russ Cochran", entities=["Q2146908", "Q7381115"], probabilities=[0.5, 0.5]) + + # Create the Entity Linker component and add it to the pipeline + entity_linker = nlp.create_pipe("entity_linker") + entity_linker.set_kb(mykb) + nlp.add_pipe(entity_linker, last=True) + + # train the NEL pipe + optimizer = nlp.begin_training() + for i in range(50): + losses = {} + nlp.update(TRAIN_DOCS, sgd=optimizer, losses=losses) + assert losses["entity_linker"] < 0.001 + + # test the trained model + predictions = [] + for text, annotation in TRAIN_DATA: + doc = nlp(text) + for ent in doc.ents: + predictions.append(ent.kb_id_) + assert predictions == GOLD_entities + + # Also test the results are still the same after IO + with make_tempdir() as tmp_dir: + nlp.to_disk(tmp_dir) + nlp2 = util.load_model_from_path(tmp_dir) + predictions = [] + for text, annotation in TRAIN_DATA: + doc2 = nlp2(text) + for ent in doc2.ents: + predictions.append(ent.kb_id_) + assert predictions == GOLD_entities