* Ensure vectors are same length, and return vector length in load_vectors_bz2

This commit is contained in:
Matthew Honnibal 2015-09-21 18:03:08 +10:00
parent d00fe2bbc6
commit ba4e563701
1 changed files with 10 additions and 1 deletions

View File

@ -270,16 +270,25 @@ cdef class Vocab:
def load_vectors_bz2(self, loc): def load_vectors_bz2(self, loc):
cdef LexemeC* lexeme cdef LexemeC* lexeme
cdef attr_t orth cdef attr_t orth
cdef int32_t vec_len = -1
with bz2.BZ2File(loc, 'r') as file_: with bz2.BZ2File(loc, 'r') as file_:
for line in file_: for line_num, line in enumerate(file_):
pieces = line.split() pieces = line.split()
word_str = pieces.pop(0) word_str = pieces.pop(0)
if vec_len == -1:
vec_len = len(pieces)
elif vec_len != len(pieces):
raise IOError(
"Error loading word vectors: all vectors must be same "
"length. Previous vector was length %d, vector on line %d "
"was length %d." % (vec_len, line_num, len(pieces)))
orth = self.strings[word_str] orth = self.strings[word_str]
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth) lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
lexeme.repvec = <float*>self.mem.alloc(len(pieces), sizeof(float)) lexeme.repvec = <float*>self.mem.alloc(len(pieces), sizeof(float))
for i, val_str in enumerate(pieces): for i, val_str in enumerate(pieces):
lexeme.repvec[i] = float(val_str) lexeme.repvec[i] = float(val_str)
return vec_len
def load_vectors_bin(self, loc): def load_vectors_bin(self, loc):
cdef CFile file_ = CFile(loc, b'rb') cdef CFile file_ = CFile(loc, b'rb')