From c04e6ebca6a7819314a363a670c87b51822fa99a Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 5 Jun 2015 16:26:39 +0200 Subject: [PATCH] * Allow user to load different sized vectors. --- spacy/lexeme.pxd | 6 +++--- spacy/tokens.pyx | 2 +- spacy/vocab.pxd | 1 + spacy/vocab.pyx | 31 +++++++++++++++++++++++++++---- 4 files changed, 32 insertions(+), 8 deletions(-) diff --git a/spacy/lexeme.pxd b/spacy/lexeme.pxd index 87354d532..bdf8dedf3 100644 --- a/spacy/lexeme.pxd +++ b/spacy/lexeme.pxd @@ -42,9 +42,9 @@ cdef class Lexeme: # Workaround for an apparent bug in the way the decorator is handled --- # TODO: post bug report / patch to Cython. @staticmethod - cdef inline Lexeme from_ptr(const LexemeC* ptr, StringStore strings): - cdef Lexeme py = Lexeme.__new__(Lexeme, 300) - for i in range(300): + cdef inline Lexeme from_ptr(const LexemeC* ptr, StringStore strings, int repvec_length): + cdef Lexeme py = Lexeme.__new__(Lexeme, repvec_length) + for i in range(repvec_length): py.repvec[i] = ptr.repvec[i] py.l2_norm = ptr.l2_norm py.flags = ptr.flags diff --git a/spacy/tokens.pyx b/spacy/tokens.pyx index 5c4aabd63..3ee559dcf 100644 --- a/spacy/tokens.pyx +++ b/spacy/tokens.pyx @@ -464,7 +464,7 @@ cdef class Token: property repvec: def __get__(self): - return numpy.asarray( self.c.lex.repvec) + return numpy.asarray( self.c.lex.repvec) property n_lefts: def __get__(self): diff --git a/spacy/vocab.pxd b/spacy/vocab.pxd index 092bedda7..7f6d8bede 100644 --- a/spacy/vocab.pxd +++ b/spacy/vocab.pxd @@ -33,3 +33,4 @@ cdef class Vocab: cdef int _add_lex_to_vocab(self, hash_t key, const LexemeC* lex) except -1 cdef PreshMap _map + cdef readonly int repvec_length diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index 512106757..c93e4202f 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -36,7 +36,7 @@ cdef class Vocab: self.strings = StringStore() self.lexemes.push_back(&EMPTY_LEXEME) self.lexeme_props_getter = get_lex_props - + self.repvec_length = 0 if data_dir is not None: if not path.exists(data_dir): raise IOError("Directory %s not found -- cannot load Vocab." % data_dir) @@ -46,7 +46,7 @@ cdef class Vocab: self.load_lexemes(path.join(data_dir, 'strings.txt'), path.join(data_dir, 'lexemes.bin')) if load_vectors and path.exists(path.join(data_dir, 'vec.bin')): - self.load_rep_vectors(path.join(data_dir, 'vec.bin')) + self.repvec_length = self.load_rep_vectors(path.join(data_dir, 'vec.bin')) def __len__(self): """The current number of lexemes stored.""" @@ -107,7 +107,7 @@ cdef class Vocab: raise ValueError("Vocab unable to map type: " "%s. Maps unicode --> Lexeme or " "int --> Lexeme" % str(type(id_or_string))) - return Lexeme.from_ptr(lexeme, self.strings) + return Lexeme.from_ptr(lexeme, self.strings, self.repvec_length) def __setitem__(self, unicode py_str, dict props): cdef UniStr c_str @@ -180,6 +180,7 @@ cdef class Vocab: file_ = _CFile(loc, b'rb') cdef int32_t word_len cdef int32_t vec_len + cdef int32_t prev_vec_len = 0 cdef float* vec cdef Address mem cdef id_t string_id @@ -192,7 +193,10 @@ cdef class Vocab: except IOError: break file_.read(&vec_len, sizeof(vec_len), 1) - + if prev_vec_len != 0 and vec_len != prev_vec_len: + raise VectorReadError.mismatched_sizes(loc, vec_len, prev_vec_len) + if 0 >= vec_len >= MAX_VEC_SIZE: + raise VectorReadError.bad_size(loc, vec_len) mem = Address(word_len, sizeof(char)) chars = mem.ptr vec = self.mem.alloc(vec_len, sizeof(float)) @@ -216,6 +220,7 @@ cdef class Vocab: lex.l2_norm = math.sqrt(lex.l2_norm) else: lex.repvec = EMPTY_VEC + return vec_len def write_binary_vectors(in_loc, out_loc): @@ -272,3 +277,21 @@ cdef class _CFile: cdef bytes py_bytes = value.encode('utf8') cdef char* chars = py_bytes self.write(sizeof(char), len(py_bytes), chars) + + +class VectorReadError(Exception): + @classmethod + def mismatched_sizes(cls, loc, prev_size, curr_size): + return cls( + "Error reading word vectors from %s.\n" + "All vectors must be the same size.\n" + "Prev size: %d\n" + "Curr size: %d" % (loc, prev_size, curr_size)) + + @classmethod + def bad_size(cls, loc, size): + return cls( + "Error reading word vectors from %s.\n" + "Vector size: %d\n" + "Max size: %d\n" + "Min size: 1\n" % (loc, size, MAX_VEC_SIZE))