Fix textcat serialization

This commit is contained in:
Matthew Honnibal 2017-09-02 15:17:20 +02:00
parent e3ea6ee02b
commit 9e378bdac5
1 changed files with 16 additions and 8 deletions

View File

@ -128,15 +128,20 @@ class BaseThincComponent(object):
def to_bytes(self, **exclude): def to_bytes(self, **exclude):
serialize = OrderedDict(( serialize = OrderedDict((
('cfg', lambda: json_dumps(self.cfg)),
('model', lambda: self.model.to_bytes()), ('model', lambda: self.model.to_bytes()),
('vocab', lambda: self.vocab.to_bytes()) ('vocab', lambda: self.vocab.to_bytes())
)) ))
return util.to_bytes(serialize, exclude) return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data, **exclude): def from_bytes(self, bytes_data, **exclude):
def load_model(b):
if self.model is True: if self.model is True:
self.model = self.Model() self.model = self.Model(**self.cfg)
self.model.from_bytes(b)
deserialize = OrderedDict(( deserialize = OrderedDict((
('cfg', lambda b: self.cfg.update(ujson.loads(b))),
('model', lambda b: self.model.from_bytes(b)), ('model', lambda b: self.model.from_bytes(b)),
('vocab', lambda b: self.vocab.from_bytes(b)) ('vocab', lambda b: self.vocab.from_bytes(b))
)) ))
@ -145,19 +150,22 @@ class BaseThincComponent(object):
def to_disk(self, path, **exclude): def to_disk(self, path, **exclude):
serialize = OrderedDict(( serialize = OrderedDict((
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg))),
('model', lambda p: p.open('wb').write(self.model.to_bytes())), ('model', lambda p: p.open('wb').write(self.model.to_bytes())),
('vocab', lambda p: self.vocab.to_disk(p)), ('vocab', lambda p: self.vocab.to_disk(p))
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg)))
)) ))
util.to_disk(path, serialize, exclude) util.to_disk(path, serialize, exclude)
def from_disk(self, path, **exclude): def from_disk(self, path, **exclude):
def load_model(p):
if self.model is True: if self.model is True:
self.model = self.Model() self.model = self.Model(**self.cfg)
self.model.from_bytes(p.open('rb').read())
deserialize = OrderedDict(( deserialize = OrderedDict((
('model', lambda p: self.model.from_bytes(p.open('rb').read())), ('cfg', lambda p: self.cfg.update(_load_cfg(p))),
('model', load_model),
('vocab', lambda p: self.vocab.from_disk(p)), ('vocab', lambda p: self.vocab.from_disk(p)),
('cfg', lambda p: self.cfg.update(_load_cfg(p)))
)) ))
util.from_disk(path, deserialize, exclude) util.from_disk(path, deserialize, exclude)
return self return self