mirror of https://github.com/explosion/spaCy.git
Fix 2/3 problem for json save/load
This commit is contained in:
parent
40703988bc
commit
cd33b39a04
|
@ -5,7 +5,7 @@ import pathlib
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
import ujson as json
|
import ujson
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -13,6 +13,10 @@ try:
|
||||||
except NameError:
|
except NameError:
|
||||||
basestring = str
|
basestring = str
|
||||||
|
|
||||||
|
try:
|
||||||
|
unicode
|
||||||
|
except NameError:
|
||||||
|
unicode = str
|
||||||
|
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
from .vocab import Vocab
|
from .vocab import Vocab
|
||||||
|
@ -226,12 +230,21 @@ class Language(object):
|
||||||
parser_cfg['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
|
parser_cfg['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
|
||||||
entity_cfg['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
|
entity_cfg['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
|
||||||
|
|
||||||
with (dep_model_dir / 'config.json').open('w') as file_:
|
with (dep_model_dir / 'config.json').open('wb') as file_:
|
||||||
json.dump(parser_cfg, file_)
|
data = ujson.dumps(parser_cfg)
|
||||||
with (ner_model_dir / 'config.json').open('w') as file_:
|
if isinstance(data, unicode):
|
||||||
json.dump(entity_cfg, file_)
|
data = data.encode('utf8')
|
||||||
with (pos_model_dir / 'config.json').open('w') as file_:
|
file_.write(data)
|
||||||
json.dump(tagger_cfg, file_)
|
with (ner_model_dir / 'config.json').open('wb') as file_:
|
||||||
|
data = ujson.dumps(entity_cfg)
|
||||||
|
if isinstance(data, unicode):
|
||||||
|
data = data.encode('utf8')
|
||||||
|
file_.write(data)
|
||||||
|
with (pos_model_dir / 'config.json').open('wb') as file_:
|
||||||
|
data = ujson.dumps(tagger_cfg)
|
||||||
|
if isinstance(data, unicode):
|
||||||
|
data = data.encode('utf8')
|
||||||
|
file_.write(data)
|
||||||
|
|
||||||
self = cls(
|
self = cls(
|
||||||
path=path,
|
path=path,
|
||||||
|
@ -391,12 +404,14 @@ class Language(object):
|
||||||
else:
|
else:
|
||||||
entity_iob_freqs = []
|
entity_iob_freqs = []
|
||||||
entity_type_freqs = []
|
entity_type_freqs = []
|
||||||
with (path / 'vocab' / 'serializer.json').open('w') as file_:
|
with (path / 'vocab' / 'serializer.json').open('wb') as file_:
|
||||||
file_.write(
|
data = ujson.dumps([
|
||||||
json.dumps([
|
(TAG, tagger_freqs),
|
||||||
(TAG, tagger_freqs),
|
(DEP, dep_freqs),
|
||||||
(DEP, dep_freqs),
|
(ENT_IOB, entity_iob_freqs),
|
||||||
(ENT_IOB, entity_iob_freqs),
|
(ENT_TYPE, entity_type_freqs),
|
||||||
(ENT_TYPE, entity_type_freqs),
|
(HEAD, head_freqs)
|
||||||
(HEAD, head_freqs)
|
])
|
||||||
]))
|
if isinstance(data, unicode):
|
||||||
|
data = data.encode('utf8')
|
||||||
|
file_.write(data)
|
||||||
|
|
Loading…
Reference in New Issue