Fix textcat serialization

This commit is contained in:
Matthew Honnibal 2017-09-02 15:17:20 +02:00
parent e3ea6ee02b
commit 9e378bdac5
1 changed files with 16 additions and 8 deletions

View File

@ -128,15 +128,20 @@ class BaseThincComponent(object):
def to_bytes(self, **exclude):
serialize = OrderedDict((
('cfg', lambda: json_dumps(self.cfg)),
('model', lambda: self.model.to_bytes()),
('vocab', lambda: self.vocab.to_bytes())
))
return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data, **exclude):
if self.model is True:
self.model = self.Model()
def load_model(b):
if self.model is True:
self.model = self.Model(**self.cfg)
self.model.from_bytes(b)
deserialize = OrderedDict((
('cfg', lambda b: self.cfg.update(ujson.loads(b))),
('model', lambda b: self.model.from_bytes(b)),
('vocab', lambda b: self.vocab.from_bytes(b))
))
@ -145,19 +150,22 @@ class BaseThincComponent(object):
def to_disk(self, path, **exclude):
serialize = OrderedDict((
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg))),
('model', lambda p: p.open('wb').write(self.model.to_bytes())),
('vocab', lambda p: self.vocab.to_disk(p)),
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg)))
('vocab', lambda p: self.vocab.to_disk(p))
))
util.to_disk(path, serialize, exclude)
def from_disk(self, path, **exclude):
if self.model is True:
self.model = self.Model()
def load_model(p):
if self.model is True:
self.model = self.Model(**self.cfg)
self.model.from_bytes(p.open('rb').read())
deserialize = OrderedDict((
('model', lambda p: self.model.from_bytes(p.open('rb').read())),
('cfg', lambda p: self.cfg.update(_load_cfg(p))),
('model', load_model),
('vocab', lambda p: self.vocab.from_disk(p)),
('cfg', lambda p: self.cfg.update(_load_cfg(p)))
))
util.from_disk(path, deserialize, exclude)
return self