Update vocab

This commit is contained in:
Matthew Honnibal 2019-08-29 21:19:54 +02:00
parent fc0a3c8c38
commit f3c3ce7f1e
1 changed files with 8 additions and 2 deletions

View File

@ -433,6 +433,8 @@ cdef class Vocab:
file_.write(self.lexemes_to_bytes()) file_.write(self.lexemes_to_bytes())
if "vectors" not in "exclude" and self.vectors is not None: if "vectors" not in "exclude" and self.vectors is not None:
self.vectors.to_disk(path) self.vectors.to_disk(path)
if "morphology" not in exclude:
self.morphology.to_disk(path / "morphology.bin")
def from_disk(self, path, exclude=tuple(), **kwargs): def from_disk(self, path, exclude=tuple(), **kwargs):
"""Loads state from a directory. Modifies the object in place and """Loads state from a directory. Modifies the object in place and
@ -457,6 +459,8 @@ cdef class Vocab:
self.vectors.from_disk(path, exclude=["strings"]) self.vectors.from_disk(path, exclude=["strings"])
if self.vectors.name is not None: if self.vectors.name is not None:
link_vectors_to_models(self) link_vectors_to_models(self)
if "morphology" not in exclude:
self.morphology.from_disk(path / "morphology.bin")
return self return self
def to_bytes(self, exclude=tuple(), **kwargs): def to_bytes(self, exclude=tuple(), **kwargs):
@ -476,7 +480,8 @@ cdef class Vocab:
getters = OrderedDict(( getters = OrderedDict((
("strings", lambda: self.strings.to_bytes()), ("strings", lambda: self.strings.to_bytes()),
("lexemes", lambda: self.lexemes_to_bytes()), ("lexemes", lambda: self.lexemes_to_bytes()),
("vectors", deserialize_vectors) ("vectors", deserialize_vectors),
("morphology", lambda: self.morphology.to_bytes())
)) ))
exclude = util.get_serialization_exclude(getters, exclude, kwargs) exclude = util.get_serialization_exclude(getters, exclude, kwargs)
return util.to_bytes(getters, exclude) return util.to_bytes(getters, exclude)
@ -499,7 +504,8 @@ cdef class Vocab:
setters = OrderedDict(( setters = OrderedDict((
("strings", lambda b: self.strings.from_bytes(b)), ("strings", lambda b: self.strings.from_bytes(b)),
("lexemes", lambda b: self.lexemes_from_bytes(b)), ("lexemes", lambda b: self.lexemes_from_bytes(b)),
("vectors", lambda b: serialize_vectors(b)) ("vectors", lambda b: serialize_vectors(b)),
("morphology", lambda b: self.morphology.from_bytes(b))
)) ))
exclude = util.get_serialization_exclude(setters, exclude, kwargs) exclude = util.get_serialization_exclude(setters, exclude, kwargs)
util.from_bytes(bytes_data, setters, exclude) util.from_bytes(bytes_data, setters, exclude)