From fc0a3c8c3877694c19ec5c4c5bab969e7ae2c93b Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 29 Aug 2019 21:17:34 +0200 Subject: [PATCH] Add morphology serialization --- spacy/morphology.pyx | 45 ++++++++++++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/spacy/morphology.pyx b/spacy/morphology.pyx index ccfb214bc..a7a1bee57 100644 --- a/spacy/morphology.pyx +++ b/spacy/morphology.pyx @@ -15,6 +15,7 @@ from .parts_of_speech cimport SPACE from .parts_of_speech import IDS as POS_IDS from .lexeme cimport Lexeme from .errors import Errors +from .util import ensure_path cdef enum univ_field_t: @@ -162,12 +163,7 @@ cdef class Morphology: self.n_tags = len(tag_map) self.reverse_index = {} self._feat_map = MorphologyClassMap(FEATURES) - for i, (tag_str, attrs) in enumerate(sorted(tag_map.items())): - attrs = _normalize_props(attrs) - self.add({self._feat_map.id2feat[feat] for feat in attrs - if feat in self._feat_map.id2feat}) - self.tag_map[tag_str] = dict(attrs) - self.reverse_index[self.strings.add(tag_str)] = i + self._load_from_tag_map(tag_map) self._cache = PreshMapArray(self.n_tags) self.exc = {} @@ -177,6 +173,14 @@ cdef class Morphology: self.add_special_case( self.strings.as_string(tag), self.strings.as_string(orth), attrs) + def _load_from_tag_map(self, tag_map): + for i, (tag_str, attrs) in enumerate(sorted(tag_map.items())): + attrs = _normalize_props(attrs) + self.add({self._feat_map.id2feat[feat] for feat in attrs + if feat in self._feat_map.id2feat}) + self.tag_map[tag_str] = dict(attrs) + self.reverse_index[self.strings.add(tag_str)] = i + def __reduce__(self): return (Morphology, (self.strings, self.tag_map, self.lemmatizer, self.exc), None, None) @@ -188,6 +192,7 @@ cdef class Morphology: for f in features: if isinstance(f, basestring_): self.strings.add(f) + string_features = features features = intify_features(features) cdef attr_t feature for feature in features: @@ -321,22 +326,34 @@ cdef class Morphology: for form_str, attrs in entries.items(): self.add_special_case(tag_str, form_str, attrs) - def to_bytes(self): - json_tags = [] + def to_bytes(self, exclude=tuple(), **kwargs): + tag_map = {} for key in self.tags: tag_ptr = self.tags.get(key) if tag_ptr != NULL: - json_tags.append(tag_to_json(tag_ptr)) - return srsly.json_dumps(json_tags) + tag_map[key] = tag_to_json(tag_ptr) + exceptions = {} + for (tag_str, orth_int), attrs in sorted(self.exc.items()): + exceptions.setdefault(tag_str, {}) + exceptions[tag_str][self.strings[orth_int]] = attrs + data = {"tag_map": tag_map, "exceptions": exceptions} + return srsly.msgpack_dumps(data) def from_bytes(self, byte_string): - raise NotImplementedError + msg = srsly.msgpack_loads(byte_string) + self._load_from_tag_map(msg["tag_map"]) + self.load_morph_exceptions(msg["exceptions"]) + return self - def to_disk(self, path): - raise NotImplementedError + def to_disk(self, path, exclude=tuple(), **kwargs): + path = ensure_path(path) + with path.open("wb") as file_: + file_.write(self.to_bytes()) def from_disk(self, path): - raise NotImplementedError + with path.open("rb") as file_: + byte_string = file_.read() + return self.from_bytes(byte_string) @classmethod def create_class_map(cls):