Merge pull request #5514 from adrianeboyd/bugfix/load-vector-name

Improve vector name loading from model meta
This commit is contained in:
Matthew Honnibal 2020-05-27 20:39:23 +02:00 committed by GitHub
commit e7ac12b598
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 29 additions and 9 deletions

View File

@ -934,15 +934,26 @@ class Language(object):
DOCS: https://spacy.io/api/language#from_disk DOCS: https://spacy.io/api/language#from_disk
""" """
def deserialize_meta(path):
if path.exists():
data = srsly.read_json(path)
self.meta.update(data)
# self.meta always overrides meta["vectors"] with the metadata
# from self.vocab.vectors, so set the name directly
self.vocab.vectors.name = data.get("vectors", {}).get("name")
def deserialize_vocab(path):
if path.exists():
self.vocab.from_disk(path)
_fix_pretrained_vectors_name(self)
if disable is not None: if disable is not None:
warnings.warn(Warnings.W014, DeprecationWarning) warnings.warn(Warnings.W014, DeprecationWarning)
exclude = disable exclude = disable
path = util.ensure_path(path) path = util.ensure_path(path)
deserializers = OrderedDict() deserializers = OrderedDict()
deserializers["meta.json"] = lambda p: self.meta.update(srsly.read_json(p)) deserializers["meta.json"] = deserialize_meta
deserializers["vocab"] = lambda p: self.vocab.from_disk( deserializers["vocab"] = deserialize_vocab
p
) and _fix_pretrained_vectors_name(self)
deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk( deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(
p, exclude=["vocab"] p, exclude=["vocab"]
) )
@ -996,14 +1007,23 @@ class Language(object):
DOCS: https://spacy.io/api/language#from_bytes DOCS: https://spacy.io/api/language#from_bytes
""" """
def deserialize_meta(b):
data = srsly.json_loads(b)
self.meta.update(data)
# self.meta always overrides meta["vectors"] with the metadata
# from self.vocab.vectors, so set the name directly
self.vocab.vectors.name = data.get("vectors", {}).get("name")
def deserialize_vocab(b):
self.vocab.from_bytes(b)
_fix_pretrained_vectors_name(self)
if disable is not None: if disable is not None:
warnings.warn(Warnings.W014, DeprecationWarning) warnings.warn(Warnings.W014, DeprecationWarning)
exclude = disable exclude = disable
deserializers = OrderedDict() deserializers = OrderedDict()
deserializers["meta.json"] = lambda b: self.meta.update(srsly.json_loads(b)) deserializers["meta.json"] = deserialize_meta
deserializers["vocab"] = lambda b: self.vocab.from_bytes( deserializers["vocab"] = deserialize_vocab
b
) and _fix_pretrained_vectors_name(self)
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes( deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
b, exclude=["vocab"] b, exclude=["vocab"]
) )
@ -1069,7 +1089,7 @@ class component(object):
def _fix_pretrained_vectors_name(nlp): def _fix_pretrained_vectors_name(nlp):
# TODO: Replace this once we handle vectors consistently as static # TODO: Replace this once we handle vectors consistently as static
# data # data
if "vectors" in nlp.meta and nlp.meta["vectors"].get("name"): if "vectors" in nlp.meta and "name" in nlp.meta["vectors"]:
nlp.vocab.vectors.name = nlp.meta["vectors"]["name"] nlp.vocab.vectors.name = nlp.meta["vectors"]["name"]
elif not nlp.vocab.vectors.size: elif not nlp.vocab.vectors.size:
nlp.vocab.vectors.name = None nlp.vocab.vectors.name = None