diff --git a/spacy/language.py b/spacy/language.py index 2225a763e..570630eb3 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -118,7 +118,7 @@ class Language(object): "tagger": lambda nlp, **cfg: Tagger(nlp.vocab, **cfg), "parser": lambda nlp, **cfg: DependencyParser(nlp.vocab, **cfg), "ner": lambda nlp, **cfg: EntityRecognizer(nlp.vocab, **cfg), - "entity_linker": lambda nlp, **cfg: EntityLinker(**cfg), + "entity_linker": lambda nlp, **cfg: EntityLinker(nlp.vocab, **cfg), "similarity": lambda nlp, **cfg: SimilarityHook(nlp.vocab, **cfg), "textcat": lambda nlp, **cfg: TextCategorizer(nlp.vocab, **cfg), "sentencizer": lambda nlp, **cfg: Sentencizer(**cfg), @@ -811,13 +811,6 @@ class Language(object): exclude = list(exclude) + ["vocab"] util.from_disk(path, deserializers, exclude) - # download the KB for the entity linking component - requires the vocab - for pipe_name, pipe in self.pipeline: - if pipe_name == "entity_linker": - kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=pipe.cfg["entity_width"]) - kb.load_bulk(path / pipe_name / "kb") - pipe.set_kb(kb) - self._path = path return self diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 91f5e7044..f4dc08251 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -13,6 +13,7 @@ from thinc.misc import LayerNorm from thinc.neural.util import to_categorical from thinc.neural.util import get_array_module +from spacy.kb import KnowledgeBase from ..cli.pretrain import get_cossim_loss from .functions import merge_subtokens from ..tokens.doc cimport Doc @@ -1079,7 +1080,8 @@ class EntityLinker(Pipe): model = build_nel_encoder(embed_width=embed_width, hidden_width=hidden_width, ner_types=len(type_to_int), **cfg) return model - def __init__(self, **cfg): + def __init__(self, vocab, **cfg): + self.vocab = vocab self.model = True self.kb = None self.cfg = dict(cfg) @@ -1277,6 +1279,7 @@ class EntityLinker(Pipe): def to_disk(self, path, exclude=tuple(), **kwargs): serialize = OrderedDict() serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg) + serialize["vocab"] = lambda p: self.vocab.to_disk(p) serialize["kb"] = lambda p: self.kb.dump(p) if self.model not in (None, True, False): serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes()) @@ -1289,8 +1292,15 @@ class EntityLinker(Pipe): self.model = self.Model(**self.cfg) self.model.from_bytes(p.open("rb").read()) + def load_kb(p): + kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=self.cfg["entity_width"]) + kb.load_bulk(p) + self.set_kb(kb) + deserialize = OrderedDict() deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p)) + deserialize["vocab"] = lambda p: self.vocab.from_disk(p) + deserialize["kb"] = load_kb deserialize["model"] = load_model exclude = util.get_serialization_exclude(deserialize, exclude, kwargs) util.from_disk(path, deserialize, exclude)