Fix loading models with pretrained vectors

This commit is contained in:
Matthew Honnibal 2018-04-03 23:11:48 +02:00
parent 96b612873b
commit 81f4005f3d
1 changed files with 9 additions and 9 deletions

View File

@ -636,11 +636,11 @@ class Language(object):
"""
path = util.ensure_path(path)
deserializers = OrderedDict((
('vocab', lambda p: self.vocab.from_disk(p)),
('meta.json', lambda p: self.meta.update(util.read_json(p))),
('vocab', lambda p: (
self.vocab.from_disk(p) and _fix_pretrained_vectors_name(self))),
('tokenizer', lambda p: self.tokenizer.from_disk(p, vocab=False)),
('meta.json', lambda p: self.meta.update(util.read_json(p)))
))
_fix_pretrained_vectors_name(self)
for name, proc in self.pipeline:
if name in disable:
continue
@ -682,11 +682,11 @@ class Language(object):
RETURNS (Language): The `Language` object.
"""
deserializers = OrderedDict((
('vocab', lambda b: self.vocab.from_bytes(b)),
('meta', lambda b: self.meta.update(ujson.loads(b))),
('vocab', lambda b: (
self.vocab.from_bytes(b) and _fix_pretrained_vectors_name(self))),
('tokenizer', lambda b: self.tokenizer.from_bytes(b, vocab=False)),
('meta', lambda b: self.meta.update(ujson.loads(b)))
))
_fix_pretrained_vectors_name(self)
for i, (name, proc) in enumerate(self.pipeline):
if name in disable:
continue
@ -708,12 +708,12 @@ def _fix_pretrained_vectors_name(nlp):
nlp.vocab.vectors.name = vectors_name
else:
raise ValueError(Errors.E092)
link_vectors_to_models(nlp.vocab)
for name, proc in nlp.pipeline:
if not hasattr(proc, 'cfg'):
continue
if proc.cfg.get('pretrained_dims'):
assert nlp.vocab.vectors.name
proc.cfg['pretrained_vectors'] = nlp.vocab.vectors.name
proc.cfg.setdefault('deprecation_fixes', {})
proc.cfg['deprecation_fixes']['vectors_name'] = nlp.vocab.vectors.name
class DisabledPipes(list):