diff --git a/spacy/vectors.pyx b/spacy/vectors.pyx index e569b46e7..b6ddf0818 100644 --- a/spacy/vectors.pyx +++ b/spacy/vectors.pyx @@ -18,6 +18,7 @@ cdef class Vectors: cdef readonly StringStore strings cdef public object key2row cdef public object keys + cdef public int i def __init__(self, strings, data_or_width): self.strings = StringStore() @@ -26,13 +27,12 @@ cdef class Vectors: dtype='f') else: data = data_or_width + self.i = 0 self.data = data self.key2row = {} self.keys = np.ndarray((self.data.shape[0],), dtype='uint64') - for i, string in enumerate(strings): - key = self.strings.add(string) - self.key2row[key] = i - self.keys[i] = key + for string in strings: + self.add_key(string) def __reduce__(self): return (Vectors, (self.strings, self.data)) @@ -56,21 +56,29 @@ cdef class Vectors: yield from self.data def __len__(self): - # TODO: Fix the quadratic behaviour here! - return max(self.key2row.values()) + return self.i def __contains__(self, key): if isinstance(key, basestring_): key = self.strings[key] return key in self.key2row - def add_key(self, string, vector=None): - key = self.strings.add(string) - next_i = len(self) + 1 - self.keys[next_i] = key - self.key2row[key] = next_i + def add(self, key, vector=None): + if isinstance(key, basestring_): + key = self.strings.add(key) + if key not in self.key2row: + i = self.i + if i >= self.keys.shape[0]: + self.keys.resize((self.keys.shape[0]*2,)) + self.data.resize((self.data.shape[0]*2, self.data.shape[1])) + self.key2row[key] = self.i + self.keys[self.i] = key + self.i += 1 + else: + i = self.key2row[key] if vector is not None: - self.data[next_i] = vector + self.data[i] = vector + return i def items(self): for i, string in enumerate(self.strings): @@ -139,5 +147,5 @@ cdef class Vectors: ('strings', lambda b: self.strings.from_bytes(b)), ('vectors', deserialize_weights) )) - util.from_bytes(deserializers, exclude) + util.from_bytes(data, deserializers, exclude) return self diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index 1fc3f5e39..1c992b56c 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -246,11 +246,13 @@ cdef class Vocab: def vectors_length(self): return len(self.vectors) - def clear_vectors(self): + def clear_vectors(self, new_dim=None): """Drop the current vector table. Because all vectors must be the same width, you have to call this to change the size of the vectors. """ - raise NotImplementedError + if new_dim is None: + new_dim = self.vectors.data.shape[1] + self.vectors = Vectors(self.strings, new_dim) def get_vector(self, orth): """Retrieve a vector for a word in the vocabulary. @@ -278,7 +280,7 @@ cdef class Vocab: """ if not isinstance(orth, basestring_): orth = self.strings[orth] - self.vectors.add_key(orth, vector=vector) + self.vectors.add(orth, vector=vector) def has_vector(self, orth): """Check whether a word has a vector. Returns False if no