mirror of https://github.com/explosion/spaCy.git
deuglify kb deserializer
This commit is contained in:
parent
8840d4b1b3
commit
668b17ea4a
|
@ -118,7 +118,7 @@ class Language(object):
|
|||
"tagger": lambda nlp, **cfg: Tagger(nlp.vocab, **cfg),
|
||||
"parser": lambda nlp, **cfg: DependencyParser(nlp.vocab, **cfg),
|
||||
"ner": lambda nlp, **cfg: EntityRecognizer(nlp.vocab, **cfg),
|
||||
"entity_linker": lambda nlp, **cfg: EntityLinker(**cfg),
|
||||
"entity_linker": lambda nlp, **cfg: EntityLinker(nlp.vocab, **cfg),
|
||||
"similarity": lambda nlp, **cfg: SimilarityHook(nlp.vocab, **cfg),
|
||||
"textcat": lambda nlp, **cfg: TextCategorizer(nlp.vocab, **cfg),
|
||||
"sentencizer": lambda nlp, **cfg: Sentencizer(**cfg),
|
||||
|
@ -811,13 +811,6 @@ class Language(object):
|
|||
exclude = list(exclude) + ["vocab"]
|
||||
util.from_disk(path, deserializers, exclude)
|
||||
|
||||
# download the KB for the entity linking component - requires the vocab
|
||||
for pipe_name, pipe in self.pipeline:
|
||||
if pipe_name == "entity_linker":
|
||||
kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=pipe.cfg["entity_width"])
|
||||
kb.load_bulk(path / pipe_name / "kb")
|
||||
pipe.set_kb(kb)
|
||||
|
||||
self._path = path
|
||||
return self
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from thinc.misc import LayerNorm
|
|||
from thinc.neural.util import to_categorical
|
||||
from thinc.neural.util import get_array_module
|
||||
|
||||
from spacy.kb import KnowledgeBase
|
||||
from ..cli.pretrain import get_cossim_loss
|
||||
from .functions import merge_subtokens
|
||||
from ..tokens.doc cimport Doc
|
||||
|
@ -1079,7 +1080,8 @@ class EntityLinker(Pipe):
|
|||
model = build_nel_encoder(embed_width=embed_width, hidden_width=hidden_width, ner_types=len(type_to_int), **cfg)
|
||||
return model
|
||||
|
||||
def __init__(self, **cfg):
|
||||
def __init__(self, vocab, **cfg):
|
||||
self.vocab = vocab
|
||||
self.model = True
|
||||
self.kb = None
|
||||
self.cfg = dict(cfg)
|
||||
|
@ -1277,6 +1279,7 @@ class EntityLinker(Pipe):
|
|||
def to_disk(self, path, exclude=tuple(), **kwargs):
|
||||
serialize = OrderedDict()
|
||||
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
||||
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
|
||||
serialize["kb"] = lambda p: self.kb.dump(p)
|
||||
if self.model not in (None, True, False):
|
||||
serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes())
|
||||
|
@ -1289,8 +1292,15 @@ class EntityLinker(Pipe):
|
|||
self.model = self.Model(**self.cfg)
|
||||
self.model.from_bytes(p.open("rb").read())
|
||||
|
||||
def load_kb(p):
|
||||
kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=self.cfg["entity_width"])
|
||||
kb.load_bulk(p)
|
||||
self.set_kb(kb)
|
||||
|
||||
deserialize = OrderedDict()
|
||||
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
|
||||
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
|
||||
deserialize["kb"] = load_kb
|
||||
deserialize["model"] = load_model
|
||||
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
|
||||
util.from_disk(path, deserialize, exclude)
|
||||
|
|
Loading…
Reference in New Issue