diff --git a/spacy/vocab.pxd b/spacy/vocab.pxd index e491a48e3..55f69bb6c 100644 --- a/spacy/vocab.pxd +++ b/spacy/vocab.pxd @@ -34,6 +34,7 @@ cdef class Vocab: cdef public object data_dir cdef public object get_lex_attr cdef public object pos_tags + cdef public object serializer_freqs cdef const LexemeC* get(self, Pool mem, unicode string) except NULL cdef const LexemeC* get_by_orth(self, Pool mem, attr_t orth) except NULL diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index 939ea9db3..6f711260b 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -47,16 +47,17 @@ cdef class Vocab: tag_map = json.load(open(path.join(data_dir, 'tag_map.json'))) lemmatizer = Lemmatizer.from_dir(path.join(data_dir, '..')) - + serializer_freqs = json.load(open(path.join(data_dir, 'serializer.json'))) cdef Vocab self = cls(get_lex_attr=get_lex_attr, vectors=vectors, tag_map=tag_map, - lemmatizer=lemmatizer) + lemmatizer=lemmatizer, serializer_freqs=serializer_freqs) self.load_lexemes(path.join(data_dir, 'strings.txt'), path.join(data_dir, 'lexemes.bin')) if vectors is None and path.exists(path.join(data_dir, 'vec.bin')): self.repvec_length = self.load_rep_vectors(path.join(data_dir, 'vec.bin')) return self - def __init__(self, get_lex_attr=None, tag_map=None, vectors=None, lemmatizer=None): + def __init__(self, get_lex_attr=None, tag_map=None, vectors=None, lemmatizer=None, + serializer_freqs=None): if tag_map is None: tag_map = {} if lemmatizer is None: @@ -67,6 +68,7 @@ cdef class Vocab: self.strings = StringStore() self.get_lex_attr = get_lex_attr self.morphology = Morphology(self.strings, tag_map, lemmatizer) + self.serializer_freqs = serializer_freqs self.length = 1 self._serializer = None @@ -75,11 +77,7 @@ cdef class Vocab: def __get__(self): if self._serializer is None: freqs = [] - if self.data_dir is not None: - freqs_loc = path.join(self.data_dir, 'serializer.json') - if path.exists(freqs_loc): - freqs = json.load(open(freqs_loc)) - self._serializer = Packer(self, freqs) + self._serializer = Packer(self, self.serializer_freqs) return self._serializer def __len__(self):