mirror of https://github.com/explosion/spaCy.git
Fix textcat serialization
This commit is contained in:
parent
e3ea6ee02b
commit
9e378bdac5
|
@ -128,15 +128,20 @@ class BaseThincComponent(object):
|
||||||
|
|
||||||
def to_bytes(self, **exclude):
|
def to_bytes(self, **exclude):
|
||||||
serialize = OrderedDict((
|
serialize = OrderedDict((
|
||||||
|
('cfg', lambda: json_dumps(self.cfg)),
|
||||||
('model', lambda: self.model.to_bytes()),
|
('model', lambda: self.model.to_bytes()),
|
||||||
('vocab', lambda: self.vocab.to_bytes())
|
('vocab', lambda: self.vocab.to_bytes())
|
||||||
))
|
))
|
||||||
return util.to_bytes(serialize, exclude)
|
return util.to_bytes(serialize, exclude)
|
||||||
|
|
||||||
def from_bytes(self, bytes_data, **exclude):
|
def from_bytes(self, bytes_data, **exclude):
|
||||||
|
def load_model(b):
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.model = self.Model()
|
self.model = self.Model(**self.cfg)
|
||||||
|
self.model.from_bytes(b)
|
||||||
|
|
||||||
deserialize = OrderedDict((
|
deserialize = OrderedDict((
|
||||||
|
('cfg', lambda b: self.cfg.update(ujson.loads(b))),
|
||||||
('model', lambda b: self.model.from_bytes(b)),
|
('model', lambda b: self.model.from_bytes(b)),
|
||||||
('vocab', lambda b: self.vocab.from_bytes(b))
|
('vocab', lambda b: self.vocab.from_bytes(b))
|
||||||
))
|
))
|
||||||
|
@ -145,19 +150,22 @@ class BaseThincComponent(object):
|
||||||
|
|
||||||
def to_disk(self, path, **exclude):
|
def to_disk(self, path, **exclude):
|
||||||
serialize = OrderedDict((
|
serialize = OrderedDict((
|
||||||
|
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg))),
|
||||||
('model', lambda p: p.open('wb').write(self.model.to_bytes())),
|
('model', lambda p: p.open('wb').write(self.model.to_bytes())),
|
||||||
('vocab', lambda p: self.vocab.to_disk(p)),
|
('vocab', lambda p: self.vocab.to_disk(p))
|
||||||
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg)))
|
|
||||||
))
|
))
|
||||||
util.to_disk(path, serialize, exclude)
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
def from_disk(self, path, **exclude):
|
def from_disk(self, path, **exclude):
|
||||||
|
def load_model(p):
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.model = self.Model()
|
self.model = self.Model(**self.cfg)
|
||||||
|
self.model.from_bytes(p.open('rb').read())
|
||||||
|
|
||||||
deserialize = OrderedDict((
|
deserialize = OrderedDict((
|
||||||
('model', lambda p: self.model.from_bytes(p.open('rb').read())),
|
('cfg', lambda p: self.cfg.update(_load_cfg(p))),
|
||||||
|
('model', load_model),
|
||||||
('vocab', lambda p: self.vocab.from_disk(p)),
|
('vocab', lambda p: self.vocab.from_disk(p)),
|
||||||
('cfg', lambda p: self.cfg.update(_load_cfg(p)))
|
|
||||||
))
|
))
|
||||||
util.from_disk(path, deserialize, exclude)
|
util.from_disk(path, deserialize, exclude)
|
||||||
return self
|
return self
|
||||||
|
|
Loading…
Reference in New Issue