Try using tensor for vector/similarity methdos

This commit is contained in:
Matthew Honnibal 2017-05-30 23:35:17 +02:00
parent a131981f3b
commit 498ad85309
2 changed files with 58 additions and 37 deletions

View File

@ -30,6 +30,7 @@ from ..syntax.iterators import CHUNKERS
from ..util import normalize_slice from ..util import normalize_slice
from ..compat import is_config from ..compat import is_config
from .. import about from .. import about
from .. import util
DEF PADDING = 5 DEF PADDING = 5
@ -252,8 +253,12 @@ cdef class Doc:
def __get__(self): def __get__(self):
if 'has_vector' in self.user_hooks: if 'has_vector' in self.user_hooks:
return self.user_hooks['has_vector'](self) return self.user_hooks['has_vector'](self)
elif any(token.has_vector for token in self):
return any(token.has_vector for token in self) return True
elif self.tensor:
return True
else:
return False
property vector: property vector:
"""A real-valued meaning representation. Defaults to an average of the """A real-valued meaning representation. Defaults to an average of the
@ -265,12 +270,16 @@ cdef class Doc:
def __get__(self): def __get__(self):
if 'vector' in self.user_hooks: if 'vector' in self.user_hooks:
return self.user_hooks['vector'](self) return self.user_hooks['vector'](self)
if self._vector is None: if self._vector is not None:
if len(self): return self._vector
elif self.has_vector and len(self):
self._vector = sum(t.vector for t in self) / len(self) self._vector = sum(t.vector for t in self) / len(self)
return self._vector
elif self.tensor:
self._vector = self.tensor.mean(axis=0)
return self._vector
else: else:
return numpy.zeros((self.vocab.vectors_length,), dtype='float32') return numpy.zeros((self.vocab.vectors_length,), dtype='float32')
return self._vector
def __set__(self, value): def __set__(self, value):
self._vector = value self._vector = value
@ -295,10 +304,6 @@ cdef class Doc:
def __set__(self, value): def __set__(self, value):
self._vector_norm = value self._vector_norm = value
@property
def string(self):
return self.text
property text: property text:
"""A unicode representation of the document text. """A unicode representation of the document text.
@ -598,15 +603,16 @@ cdef class Doc:
self.is_tagged = bool(TAG in attrs or POS in attrs) self.is_tagged = bool(TAG in attrs or POS in attrs)
return self return self
def to_disk(self, path): def to_disk(self, path, **exclude):
"""Save the current state to a directory. """Save the current state to a directory.
path (unicode or Path): A path to a directory, which will be created if path (unicode or Path): A path to a directory, which will be created if
it doesn't exist. Paths may be either strings or `Path`-like objects. it doesn't exist. Paths may be either strings or `Path`-like objects.
""" """
raise NotImplementedError() with path.open('wb') as file_:
file_.write(self.to_bytes(**exclude))
def from_disk(self, path): def from_disk(self, path, **exclude):
"""Loads state from a directory. Modifies the object in place and """Loads state from a directory. Modifies the object in place and
returns it. returns it.
@ -614,25 +620,28 @@ cdef class Doc:
strings or `Path`-like objects. strings or `Path`-like objects.
RETURNS (Doc): The modified `Doc` object. RETURNS (Doc): The modified `Doc` object.
""" """
raise NotImplementedError() with path.open('rb') as file_:
bytes_data = file_.read()
self.from_bytes(bytes_data, **exclude)
def to_bytes(self): def to_bytes(self, **exclude):
"""Serialize, i.e. export the document contents to a binary string. """Serialize, i.e. export the document contents to a binary string.
RETURNS (bytes): A losslessly serialized copy of the `Doc`, including RETURNS (bytes): A losslessly serialized copy of the `Doc`, including
all annotations. all annotations.
""" """
return dill.dumps( array_head = [LENGTH,SPACY,TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE]
(self.text, serializers = {
self.to_array([LENGTH,SPACY,TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE]), 'text': lambda: self.text,
self.sentiment, 'array_head': lambda: array_head,
self.tensor, 'array_body': lambda: self.to_array(array_head),
self.noun_chunks_iterator, 'sentiment': lambda: self.sentiment,
self.user_data, 'tensor': lambda: self.tensor,
(self.user_hooks, self.user_token_hooks, self.user_span_hooks)), 'user_data': lambda: self.user_data
protocol=-1) }
return util.to_bytes(serializers, exclude)
def from_bytes(self, data): def from_bytes(self, bytes_data, **exclude):
"""Deserialize, i.e. import the document contents from a binary string. """Deserialize, i.e. import the document contents from a binary string.
data (bytes): The string to load from. data (bytes): The string to load from.
@ -640,26 +649,35 @@ cdef class Doc:
""" """
if self.length != 0: if self.length != 0:
raise ValueError("Cannot load into non-empty Doc") raise ValueError("Cannot load into non-empty Doc")
deserializers = {
'text': lambda b: None,
'array_head': lambda b: None,
'array_body': lambda b: None,
'sentiment': lambda b: None,
'tensor': lambda b: None,
'user_data': lambda user_data: self.user_data.update(user_data)
}
msg = util.from_bytes(bytes_data, deserializers, exclude)
cdef attr_t[:, :] attrs cdef attr_t[:, :] attrs
cdef int i, start, end, has_space cdef int i, start, end, has_space
fields = dill.loads(data) self.sentiment = msg['sentiment']
text, attrs = fields[:2] self.tensor = msg['tensor']
self.sentiment, self.tensor = fields[2:4]
self.noun_chunks_iterator, self.user_data = fields[4:6]
self.user_hooks, self.user_token_hooks, self.user_span_hooks = fields[6]
start = 0 start = 0
cdef const LexemeC* lex cdef const LexemeC* lex
cdef unicode orth_ cdef unicode orth_
text = msg['text']
attrs = msg['array_body']
for i in range(attrs.shape[0]): for i in range(attrs.shape[0]):
end = start + attrs[i, 0] end = start + attrs[i, 0]
has_space = attrs[i, 1] has_space = attrs[i, 1]
orth_ = text[start:end] orth_ = text[start:end]
lex = self.vocab.get(self.mem, orth_) lex = self.vocab.get(self.mem, orth_)
self.push_back(lex, has_space) self.push_back(lex, has_space)
start = end + has_space start = end + has_space
self.from_array([TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE], self.from_array(msg['array_head'][2:],
attrs[:, 2:]) attrs[:, 2:])
return self return self

View File

@ -245,7 +245,10 @@ cdef class Token:
def __get__(self): def __get__(self):
if 'vector' in self.doc.user_token_hooks: if 'vector' in self.doc.user_token_hooks:
return self.doc.user_token_hooks['vector'](self) return self.doc.user_token_hooks['vector'](self)
if self.has_vector:
return self.vocab.get_vector(self.c.lex.orth) return self.vocab.get_vector(self.c.lex.orth)
else:
return self.doc.tensor[self.i]
property vector_norm: property vector_norm:
"""The L2 norm of the token's vector representation. """The L2 norm of the token's vector representation.