* Fix vector length error reporting, and ensure vec_len is returned

This commit is contained in:
Matthew Honnibal 2015-09-21 18:08:32 +10:00
parent ba4e563701
commit ac459278d1
1 changed files with 14 additions and 12 deletions

View File

@ -263,9 +263,10 @@ cdef class Vocab:
def load_vectors(self, loc):
if loc.endswith('bz2'):
self.load_vectors_bz2(loc)
vec_len = self.load_vectors_bz2(loc)
else:
self.load_vectors_bin(loc)
vec_len = self.load_vectors_bin(loc)
return vec_len
def load_vectors_bz2(self, loc):
cdef LexemeC* lexeme
@ -278,10 +279,8 @@ cdef class Vocab:
if vec_len == -1:
vec_len = len(pieces)
elif vec_len != len(pieces):
raise IOError(
"Error loading word vectors: all vectors must be same "
"length. Previous vector was length %d, vector on line %d "
"was length %d." % (vec_len, line_num, len(pieces)))
raise VectorReadError.mismatched_sizes(loc, line_num,
vec_len, len(pieces))
orth = self.strings[word_str]
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
lexeme.repvec = <float*>self.mem.alloc(len(pieces), sizeof(float))
@ -293,14 +292,14 @@ cdef class Vocab:
def load_vectors_bin(self, loc):
cdef CFile file_ = CFile(loc, b'rb')
cdef int32_t word_len
cdef int32_t vec_len
cdef int32_t vec_len = 0
cdef int32_t prev_vec_len = 0
cdef float* vec
cdef Address mem
cdef attr_t string_id
cdef bytes py_word
cdef vector[float*] vectors
cdef int i
cdef int line_num = 0
cdef Pool tmp_mem = Pool()
while True:
try:
@ -309,7 +308,8 @@ cdef class Vocab:
break
file_.read_into(&vec_len, sizeof(vec_len), 1)
if prev_vec_len != 0 and vec_len != prev_vec_len:
raise VectorReadError.mismatched_sizes(loc, vec_len, prev_vec_len)
raise VectorReadError.mismatched_sizes(loc, line_num,
vec_len, prev_vec_len)
if 0 >= vec_len >= MAX_VEC_SIZE:
raise VectorReadError.bad_size(loc, vec_len)
@ -321,8 +321,10 @@ cdef class Vocab:
vectors.push_back(EMPTY_VEC)
assert vec != NULL
vectors[string_id] = vec
line_num += 1
cdef LexemeC* lex
cdef size_t lex_addr
cdef int i
for orth, lex_addr in self._by_orth.items():
lex = <LexemeC*>lex_addr
if lex.lower < vectors.size():
@ -363,12 +365,12 @@ def write_binary_vectors(in_loc, out_loc):
class VectorReadError(Exception):
@classmethod
def mismatched_sizes(cls, loc, prev_size, curr_size):
def mismatched_sizes(cls, loc, line_num, prev_size, curr_size):
return cls(
"Error reading word vectors from %s.\n"
"Error reading word vectors from %s on line %d.\n"
"All vectors must be the same size.\n"
"Prev size: %d\n"
"Curr size: %d" % (loc, prev_size, curr_size))
"Curr size: %d" % (loc, line_num, prev_size, curr_size))
@classmethod
def bad_size(cls, loc, size):