Try using tensor for vector/similarity methdos

This commit is contained in:
Matthew Honnibal 2017-05-30 23:35:17 +02:00
parent a131981f3b
commit 498ad85309
2 changed files with 58 additions and 37 deletions

View File

@ -30,6 +30,7 @@ from ..syntax.iterators import CHUNKERS
from ..util import normalize_slice
from ..compat import is_config
from .. import about
from .. import util
DEF PADDING = 5
@ -252,8 +253,12 @@ cdef class Doc:
def __get__(self):
if 'has_vector' in self.user_hooks:
return self.user_hooks['has_vector'](self)
return any(token.has_vector for token in self)
elif any(token.has_vector for token in self):
return True
elif self.tensor:
return True
else:
return False
property vector:
"""A real-valued meaning representation. Defaults to an average of the
@ -265,12 +270,16 @@ cdef class Doc:
def __get__(self):
if 'vector' in self.user_hooks:
return self.user_hooks['vector'](self)
if self._vector is None:
if len(self):
self._vector = sum(t.vector for t in self) / len(self)
else:
return numpy.zeros((self.vocab.vectors_length,), dtype='float32')
return self._vector
if self._vector is not None:
return self._vector
elif self.has_vector and len(self):
self._vector = sum(t.vector for t in self) / len(self)
return self._vector
elif self.tensor:
self._vector = self.tensor.mean(axis=0)
return self._vector
else:
return numpy.zeros((self.vocab.vectors_length,), dtype='float32')
def __set__(self, value):
self._vector = value
@ -295,10 +304,6 @@ cdef class Doc:
def __set__(self, value):
self._vector_norm = value
@property
def string(self):
return self.text
property text:
"""A unicode representation of the document text.
@ -598,15 +603,16 @@ cdef class Doc:
self.is_tagged = bool(TAG in attrs or POS in attrs)
return self
def to_disk(self, path):
def to_disk(self, path, **exclude):
"""Save the current state to a directory.
path (unicode or Path): A path to a directory, which will be created if
it doesn't exist. Paths may be either strings or `Path`-like objects.
"""
raise NotImplementedError()
with path.open('wb') as file_:
file_.write(self.to_bytes(**exclude))
def from_disk(self, path):
def from_disk(self, path, **exclude):
"""Loads state from a directory. Modifies the object in place and
returns it.
@ -614,25 +620,28 @@ cdef class Doc:
strings or `Path`-like objects.
RETURNS (Doc): The modified `Doc` object.
"""
raise NotImplementedError()
with path.open('rb') as file_:
bytes_data = file_.read()
self.from_bytes(bytes_data, **exclude)
def to_bytes(self):
def to_bytes(self, **exclude):
"""Serialize, i.e. export the document contents to a binary string.
RETURNS (bytes): A losslessly serialized copy of the `Doc`, including
all annotations.
"""
return dill.dumps(
(self.text,
self.to_array([LENGTH,SPACY,TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE]),
self.sentiment,
self.tensor,
self.noun_chunks_iterator,
self.user_data,
(self.user_hooks, self.user_token_hooks, self.user_span_hooks)),
protocol=-1)
array_head = [LENGTH,SPACY,TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE]
serializers = {
'text': lambda: self.text,
'array_head': lambda: array_head,
'array_body': lambda: self.to_array(array_head),
'sentiment': lambda: self.sentiment,
'tensor': lambda: self.tensor,
'user_data': lambda: self.user_data
}
return util.to_bytes(serializers, exclude)
def from_bytes(self, data):
def from_bytes(self, bytes_data, **exclude):
"""Deserialize, i.e. import the document contents from a binary string.
data (bytes): The string to load from.
@ -640,27 +649,36 @@ cdef class Doc:
"""
if self.length != 0:
raise ValueError("Cannot load into non-empty Doc")
deserializers = {
'text': lambda b: None,
'array_head': lambda b: None,
'array_body': lambda b: None,
'sentiment': lambda b: None,
'tensor': lambda b: None,
'user_data': lambda user_data: self.user_data.update(user_data)
}
msg = util.from_bytes(bytes_data, deserializers, exclude)
cdef attr_t[:, :] attrs
cdef int i, start, end, has_space
fields = dill.loads(data)
text, attrs = fields[:2]
self.sentiment, self.tensor = fields[2:4]
self.noun_chunks_iterator, self.user_data = fields[4:6]
self.user_hooks, self.user_token_hooks, self.user_span_hooks = fields[6]
self.sentiment = msg['sentiment']
self.tensor = msg['tensor']
start = 0
cdef const LexemeC* lex
cdef unicode orth_
text = msg['text']
attrs = msg['array_body']
for i in range(attrs.shape[0]):
end = start + attrs[i, 0]
has_space = attrs[i, 1]
orth_ = text[start:end]
lex = self.vocab.get(self.mem, orth_)
self.push_back(lex, has_space)
start = end + has_space
self.from_array([TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE],
attrs[:, 2:])
self.from_array(msg['array_head'][2:],
attrs[:, 2:])
return self
def merge(self, int start_idx, int end_idx, *args, **attributes):

View File

@ -111,7 +111,7 @@ cdef class Token:
RETURNS (float): A scalar similarity score. Higher is more similar.
"""
if 'similarity' in self.doc.user_token_hooks:
return self.doc.user_token_hooks['similarity'](self)
return self.doc.user_token_hooks['similarity'](self)
if self.vector_norm == 0 or other.vector_norm == 0:
return 0.0
return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm)
@ -245,7 +245,10 @@ cdef class Token:
def __get__(self):
if 'vector' in self.doc.user_token_hooks:
return self.doc.user_token_hooks['vector'](self)
return self.vocab.get_vector(self.c.lex.orth)
if self.has_vector:
return self.vocab.get_vector(self.c.lex.orth)
else:
return self.doc.tensor[self.i]
property vector_norm:
"""The L2 norm of the token's vector representation.