Fix vectors.add()

This commit is contained in:
Explosion Bot 2017-10-30 16:08:09 +01:00
parent 41d0f1665a
commit ab5d5ed880
1 changed files with 15 additions and 17 deletions

View File

@ -56,10 +56,11 @@ cdef class Vectors:
self.i = 0
self.key2row = {}
self.keys = numpy.zeros((self.data.shape[0],), dtype='uint64')
for i, string in enumerate(self.strings):
if i >= self.data.shape[0]:
break
self.add(self.strings[string], vector=self.data[i])
if data is not None:
for i, string in enumerate(self.strings):
if i >= self.data.shape[0]:
break
self.add(self.strings[string], vector=self.data[i])
def __reduce__(self):
return (Vectors, (self.strings, self.data))
@ -124,25 +125,22 @@ cdef class Vectors:
vector (numpy.ndarray / None): A vector to add for the key.
row (int / None): The row-number of a vector to map the key to.
"""
if row is not None and vector is not None:
raise ValueError("Only one of 'row' and 'vector' may be set")
if isinstance(key, basestring_):
key = self.strings.add(key)
if key in self.key2row and vector is not None:
if key in self.key2row and row is None:
row = self.key2row[key]
elif key in self.key2row and row is not None:
self.key2row[key] = row
elif key not in self.key2row:
if row is not None:
self.key2row[key] = row
else:
self.key2row[key] = self.i
row = self.i
if row >= self.keys.shape[0]:
self.keys.resize((row*2,))
self.data.resize((row*2, self.data.shape[1]))
self.keys[self.i] = key
elif row is None:
row = self.i
self.i += 1
if row >= self.keys.shape[0]:
self.keys.resize((row*2,))
self.data.resize((row*2, self.data.shape[1]))
self.keys[self.i] = key
self.key2row[key] = row
self.keys[row] = key
if vector is not None:
self.data[row] = vector
return row