Fix 2/3 problem for json save/load

This commit is contained in:
Matthew Honnibal 2017-03-08 01:39:13 +01:00
parent 40703988bc
commit cd33b39a04
1 changed files with 31 additions and 16 deletions

View File

@ -5,7 +5,7 @@ import pathlib
from contextlib import contextmanager
import shutil
import ujson as json
import ujson
try:
@ -13,6 +13,10 @@ try:
except NameError:
basestring = str
try:
unicode
except NameError:
unicode = str
from .tokenizer import Tokenizer
from .vocab import Vocab
@ -226,12 +230,21 @@ class Language(object):
parser_cfg['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
entity_cfg['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
with (dep_model_dir / 'config.json').open('w') as file_:
json.dump(parser_cfg, file_)
with (ner_model_dir / 'config.json').open('w') as file_:
json.dump(entity_cfg, file_)
with (pos_model_dir / 'config.json').open('w') as file_:
json.dump(tagger_cfg, file_)
with (dep_model_dir / 'config.json').open('wb') as file_:
data = ujson.dumps(parser_cfg)
if isinstance(data, unicode):
data = data.encode('utf8')
file_.write(data)
with (ner_model_dir / 'config.json').open('wb') as file_:
data = ujson.dumps(entity_cfg)
if isinstance(data, unicode):
data = data.encode('utf8')
file_.write(data)
with (pos_model_dir / 'config.json').open('wb') as file_:
data = ujson.dumps(tagger_cfg)
if isinstance(data, unicode):
data = data.encode('utf8')
file_.write(data)
self = cls(
path=path,
@ -391,12 +404,14 @@ class Language(object):
else:
entity_iob_freqs = []
entity_type_freqs = []
with (path / 'vocab' / 'serializer.json').open('w') as file_:
file_.write(
json.dumps([
with (path / 'vocab' / 'serializer.json').open('wb') as file_:
data = ujson.dumps([
(TAG, tagger_freqs),
(DEP, dep_freqs),
(ENT_IOB, entity_iob_freqs),
(ENT_TYPE, entity_type_freqs),
(HEAD, head_freqs)
]))
])
if isinstance(data, unicode):
data = data.encode('utf8')
file_.write(data)