mirror of https://github.com/explosion/spaCy.git
* Fix vector length error reporting, and ensure vec_len is returned
This commit is contained in:
parent
ba4e563701
commit
ac459278d1
|
@ -263,9 +263,10 @@ cdef class Vocab:
|
|||
|
||||
def load_vectors(self, loc):
|
||||
if loc.endswith('bz2'):
|
||||
self.load_vectors_bz2(loc)
|
||||
vec_len = self.load_vectors_bz2(loc)
|
||||
else:
|
||||
self.load_vectors_bin(loc)
|
||||
vec_len = self.load_vectors_bin(loc)
|
||||
return vec_len
|
||||
|
||||
def load_vectors_bz2(self, loc):
|
||||
cdef LexemeC* lexeme
|
||||
|
@ -278,10 +279,8 @@ cdef class Vocab:
|
|||
if vec_len == -1:
|
||||
vec_len = len(pieces)
|
||||
elif vec_len != len(pieces):
|
||||
raise IOError(
|
||||
"Error loading word vectors: all vectors must be same "
|
||||
"length. Previous vector was length %d, vector on line %d "
|
||||
"was length %d." % (vec_len, line_num, len(pieces)))
|
||||
raise VectorReadError.mismatched_sizes(loc, line_num,
|
||||
vec_len, len(pieces))
|
||||
orth = self.strings[word_str]
|
||||
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
|
||||
lexeme.repvec = <float*>self.mem.alloc(len(pieces), sizeof(float))
|
||||
|
@ -293,14 +292,14 @@ cdef class Vocab:
|
|||
def load_vectors_bin(self, loc):
|
||||
cdef CFile file_ = CFile(loc, b'rb')
|
||||
cdef int32_t word_len
|
||||
cdef int32_t vec_len
|
||||
cdef int32_t vec_len = 0
|
||||
cdef int32_t prev_vec_len = 0
|
||||
cdef float* vec
|
||||
cdef Address mem
|
||||
cdef attr_t string_id
|
||||
cdef bytes py_word
|
||||
cdef vector[float*] vectors
|
||||
cdef int i
|
||||
cdef int line_num = 0
|
||||
cdef Pool tmp_mem = Pool()
|
||||
while True:
|
||||
try:
|
||||
|
@ -309,7 +308,8 @@ cdef class Vocab:
|
|||
break
|
||||
file_.read_into(&vec_len, sizeof(vec_len), 1)
|
||||
if prev_vec_len != 0 and vec_len != prev_vec_len:
|
||||
raise VectorReadError.mismatched_sizes(loc, vec_len, prev_vec_len)
|
||||
raise VectorReadError.mismatched_sizes(loc, line_num,
|
||||
vec_len, prev_vec_len)
|
||||
if 0 >= vec_len >= MAX_VEC_SIZE:
|
||||
raise VectorReadError.bad_size(loc, vec_len)
|
||||
|
||||
|
@ -321,8 +321,10 @@ cdef class Vocab:
|
|||
vectors.push_back(EMPTY_VEC)
|
||||
assert vec != NULL
|
||||
vectors[string_id] = vec
|
||||
line_num += 1
|
||||
cdef LexemeC* lex
|
||||
cdef size_t lex_addr
|
||||
cdef int i
|
||||
for orth, lex_addr in self._by_orth.items():
|
||||
lex = <LexemeC*>lex_addr
|
||||
if lex.lower < vectors.size():
|
||||
|
@ -363,12 +365,12 @@ def write_binary_vectors(in_loc, out_loc):
|
|||
|
||||
class VectorReadError(Exception):
|
||||
@classmethod
|
||||
def mismatched_sizes(cls, loc, prev_size, curr_size):
|
||||
def mismatched_sizes(cls, loc, line_num, prev_size, curr_size):
|
||||
return cls(
|
||||
"Error reading word vectors from %s.\n"
|
||||
"Error reading word vectors from %s on line %d.\n"
|
||||
"All vectors must be the same size.\n"
|
||||
"Prev size: %d\n"
|
||||
"Curr size: %d" % (loc, prev_size, curr_size))
|
||||
"Curr size: %d" % (loc, line_num, prev_size, curr_size))
|
||||
|
||||
@classmethod
|
||||
def bad_size(cls, loc, size):
|
||||
|
|
Loading…
Reference in New Issue