Fix 2/3 problem for json save/load

This commit is contained in:
Matthew Honnibal 2017-03-08 01:39:13 +01:00
parent 40703988bc
commit cd33b39a04
1 changed files with 31 additions and 16 deletions

View File

@ -5,7 +5,7 @@ import pathlib
from contextlib import contextmanager from contextlib import contextmanager
import shutil import shutil
import ujson as json import ujson
try: try:
@ -13,6 +13,10 @@ try:
except NameError: except NameError:
basestring = str basestring = str
try:
unicode
except NameError:
unicode = str
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
from .vocab import Vocab from .vocab import Vocab
@ -226,12 +230,21 @@ class Language(object):
parser_cfg['actions'] = ArcEager.get_actions(gold_parses=gold_tuples) parser_cfg['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
entity_cfg['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples) entity_cfg['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
with (dep_model_dir / 'config.json').open('w') as file_: with (dep_model_dir / 'config.json').open('wb') as file_:
json.dump(parser_cfg, file_) data = ujson.dumps(parser_cfg)
with (ner_model_dir / 'config.json').open('w') as file_: if isinstance(data, unicode):
json.dump(entity_cfg, file_) data = data.encode('utf8')
with (pos_model_dir / 'config.json').open('w') as file_: file_.write(data)
json.dump(tagger_cfg, file_) with (ner_model_dir / 'config.json').open('wb') as file_:
data = ujson.dumps(entity_cfg)
if isinstance(data, unicode):
data = data.encode('utf8')
file_.write(data)
with (pos_model_dir / 'config.json').open('wb') as file_:
data = ujson.dumps(tagger_cfg)
if isinstance(data, unicode):
data = data.encode('utf8')
file_.write(data)
self = cls( self = cls(
path=path, path=path,
@ -391,12 +404,14 @@ class Language(object):
else: else:
entity_iob_freqs = [] entity_iob_freqs = []
entity_type_freqs = [] entity_type_freqs = []
with (path / 'vocab' / 'serializer.json').open('w') as file_: with (path / 'vocab' / 'serializer.json').open('wb') as file_:
file_.write( data = ujson.dumps([
json.dumps([ (TAG, tagger_freqs),
(TAG, tagger_freqs), (DEP, dep_freqs),
(DEP, dep_freqs), (ENT_IOB, entity_iob_freqs),
(ENT_IOB, entity_iob_freqs), (ENT_TYPE, entity_type_freqs),
(ENT_TYPE, entity_type_freqs), (HEAD, head_freqs)
(HEAD, head_freqs) ])
])) if isinstance(data, unicode):
data = data.encode('utf8')
file_.write(data)