mirror of https://github.com/explosion/spaCy.git
Fix 2/3 problem for json save/load
This commit is contained in:
parent
40703988bc
commit
cd33b39a04
|
@ -5,7 +5,7 @@ import pathlib
|
|||
from contextlib import contextmanager
|
||||
import shutil
|
||||
|
||||
import ujson as json
|
||||
import ujson
|
||||
|
||||
|
||||
try:
|
||||
|
@ -13,6 +13,10 @@ try:
|
|||
except NameError:
|
||||
basestring = str
|
||||
|
||||
try:
|
||||
unicode
|
||||
except NameError:
|
||||
unicode = str
|
||||
|
||||
from .tokenizer import Tokenizer
|
||||
from .vocab import Vocab
|
||||
|
@ -226,12 +230,21 @@ class Language(object):
|
|||
parser_cfg['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
|
||||
entity_cfg['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
|
||||
|
||||
with (dep_model_dir / 'config.json').open('w') as file_:
|
||||
json.dump(parser_cfg, file_)
|
||||
with (ner_model_dir / 'config.json').open('w') as file_:
|
||||
json.dump(entity_cfg, file_)
|
||||
with (pos_model_dir / 'config.json').open('w') as file_:
|
||||
json.dump(tagger_cfg, file_)
|
||||
with (dep_model_dir / 'config.json').open('wb') as file_:
|
||||
data = ujson.dumps(parser_cfg)
|
||||
if isinstance(data, unicode):
|
||||
data = data.encode('utf8')
|
||||
file_.write(data)
|
||||
with (ner_model_dir / 'config.json').open('wb') as file_:
|
||||
data = ujson.dumps(entity_cfg)
|
||||
if isinstance(data, unicode):
|
||||
data = data.encode('utf8')
|
||||
file_.write(data)
|
||||
with (pos_model_dir / 'config.json').open('wb') as file_:
|
||||
data = ujson.dumps(tagger_cfg)
|
||||
if isinstance(data, unicode):
|
||||
data = data.encode('utf8')
|
||||
file_.write(data)
|
||||
|
||||
self = cls(
|
||||
path=path,
|
||||
|
@ -391,12 +404,14 @@ class Language(object):
|
|||
else:
|
||||
entity_iob_freqs = []
|
||||
entity_type_freqs = []
|
||||
with (path / 'vocab' / 'serializer.json').open('w') as file_:
|
||||
file_.write(
|
||||
json.dumps([
|
||||
with (path / 'vocab' / 'serializer.json').open('wb') as file_:
|
||||
data = ujson.dumps([
|
||||
(TAG, tagger_freqs),
|
||||
(DEP, dep_freqs),
|
||||
(ENT_IOB, entity_iob_freqs),
|
||||
(ENT_TYPE, entity_type_freqs),
|
||||
(HEAD, head_freqs)
|
||||
]))
|
||||
])
|
||||
if isinstance(data, unicode):
|
||||
data = data.encode('utf8')
|
||||
file_.write(data)
|
||||
|
|
Loading…
Reference in New Issue