mirror of https://github.com/explosion/spaCy.git
* Ensure vectors are same length, and return vector length in load_vectors_bz2
This commit is contained in:
parent
d00fe2bbc6
commit
ba4e563701
|
@ -270,16 +270,25 @@ cdef class Vocab:
|
||||||
def load_vectors_bz2(self, loc):
|
def load_vectors_bz2(self, loc):
|
||||||
cdef LexemeC* lexeme
|
cdef LexemeC* lexeme
|
||||||
cdef attr_t orth
|
cdef attr_t orth
|
||||||
|
cdef int32_t vec_len = -1
|
||||||
with bz2.BZ2File(loc, 'r') as file_:
|
with bz2.BZ2File(loc, 'r') as file_:
|
||||||
for line in file_:
|
for line_num, line in enumerate(file_):
|
||||||
pieces = line.split()
|
pieces = line.split()
|
||||||
word_str = pieces.pop(0)
|
word_str = pieces.pop(0)
|
||||||
|
if vec_len == -1:
|
||||||
|
vec_len = len(pieces)
|
||||||
|
elif vec_len != len(pieces):
|
||||||
|
raise IOError(
|
||||||
|
"Error loading word vectors: all vectors must be same "
|
||||||
|
"length. Previous vector was length %d, vector on line %d "
|
||||||
|
"was length %d." % (vec_len, line_num, len(pieces)))
|
||||||
orth = self.strings[word_str]
|
orth = self.strings[word_str]
|
||||||
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
|
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
|
||||||
lexeme.repvec = <float*>self.mem.alloc(len(pieces), sizeof(float))
|
lexeme.repvec = <float*>self.mem.alloc(len(pieces), sizeof(float))
|
||||||
|
|
||||||
for i, val_str in enumerate(pieces):
|
for i, val_str in enumerate(pieces):
|
||||||
lexeme.repvec[i] = float(val_str)
|
lexeme.repvec[i] = float(val_str)
|
||||||
|
return vec_len
|
||||||
|
|
||||||
def load_vectors_bin(self, loc):
|
def load_vectors_bin(self, loc):
|
||||||
cdef CFile file_ = CFile(loc, b'rb')
|
cdef CFile file_ = CFile(loc, b'rb')
|
||||||
|
|
Loading…
Reference in New Issue