mirror of https://github.com/explosion/spaCy.git
* Load/dump strings with a json file, instead of the hacky strings file we were using.
This commit is contained in:
parent
9baf0abd59
commit
2348a08481
|
@ -191,7 +191,8 @@ def setup_vocab(get_lex_attr, tag_map, src_dir, dst_dir):
|
|||
else:
|
||||
lexeme.cluster = 0
|
||||
vocab.dump(str(dst_dir / 'lexemes.bin'))
|
||||
vocab.strings.dump(str(dst_dir / 'strings.txt'))
|
||||
with (dst_dir / 'strings.json').open('w') as file_:
|
||||
vocab.strings.dump(file_)
|
||||
with (dst_dir / 'oov_prob').open('w') as file_:
|
||||
file_.write('%f' % oov_prob)
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -133,7 +133,7 @@ def cython_setup(mod_names, language, includes):
|
|||
"data/wordnet/*", "data/tokenizer/*",
|
||||
"data/vocab/tag_map.json",
|
||||
"data/vocab/lexemes.bin",
|
||||
"data/vocab/strings.txt"],
|
||||
"data/vocab/strings.json"],
|
||||
"spacy.syntax": ["*.pxd"]},
|
||||
ext_modules=exts,
|
||||
cmdclass={'build_ext': build_ext_cython_subclass},
|
||||
|
|
|
@ -247,8 +247,10 @@ class Language(object):
|
|||
self.parser.model.end_training(path.join(data_dir, 'deps', 'model'))
|
||||
self.entity.model.end_training(path.join(data_dir, 'ner', 'model'))
|
||||
self.tagger.model.end_training(path.join(data_dir, 'pos', 'model'))
|
||||
self.vocab.dump(path.join(data_dir, 'vocab', 'lexemes.bin'))
|
||||
self.vocab.strings.dump(path.join(data_dir, 'vocab', 'strings.txt'))
|
||||
|
||||
strings_loc = path.join(data_dir, 'vocab', 'strings.txt')
|
||||
with io.open(strings_loc, 'w', encoding='utf8'):
|
||||
self.vocab.strings.dump(file_)
|
||||
|
||||
with open(path.join(data_dir, 'vocab', 'serializer.json'), 'w') as file_:
|
||||
file_.write(
|
||||
|
|
|
@ -12,8 +12,15 @@ from libc.stdint cimport int64_t
|
|||
|
||||
from .typedefs cimport hash_t, attr_t
|
||||
|
||||
try:
|
||||
import codecs as io
|
||||
except ImportError:
|
||||
import io
|
||||
|
||||
SEPARATOR = '\n|-SEP-|\n'
|
||||
try:
|
||||
import ujson as json
|
||||
except ImportError:
|
||||
import json
|
||||
|
||||
|
||||
cpdef hash_t hash_string(unicode string) except 0:
|
||||
|
@ -114,7 +121,11 @@ cdef class StringStore:
|
|||
def __iter__(self):
|
||||
cdef int i
|
||||
for i in range(self.size):
|
||||
yield self[i]
|
||||
if i == 0:
|
||||
yield u''
|
||||
else:
|
||||
utf8str = &self.c[i]
|
||||
yield _decode(utf8str)
|
||||
|
||||
def __reduce__(self):
|
||||
strings = [""]
|
||||
|
@ -138,28 +149,22 @@ cdef class StringStore:
|
|||
self.size += 1
|
||||
return &self.c[self.size-1]
|
||||
|
||||
def dump(self, loc):
|
||||
cdef Utf8Str* string
|
||||
cdef unicode py_string
|
||||
cdef int i
|
||||
with codecs.open(loc, 'w', 'utf8') as file_:
|
||||
for i in range(1, self.size):
|
||||
string = &self.c[i]
|
||||
py_string = _decode(string)
|
||||
file_.write(py_string)
|
||||
if (i+1) != self.size:
|
||||
file_.write(SEPARATOR)
|
||||
def dump(self, file_):
|
||||
string_data = json.dumps([s for s in self])
|
||||
if not isinstance(string_data, unicode):
|
||||
string_data = string_data.decode('utf8')
|
||||
file_.write(string_data)
|
||||
|
||||
def load(self, loc):
|
||||
with codecs.open(loc, 'r', 'utf8') as file_:
|
||||
strings = file_.read().split(SEPARATOR)
|
||||
def load(self, file_):
|
||||
strings = json.load(file_)
|
||||
if strings == ['']:
|
||||
return None
|
||||
cdef unicode string
|
||||
cdef bytes byte_string
|
||||
for string in strings:
|
||||
byte_string = string.encode('utf8')
|
||||
self.intern(byte_string, len(byte_string))
|
||||
if string:
|
||||
byte_string = string.encode('utf8')
|
||||
self.intern(byte_string, len(byte_string))
|
||||
|
||||
def _realloc(self):
|
||||
# We want to map straight to pointers, but they'll be invalidated if
|
||||
|
|
|
@ -62,7 +62,9 @@ cdef class Vocab:
|
|||
cdef Vocab self = cls(get_lex_attr=get_lex_attr, tag_map=tag_map,
|
||||
lemmatizer=lemmatizer, serializer_freqs=serializer_freqs)
|
||||
|
||||
self.load_lexemes(path.join(data_dir, 'strings.txt'), path.join(data_dir, 'lexemes.bin'))
|
||||
with io.open(path.join(data_dir, 'strings.json'), 'r', encoding='utf8') as file_:
|
||||
self.strings.load(file_)
|
||||
self.load_lexemes(path.join(data_dir, 'lexemes.bin'))
|
||||
if path.exists(path.join(data_dir, 'vec.bin')):
|
||||
self.vectors_length = self.load_vectors_from_bin_loc(path.join(data_dir, 'vec.bin'))
|
||||
return self
|
||||
|
@ -106,11 +108,12 @@ cdef class Vocab:
|
|||
# TODO: Dump vectors
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
lex_loc = path.join(tmp_dir, 'lexemes.bin')
|
||||
str_loc = path.join(tmp_dir, 'strings.txt')
|
||||
str_loc = path.join(tmp_dir, 'strings.json')
|
||||
vec_loc = path.join(self.data_dir, 'vec.bin') if self.data_dir is not None else None
|
||||
|
||||
self.dump(lex_loc)
|
||||
self.strings.dump(str_loc)
|
||||
with io.open(str_loc, 'w', encoding='utf8') as file_:
|
||||
self.strings.dump(file_)
|
||||
|
||||
state = (str_loc, lex_loc, vec_loc, self.morphology, self.get_lex_attr,
|
||||
self.serializer_freqs, self.data_dir)
|
||||
|
@ -250,8 +253,7 @@ cdef class Vocab:
|
|||
fp.write_from(&lexeme.l2_norm, sizeof(lexeme.l2_norm), 1)
|
||||
fp.close()
|
||||
|
||||
def load_lexemes(self, strings_loc, loc):
|
||||
self.strings.load(strings_loc)
|
||||
def load_lexemes(self, loc):
|
||||
if not path.exists(loc):
|
||||
raise IOError('LexemeCs file not found at %s' % loc)
|
||||
fp = CFile(loc, 'rb')
|
||||
|
@ -369,7 +371,9 @@ def unpickle_vocab(strings_loc, lex_loc, vec_loc, morphology, get_lex_attr,
|
|||
vocab.data_dir = data_dir
|
||||
vocab.serializer_freqs = serializer_freqs
|
||||
|
||||
vocab.load_lexemes(strings_loc, lex_loc)
|
||||
with io.open(strings_loc, 'r', encoding='utf8') as file_:
|
||||
vocab.strings.load(file_)
|
||||
vocab.load_lexemes(lex_loc)
|
||||
if vec_loc is not None:
|
||||
vocab.load_vectors_from_bin_loc(vec_loc)
|
||||
return vocab
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
# -*- coding: utf8 -*-
|
||||
from __future__ import unicode_literals
|
||||
import pickle
|
||||
import io
|
||||
|
||||
from spacy.strings import StringStore
|
||||
|
||||
import pytest
|
||||
|
||||
import io
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sstore():
|
||||
|
@ -92,4 +93,12 @@ def test_pickle_string_store(sstore):
|
|||
assert loaded[hello_id] == u'Hi'
|
||||
|
||||
|
||||
|
||||
def test_dump_load(sstore):
|
||||
id_ = sstore[u'qqqqq']
|
||||
loc = '/tmp/sstore.json'
|
||||
with io.open(loc, 'w', encoding='utf8') as file_:
|
||||
sstore.dump(file_)
|
||||
new_store = StringStore()
|
||||
with io.open(loc, 'r', encoding='utf8') as file_:
|
||||
new_store.load(file_)
|
||||
assert new_store[id_] == u'qqqqq'
|
||||
|
|
Loading…
Reference in New Issue