From abf0d930af5b79225a8d7e15b68ddb03952c0e51 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 23 Sep 2015 23:51:08 +1000 Subject: [PATCH] * Fix API for loading word vectors from a file. --- spacy/vocab.pyx | 40 ++++++++++++++++------------------------ 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index e3ac67bf7..0e35cdd6d 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -55,7 +55,7 @@ cdef class Vocab: self.load_lexemes(path.join(data_dir, 'strings.txt'), path.join(data_dir, 'lexemes.bin')) if path.exists(path.join(data_dir, 'vec.bin')): - self.vectors_length = self.load_vectors(path.join(data_dir, 'vec.bin')) + self.vectors_length = self.load_vectors_from_bin_loc(path.join(data_dir, 'vec.bin')) return self def __init__(self, get_lex_attr=None, tag_map=None, lemmatizer=None, serializer_freqs=None): @@ -258,35 +258,27 @@ cdef class Vocab: i += 1 fp.close() - def load_vectors(self, loc): - if loc.endswith('bz2'): - vec_len = self.load_vectors_bz2(loc) - else: - vec_len = self.load_vectors_bin(loc) - return vec_len - - def load_vectors_bz2(self, loc): + def load_vectors(self, loc_or_file): cdef LexemeC* lexeme cdef attr_t orth cdef int32_t vec_len = -1 - with bz2.BZ2File(loc, 'r') as file_: - for line_num, line in enumerate(file_): - pieces = line.split() - word_str = pieces.pop(0) - if vec_len == -1: - vec_len = len(pieces) - elif vec_len != len(pieces): - raise VectorReadError.mismatched_sizes(loc, line_num, - vec_len, len(pieces)) - orth = self.strings[word_str] - lexeme = self.get_by_orth(self.mem, orth) - lexeme.repvec = self.mem.alloc(self.vectors_length, sizeof(float)) + for line_num, line in enumerate(loc_or_file): + pieces = line.split() + word_str = pieces.pop(0) + if vec_len == -1: + vec_len = len(pieces) + elif vec_len != len(pieces): + raise VectorReadError.mismatched_sizes(loc_or_file, line_num, + vec_len, len(pieces)) + orth = self.strings[word_str] + lexeme = self.get_by_orth(self.mem, orth) + lexeme.repvec = self.mem.alloc(self.vectors_length, sizeof(float)) - for i, val_str in enumerate(pieces): - lexeme.repvec[i] = float(val_str) + for i, val_str in enumerate(pieces): + lexeme.repvec[i] = float(val_str) return vec_len - def load_vectors_bin(self, loc): + def load_vectors_from_bin_loc(self, loc): cdef CFile file_ = CFile(loc, b'rb') cdef int32_t word_len cdef int32_t vec_len = 0