mirror of https://github.com/explosion/spaCy.git
Improve vector handling
This commit is contained in:
parent
ef87562741
commit
1157294434
|
@ -18,6 +18,7 @@ cdef class Vectors:
|
|||
cdef readonly StringStore strings
|
||||
cdef public object key2row
|
||||
cdef public object keys
|
||||
cdef public int i
|
||||
|
||||
def __init__(self, strings, data_or_width):
|
||||
self.strings = StringStore()
|
||||
|
@ -26,13 +27,12 @@ cdef class Vectors:
|
|||
dtype='f')
|
||||
else:
|
||||
data = data_or_width
|
||||
self.i = 0
|
||||
self.data = data
|
||||
self.key2row = {}
|
||||
self.keys = np.ndarray((self.data.shape[0],), dtype='uint64')
|
||||
for i, string in enumerate(strings):
|
||||
key = self.strings.add(string)
|
||||
self.key2row[key] = i
|
||||
self.keys[i] = key
|
||||
for string in strings:
|
||||
self.add_key(string)
|
||||
|
||||
def __reduce__(self):
|
||||
return (Vectors, (self.strings, self.data))
|
||||
|
@ -56,21 +56,29 @@ cdef class Vectors:
|
|||
yield from self.data
|
||||
|
||||
def __len__(self):
|
||||
# TODO: Fix the quadratic behaviour here!
|
||||
return max(self.key2row.values())
|
||||
return self.i
|
||||
|
||||
def __contains__(self, key):
|
||||
if isinstance(key, basestring_):
|
||||
key = self.strings[key]
|
||||
return key in self.key2row
|
||||
|
||||
def add_key(self, string, vector=None):
|
||||
key = self.strings.add(string)
|
||||
next_i = len(self) + 1
|
||||
self.keys[next_i] = key
|
||||
self.key2row[key] = next_i
|
||||
def add(self, key, vector=None):
|
||||
if isinstance(key, basestring_):
|
||||
key = self.strings.add(key)
|
||||
if key not in self.key2row:
|
||||
i = self.i
|
||||
if i >= self.keys.shape[0]:
|
||||
self.keys.resize((self.keys.shape[0]*2,))
|
||||
self.data.resize((self.data.shape[0]*2, self.data.shape[1]))
|
||||
self.key2row[key] = self.i
|
||||
self.keys[self.i] = key
|
||||
self.i += 1
|
||||
else:
|
||||
i = self.key2row[key]
|
||||
if vector is not None:
|
||||
self.data[next_i] = vector
|
||||
self.data[i] = vector
|
||||
return i
|
||||
|
||||
def items(self):
|
||||
for i, string in enumerate(self.strings):
|
||||
|
@ -139,5 +147,5 @@ cdef class Vectors:
|
|||
('strings', lambda b: self.strings.from_bytes(b)),
|
||||
('vectors', deserialize_weights)
|
||||
))
|
||||
util.from_bytes(deserializers, exclude)
|
||||
util.from_bytes(data, deserializers, exclude)
|
||||
return self
|
||||
|
|
|
@ -246,11 +246,13 @@ cdef class Vocab:
|
|||
def vectors_length(self):
|
||||
return len(self.vectors)
|
||||
|
||||
def clear_vectors(self):
|
||||
def clear_vectors(self, new_dim=None):
|
||||
"""Drop the current vector table. Because all vectors must be the same
|
||||
width, you have to call this to change the size of the vectors.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
if new_dim is None:
|
||||
new_dim = self.vectors.data.shape[1]
|
||||
self.vectors = Vectors(self.strings, new_dim)
|
||||
|
||||
def get_vector(self, orth):
|
||||
"""Retrieve a vector for a word in the vocabulary.
|
||||
|
@ -278,7 +280,7 @@ cdef class Vocab:
|
|||
"""
|
||||
if not isinstance(orth, basestring_):
|
||||
orth = self.strings[orth]
|
||||
self.vectors.add_key(orth, vector=vector)
|
||||
self.vectors.add(orth, vector=vector)
|
||||
|
||||
def has_vector(self, orth):
|
||||
"""Check whether a word has a vector. Returns False if no
|
||||
|
|
Loading…
Reference in New Issue