Use OrderedDict

This commit is contained in:
ines 2017-06-02 21:07:56 +02:00
parent 2f1025a94c
commit 6669583f4e
1 changed files with 19 additions and 17 deletions

View File

@ -5,7 +5,7 @@
# coding: utf-8 # coding: utf-8
from __future__ import unicode_literals, print_function from __future__ import unicode_literals, print_function
from collections import Counter from collections import Counter, OrderedDict
import ujson import ujson
import contextlib import contextlib
@ -668,13 +668,13 @@ cdef class Parser:
return self return self
def to_bytes(self, **exclude): def to_bytes(self, **exclude):
serializers = { serializers = OrderedDict((
'lower_model': lambda: self.model[0].to_bytes(), ('lower_model', lambda: self.model[0].to_bytes()),
'upper_model': lambda: self.model[1].to_bytes(), ('upper_model', lambda: self.model[1].to_bytes()),
'vocab': lambda: self.vocab.to_bytes(), ('vocab', lambda: self.vocab.to_bytes()),
'moves': lambda: self.moves.to_bytes(strings=False), ('moves', lambda: self.moves.to_bytes(strings=False)),
'cfg': lambda: ujson.dumps(self.cfg) ('cfg', lambda: ujson.dumps(self.cfg))
} ))
if 'model' in exclude: if 'model' in exclude:
exclude['lower_model'] = True exclude['lower_model'] = True
exclude['upper_model'] = True exclude['upper_model'] = True
@ -682,21 +682,23 @@ cdef class Parser:
return util.to_bytes(serializers, exclude) return util.to_bytes(serializers, exclude)
def from_bytes(self, bytes_data, **exclude): def from_bytes(self, bytes_data, **exclude):
deserializers = { deserializers = OrderedDict((
'vocab': lambda b: self.vocab.from_bytes(b), ('vocab', lambda b: self.vocab.from_bytes(b)),
'moves': lambda b: self.moves.from_bytes(b, strings=False), ('moves', lambda b: self.moves.from_bytes(b, strings=False)),
'cfg': lambda b: self.cfg.update(ujson.loads(b)), ('cfg', lambda b: self.cfg.update(ujson.loads(b))),
'lower_model': lambda b: None, ('lower_model', lambda b: None),
'upper_model': lambda b: None ('upper_model', lambda b: None)
} ))
msg = util.from_bytes(bytes_data, deserializers, exclude) msg = util.from_bytes(bytes_data, deserializers, exclude)
if 'model' not in exclude: if 'model' not in exclude:
if self.model is True: if self.model is True:
self.model, cfg = self.Model(self.moves.n_moves) self.model, cfg = self.Model(self.moves.n_moves)
else: else:
cfg = {} cfg = {}
self.model[0].from_bytes(msg['lower_model']) if 'lower_model' in msg:
self.model[1].from_bytes(msg['upper_model']) self.model[0].from_bytes(msg['lower_model'])
if 'upper_model' in msg:
self.model[1].from_bytes(msg['upper_model'])
self.cfg.update(cfg) self.cfg.update(cfg)
return self return self