mirror of https://github.com/explosion/spaCy.git
Improve vector handling
This commit is contained in:
parent
ef87562741
commit
1157294434
|
@ -18,6 +18,7 @@ cdef class Vectors:
|
||||||
cdef readonly StringStore strings
|
cdef readonly StringStore strings
|
||||||
cdef public object key2row
|
cdef public object key2row
|
||||||
cdef public object keys
|
cdef public object keys
|
||||||
|
cdef public int i
|
||||||
|
|
||||||
def __init__(self, strings, data_or_width):
|
def __init__(self, strings, data_or_width):
|
||||||
self.strings = StringStore()
|
self.strings = StringStore()
|
||||||
|
@ -26,13 +27,12 @@ cdef class Vectors:
|
||||||
dtype='f')
|
dtype='f')
|
||||||
else:
|
else:
|
||||||
data = data_or_width
|
data = data_or_width
|
||||||
|
self.i = 0
|
||||||
self.data = data
|
self.data = data
|
||||||
self.key2row = {}
|
self.key2row = {}
|
||||||
self.keys = np.ndarray((self.data.shape[0],), dtype='uint64')
|
self.keys = np.ndarray((self.data.shape[0],), dtype='uint64')
|
||||||
for i, string in enumerate(strings):
|
for string in strings:
|
||||||
key = self.strings.add(string)
|
self.add_key(string)
|
||||||
self.key2row[key] = i
|
|
||||||
self.keys[i] = key
|
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return (Vectors, (self.strings, self.data))
|
return (Vectors, (self.strings, self.data))
|
||||||
|
@ -56,21 +56,29 @@ cdef class Vectors:
|
||||||
yield from self.data
|
yield from self.data
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
# TODO: Fix the quadratic behaviour here!
|
return self.i
|
||||||
return max(self.key2row.values())
|
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
if isinstance(key, basestring_):
|
if isinstance(key, basestring_):
|
||||||
key = self.strings[key]
|
key = self.strings[key]
|
||||||
return key in self.key2row
|
return key in self.key2row
|
||||||
|
|
||||||
def add_key(self, string, vector=None):
|
def add(self, key, vector=None):
|
||||||
key = self.strings.add(string)
|
if isinstance(key, basestring_):
|
||||||
next_i = len(self) + 1
|
key = self.strings.add(key)
|
||||||
self.keys[next_i] = key
|
if key not in self.key2row:
|
||||||
self.key2row[key] = next_i
|
i = self.i
|
||||||
|
if i >= self.keys.shape[0]:
|
||||||
|
self.keys.resize((self.keys.shape[0]*2,))
|
||||||
|
self.data.resize((self.data.shape[0]*2, self.data.shape[1]))
|
||||||
|
self.key2row[key] = self.i
|
||||||
|
self.keys[self.i] = key
|
||||||
|
self.i += 1
|
||||||
|
else:
|
||||||
|
i = self.key2row[key]
|
||||||
if vector is not None:
|
if vector is not None:
|
||||||
self.data[next_i] = vector
|
self.data[i] = vector
|
||||||
|
return i
|
||||||
|
|
||||||
def items(self):
|
def items(self):
|
||||||
for i, string in enumerate(self.strings):
|
for i, string in enumerate(self.strings):
|
||||||
|
@ -139,5 +147,5 @@ cdef class Vectors:
|
||||||
('strings', lambda b: self.strings.from_bytes(b)),
|
('strings', lambda b: self.strings.from_bytes(b)),
|
||||||
('vectors', deserialize_weights)
|
('vectors', deserialize_weights)
|
||||||
))
|
))
|
||||||
util.from_bytes(deserializers, exclude)
|
util.from_bytes(data, deserializers, exclude)
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -246,11 +246,13 @@ cdef class Vocab:
|
||||||
def vectors_length(self):
|
def vectors_length(self):
|
||||||
return len(self.vectors)
|
return len(self.vectors)
|
||||||
|
|
||||||
def clear_vectors(self):
|
def clear_vectors(self, new_dim=None):
|
||||||
"""Drop the current vector table. Because all vectors must be the same
|
"""Drop the current vector table. Because all vectors must be the same
|
||||||
width, you have to call this to change the size of the vectors.
|
width, you have to call this to change the size of the vectors.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
if new_dim is None:
|
||||||
|
new_dim = self.vectors.data.shape[1]
|
||||||
|
self.vectors = Vectors(self.strings, new_dim)
|
||||||
|
|
||||||
def get_vector(self, orth):
|
def get_vector(self, orth):
|
||||||
"""Retrieve a vector for a word in the vocabulary.
|
"""Retrieve a vector for a word in the vocabulary.
|
||||||
|
@ -278,7 +280,7 @@ cdef class Vocab:
|
||||||
"""
|
"""
|
||||||
if not isinstance(orth, basestring_):
|
if not isinstance(orth, basestring_):
|
||||||
orth = self.strings[orth]
|
orth = self.strings[orth]
|
||||||
self.vectors.add_key(orth, vector=vector)
|
self.vectors.add(orth, vector=vector)
|
||||||
|
|
||||||
def has_vector(self, orth):
|
def has_vector(self, orth):
|
||||||
"""Check whether a word has a vector. Returns False if no
|
"""Check whether a word has a vector. Returns False if no
|
||||||
|
|
Loading…
Reference in New Issue