* Allow user to load different sized vectors.

This commit is contained in:
Matthew Honnibal 2015-06-05 16:26:39 +02:00
parent 0aed9c9a33
commit c04e6ebca6
4 changed files with 32 additions and 8 deletions

View File

@ -42,9 +42,9 @@ cdef class Lexeme:
# Workaround for an apparent bug in the way the decorator is handled ---
# TODO: post bug report / patch to Cython.
@staticmethod
cdef inline Lexeme from_ptr(const LexemeC* ptr, StringStore strings):
cdef Lexeme py = Lexeme.__new__(Lexeme, 300)
for i in range(300):
cdef inline Lexeme from_ptr(const LexemeC* ptr, StringStore strings, int repvec_length):
cdef Lexeme py = Lexeme.__new__(Lexeme, repvec_length)
for i in range(repvec_length):
py.repvec[i] = ptr.repvec[i]
py.l2_norm = ptr.l2_norm
py.flags = ptr.flags

View File

@ -464,7 +464,7 @@ cdef class Token:
property repvec:
def __get__(self):
return numpy.asarray(<float[:300,]> self.c.lex.repvec)
return numpy.asarray(<float[:self.vocab.repvec_length,]> self.c.lex.repvec)
property n_lefts:
def __get__(self):

View File

@ -33,3 +33,4 @@ cdef class Vocab:
cdef int _add_lex_to_vocab(self, hash_t key, const LexemeC* lex) except -1
cdef PreshMap _map
cdef readonly int repvec_length

View File

@ -36,7 +36,7 @@ cdef class Vocab:
self.strings = StringStore()
self.lexemes.push_back(&EMPTY_LEXEME)
self.lexeme_props_getter = get_lex_props
self.repvec_length = 0
if data_dir is not None:
if not path.exists(data_dir):
raise IOError("Directory %s not found -- cannot load Vocab." % data_dir)
@ -46,7 +46,7 @@ cdef class Vocab:
self.load_lexemes(path.join(data_dir, 'strings.txt'),
path.join(data_dir, 'lexemes.bin'))
if load_vectors and path.exists(path.join(data_dir, 'vec.bin')):
self.load_rep_vectors(path.join(data_dir, 'vec.bin'))
self.repvec_length = self.load_rep_vectors(path.join(data_dir, 'vec.bin'))
def __len__(self):
"""The current number of lexemes stored."""
@ -107,7 +107,7 @@ cdef class Vocab:
raise ValueError("Vocab unable to map type: "
"%s. Maps unicode --> Lexeme or "
"int --> Lexeme" % str(type(id_or_string)))
return Lexeme.from_ptr(lexeme, self.strings)
return Lexeme.from_ptr(lexeme, self.strings, self.repvec_length)
def __setitem__(self, unicode py_str, dict props):
cdef UniStr c_str
@ -180,6 +180,7 @@ cdef class Vocab:
file_ = _CFile(loc, b'rb')
cdef int32_t word_len
cdef int32_t vec_len
cdef int32_t prev_vec_len = 0
cdef float* vec
cdef Address mem
cdef id_t string_id
@ -192,7 +193,10 @@ cdef class Vocab:
except IOError:
break
file_.read(&vec_len, sizeof(vec_len), 1)
if prev_vec_len != 0 and vec_len != prev_vec_len:
raise VectorReadError.mismatched_sizes(loc, vec_len, prev_vec_len)
if 0 >= vec_len >= MAX_VEC_SIZE:
raise VectorReadError.bad_size(loc, vec_len)
mem = Address(word_len, sizeof(char))
chars = <char*>mem.ptr
vec = <float*>self.mem.alloc(vec_len, sizeof(float))
@ -216,6 +220,7 @@ cdef class Vocab:
lex.l2_norm = math.sqrt(lex.l2_norm)
else:
lex.repvec = EMPTY_VEC
return vec_len
def write_binary_vectors(in_loc, out_loc):
@ -272,3 +277,21 @@ cdef class _CFile:
cdef bytes py_bytes = value.encode('utf8')
cdef char* chars = <char*>py_bytes
self.write(sizeof(char), len(py_bytes), chars)
class VectorReadError(Exception):
@classmethod
def mismatched_sizes(cls, loc, prev_size, curr_size):
return cls(
"Error reading word vectors from %s.\n"
"All vectors must be the same size.\n"
"Prev size: %d\n"
"Curr size: %d" % (loc, prev_size, curr_size))
@classmethod
def bad_size(cls, loc, size):
return cls(
"Error reading word vectors from %s.\n"
"Vector size: %d\n"
"Max size: %d\n"
"Min size: 1\n" % (loc, size, MAX_VEC_SIZE))