Try to fix python3.5 serialization

This commit is contained in:
Matthew Honnibal 2017-11-08 12:10:49 +01:00
parent e262e8d942
commit 072ff38a01
1 changed files with 6 additions and 6 deletions

View File

@ -370,7 +370,7 @@ class Tagger(Pipe):
def __init__(self, vocab, model=True, **cfg):
self.vocab = vocab
self.model = model
self.cfg = dict(cfg)
self.cfg = OrderedDict(sorted(cfg.items()))
self.cfg.setdefault('cnn_maxout_pieces', 2)
self.cfg.setdefault('pretrained_dims',
self.vocab.vectors.data.shape[1])
@ -469,7 +469,7 @@ class Tagger(Pipe):
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None):
orig_tag_map = dict(self.vocab.morphology.tag_map)
new_tag_map = {}
new_tag_map = OrderedDict()
for raw_text, annots_brackets in gold_tuples:
for annots, brackets in annots_brackets:
ids, words, tags, heads, deps, ents = annots
@ -533,8 +533,9 @@ class Tagger(Pipe):
serialize['model'] = self.model.to_bytes
serialize['vocab'] = self.vocab.to_bytes
tag_map = OrderedDict(sorted(self.vocab.morphology.item()))
serialize['tag_map'] = lambda: msgpack.dumps(
self.vocab.morphology.tag_map, use_bin_type=True, encoding='utf8')
tag_map, use_bin_type=True, encoding='utf8')
return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data, **exclude):
@ -565,12 +566,11 @@ class Tagger(Pipe):
def to_disk(self, path, **exclude):
self.cfg['pretrained_dims'] = self.vocab.vectors.data.shape[1]
tag_map = OrderedDict(sorted(self.vocab.morphology.item()))
serialize = OrderedDict((
('vocab', lambda p: self.vocab.to_disk(p)),
('tag_map', lambda p: p.open('wb').write(msgpack.dumps(
self.vocab.morphology.tag_map,
use_bin_type=True,
encoding='utf8'))),
tag_map, use_bin_type=True, encoding='utf8'))),
('model', lambda p: p.open('wb').write(self.model.to_bytes())),
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg)))
))