Improve vector handling

This commit is contained in:
Matthew Honnibal 2017-08-19 20:35:33 +02:00
parent ef87562741
commit 1157294434
2 changed files with 26 additions and 16 deletions

View File

@ -18,6 +18,7 @@ cdef class Vectors:
cdef readonly StringStore strings cdef readonly StringStore strings
cdef public object key2row cdef public object key2row
cdef public object keys cdef public object keys
cdef public int i
def __init__(self, strings, data_or_width): def __init__(self, strings, data_or_width):
self.strings = StringStore() self.strings = StringStore()
@ -26,13 +27,12 @@ cdef class Vectors:
dtype='f') dtype='f')
else: else:
data = data_or_width data = data_or_width
self.i = 0
self.data = data self.data = data
self.key2row = {} self.key2row = {}
self.keys = np.ndarray((self.data.shape[0],), dtype='uint64') self.keys = np.ndarray((self.data.shape[0],), dtype='uint64')
for i, string in enumerate(strings): for string in strings:
key = self.strings.add(string) self.add_key(string)
self.key2row[key] = i
self.keys[i] = key
def __reduce__(self): def __reduce__(self):
return (Vectors, (self.strings, self.data)) return (Vectors, (self.strings, self.data))
@ -56,21 +56,29 @@ cdef class Vectors:
yield from self.data yield from self.data
def __len__(self): def __len__(self):
# TODO: Fix the quadratic behaviour here! return self.i
return max(self.key2row.values())
def __contains__(self, key): def __contains__(self, key):
if isinstance(key, basestring_): if isinstance(key, basestring_):
key = self.strings[key] key = self.strings[key]
return key in self.key2row return key in self.key2row
def add_key(self, string, vector=None): def add(self, key, vector=None):
key = self.strings.add(string) if isinstance(key, basestring_):
next_i = len(self) + 1 key = self.strings.add(key)
self.keys[next_i] = key if key not in self.key2row:
self.key2row[key] = next_i i = self.i
if i >= self.keys.shape[0]:
self.keys.resize((self.keys.shape[0]*2,))
self.data.resize((self.data.shape[0]*2, self.data.shape[1]))
self.key2row[key] = self.i
self.keys[self.i] = key
self.i += 1
else:
i = self.key2row[key]
if vector is not None: if vector is not None:
self.data[next_i] = vector self.data[i] = vector
return i
def items(self): def items(self):
for i, string in enumerate(self.strings): for i, string in enumerate(self.strings):
@ -139,5 +147,5 @@ cdef class Vectors:
('strings', lambda b: self.strings.from_bytes(b)), ('strings', lambda b: self.strings.from_bytes(b)),
('vectors', deserialize_weights) ('vectors', deserialize_weights)
)) ))
util.from_bytes(deserializers, exclude) util.from_bytes(data, deserializers, exclude)
return self return self

View File

@ -246,11 +246,13 @@ cdef class Vocab:
def vectors_length(self): def vectors_length(self):
return len(self.vectors) return len(self.vectors)
def clear_vectors(self): def clear_vectors(self, new_dim=None):
"""Drop the current vector table. Because all vectors must be the same """Drop the current vector table. Because all vectors must be the same
width, you have to call this to change the size of the vectors. width, you have to call this to change the size of the vectors.
""" """
raise NotImplementedError if new_dim is None:
new_dim = self.vectors.data.shape[1]
self.vectors = Vectors(self.strings, new_dim)
def get_vector(self, orth): def get_vector(self, orth):
"""Retrieve a vector for a word in the vocabulary. """Retrieve a vector for a word in the vocabulary.
@ -278,7 +280,7 @@ cdef class Vocab:
""" """
if not isinstance(orth, basestring_): if not isinstance(orth, basestring_):
orth = self.strings[orth] orth = self.strings[orth]
self.vectors.add_key(orth, vector=vector) self.vectors.add(orth, vector=vector)
def has_vector(self, orth): def has_vector(self, orth):
"""Check whether a word has a vector. Returns False if no """Check whether a word has a vector. Returns False if no