diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index 675bfa52e..238e5670e 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -128,15 +128,20 @@ class BaseThincComponent(object): def to_bytes(self, **exclude): serialize = OrderedDict(( + ('cfg', lambda: json_dumps(self.cfg)), ('model', lambda: self.model.to_bytes()), ('vocab', lambda: self.vocab.to_bytes()) )) return util.to_bytes(serialize, exclude) def from_bytes(self, bytes_data, **exclude): - if self.model is True: - self.model = self.Model() + def load_model(b): + if self.model is True: + self.model = self.Model(**self.cfg) + self.model.from_bytes(b) + deserialize = OrderedDict(( + ('cfg', lambda b: self.cfg.update(ujson.loads(b))), ('model', lambda b: self.model.from_bytes(b)), ('vocab', lambda b: self.vocab.from_bytes(b)) )) @@ -145,19 +150,22 @@ class BaseThincComponent(object): def to_disk(self, path, **exclude): serialize = OrderedDict(( + ('cfg', lambda p: p.open('w').write(json_dumps(self.cfg))), ('model', lambda p: p.open('wb').write(self.model.to_bytes())), - ('vocab', lambda p: self.vocab.to_disk(p)), - ('cfg', lambda p: p.open('w').write(json_dumps(self.cfg))) + ('vocab', lambda p: self.vocab.to_disk(p)) )) util.to_disk(path, serialize, exclude) def from_disk(self, path, **exclude): - if self.model is True: - self.model = self.Model() + def load_model(p): + if self.model is True: + self.model = self.Model(**self.cfg) + self.model.from_bytes(p.open('rb').read()) + deserialize = OrderedDict(( - ('model', lambda p: self.model.from_bytes(p.open('rb').read())), + ('cfg', lambda p: self.cfg.update(_load_cfg(p))), + ('model', load_model), ('vocab', lambda p: self.vocab.from_disk(p)), - ('cfg', lambda p: self.cfg.update(_load_cfg(p))) )) util.from_disk(path, deserialize, exclude) return self