mirror of https://github.com/explosion/spaCy.git
Try using tensor for vector/similarity methdos
This commit is contained in:
parent
a131981f3b
commit
498ad85309
|
@ -30,6 +30,7 @@ from ..syntax.iterators import CHUNKERS
|
||||||
from ..util import normalize_slice
|
from ..util import normalize_slice
|
||||||
from ..compat import is_config
|
from ..compat import is_config
|
||||||
from .. import about
|
from .. import about
|
||||||
|
from .. import util
|
||||||
|
|
||||||
|
|
||||||
DEF PADDING = 5
|
DEF PADDING = 5
|
||||||
|
@ -252,8 +253,12 @@ cdef class Doc:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
if 'has_vector' in self.user_hooks:
|
if 'has_vector' in self.user_hooks:
|
||||||
return self.user_hooks['has_vector'](self)
|
return self.user_hooks['has_vector'](self)
|
||||||
|
elif any(token.has_vector for token in self):
|
||||||
return any(token.has_vector for token in self)
|
return True
|
||||||
|
elif self.tensor:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
property vector:
|
property vector:
|
||||||
"""A real-valued meaning representation. Defaults to an average of the
|
"""A real-valued meaning representation. Defaults to an average of the
|
||||||
|
@ -265,12 +270,16 @@ cdef class Doc:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
if 'vector' in self.user_hooks:
|
if 'vector' in self.user_hooks:
|
||||||
return self.user_hooks['vector'](self)
|
return self.user_hooks['vector'](self)
|
||||||
if self._vector is None:
|
if self._vector is not None:
|
||||||
if len(self):
|
return self._vector
|
||||||
|
elif self.has_vector and len(self):
|
||||||
self._vector = sum(t.vector for t in self) / len(self)
|
self._vector = sum(t.vector for t in self) / len(self)
|
||||||
|
return self._vector
|
||||||
|
elif self.tensor:
|
||||||
|
self._vector = self.tensor.mean(axis=0)
|
||||||
|
return self._vector
|
||||||
else:
|
else:
|
||||||
return numpy.zeros((self.vocab.vectors_length,), dtype='float32')
|
return numpy.zeros((self.vocab.vectors_length,), dtype='float32')
|
||||||
return self._vector
|
|
||||||
|
|
||||||
def __set__(self, value):
|
def __set__(self, value):
|
||||||
self._vector = value
|
self._vector = value
|
||||||
|
@ -295,10 +304,6 @@ cdef class Doc:
|
||||||
def __set__(self, value):
|
def __set__(self, value):
|
||||||
self._vector_norm = value
|
self._vector_norm = value
|
||||||
|
|
||||||
@property
|
|
||||||
def string(self):
|
|
||||||
return self.text
|
|
||||||
|
|
||||||
property text:
|
property text:
|
||||||
"""A unicode representation of the document text.
|
"""A unicode representation of the document text.
|
||||||
|
|
||||||
|
@ -598,15 +603,16 @@ cdef class Doc:
|
||||||
self.is_tagged = bool(TAG in attrs or POS in attrs)
|
self.is_tagged = bool(TAG in attrs or POS in attrs)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_disk(self, path):
|
def to_disk(self, path, **exclude):
|
||||||
"""Save the current state to a directory.
|
"""Save the current state to a directory.
|
||||||
|
|
||||||
path (unicode or Path): A path to a directory, which will be created if
|
path (unicode or Path): A path to a directory, which will be created if
|
||||||
it doesn't exist. Paths may be either strings or `Path`-like objects.
|
it doesn't exist. Paths may be either strings or `Path`-like objects.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
with path.open('wb') as file_:
|
||||||
|
file_.write(self.to_bytes(**exclude))
|
||||||
|
|
||||||
def from_disk(self, path):
|
def from_disk(self, path, **exclude):
|
||||||
"""Loads state from a directory. Modifies the object in place and
|
"""Loads state from a directory. Modifies the object in place and
|
||||||
returns it.
|
returns it.
|
||||||
|
|
||||||
|
@ -614,25 +620,28 @@ cdef class Doc:
|
||||||
strings or `Path`-like objects.
|
strings or `Path`-like objects.
|
||||||
RETURNS (Doc): The modified `Doc` object.
|
RETURNS (Doc): The modified `Doc` object.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
with path.open('rb') as file_:
|
||||||
|
bytes_data = file_.read()
|
||||||
|
self.from_bytes(bytes_data, **exclude)
|
||||||
|
|
||||||
def to_bytes(self):
|
def to_bytes(self, **exclude):
|
||||||
"""Serialize, i.e. export the document contents to a binary string.
|
"""Serialize, i.e. export the document contents to a binary string.
|
||||||
|
|
||||||
RETURNS (bytes): A losslessly serialized copy of the `Doc`, including
|
RETURNS (bytes): A losslessly serialized copy of the `Doc`, including
|
||||||
all annotations.
|
all annotations.
|
||||||
"""
|
"""
|
||||||
return dill.dumps(
|
array_head = [LENGTH,SPACY,TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE]
|
||||||
(self.text,
|
serializers = {
|
||||||
self.to_array([LENGTH,SPACY,TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE]),
|
'text': lambda: self.text,
|
||||||
self.sentiment,
|
'array_head': lambda: array_head,
|
||||||
self.tensor,
|
'array_body': lambda: self.to_array(array_head),
|
||||||
self.noun_chunks_iterator,
|
'sentiment': lambda: self.sentiment,
|
||||||
self.user_data,
|
'tensor': lambda: self.tensor,
|
||||||
(self.user_hooks, self.user_token_hooks, self.user_span_hooks)),
|
'user_data': lambda: self.user_data
|
||||||
protocol=-1)
|
}
|
||||||
|
return util.to_bytes(serializers, exclude)
|
||||||
|
|
||||||
def from_bytes(self, data):
|
def from_bytes(self, bytes_data, **exclude):
|
||||||
"""Deserialize, i.e. import the document contents from a binary string.
|
"""Deserialize, i.e. import the document contents from a binary string.
|
||||||
|
|
||||||
data (bytes): The string to load from.
|
data (bytes): The string to load from.
|
||||||
|
@ -640,26 +649,35 @@ cdef class Doc:
|
||||||
"""
|
"""
|
||||||
if self.length != 0:
|
if self.length != 0:
|
||||||
raise ValueError("Cannot load into non-empty Doc")
|
raise ValueError("Cannot load into non-empty Doc")
|
||||||
|
deserializers = {
|
||||||
|
'text': lambda b: None,
|
||||||
|
'array_head': lambda b: None,
|
||||||
|
'array_body': lambda b: None,
|
||||||
|
'sentiment': lambda b: None,
|
||||||
|
'tensor': lambda b: None,
|
||||||
|
'user_data': lambda user_data: self.user_data.update(user_data)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg = util.from_bytes(bytes_data, deserializers, exclude)
|
||||||
|
|
||||||
cdef attr_t[:, :] attrs
|
cdef attr_t[:, :] attrs
|
||||||
cdef int i, start, end, has_space
|
cdef int i, start, end, has_space
|
||||||
fields = dill.loads(data)
|
self.sentiment = msg['sentiment']
|
||||||
text, attrs = fields[:2]
|
self.tensor = msg['tensor']
|
||||||
self.sentiment, self.tensor = fields[2:4]
|
|
||||||
self.noun_chunks_iterator, self.user_data = fields[4:6]
|
|
||||||
self.user_hooks, self.user_token_hooks, self.user_span_hooks = fields[6]
|
|
||||||
|
|
||||||
start = 0
|
start = 0
|
||||||
cdef const LexemeC* lex
|
cdef const LexemeC* lex
|
||||||
cdef unicode orth_
|
cdef unicode orth_
|
||||||
|
text = msg['text']
|
||||||
|
attrs = msg['array_body']
|
||||||
for i in range(attrs.shape[0]):
|
for i in range(attrs.shape[0]):
|
||||||
end = start + attrs[i, 0]
|
end = start + attrs[i, 0]
|
||||||
has_space = attrs[i, 1]
|
has_space = attrs[i, 1]
|
||||||
orth_ = text[start:end]
|
orth_ = text[start:end]
|
||||||
lex = self.vocab.get(self.mem, orth_)
|
lex = self.vocab.get(self.mem, orth_)
|
||||||
self.push_back(lex, has_space)
|
self.push_back(lex, has_space)
|
||||||
|
|
||||||
start = end + has_space
|
start = end + has_space
|
||||||
self.from_array([TAG,LEMMA,HEAD,DEP,ENT_IOB,ENT_TYPE],
|
self.from_array(msg['array_head'][2:],
|
||||||
attrs[:, 2:])
|
attrs[:, 2:])
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
@ -245,7 +245,10 @@ cdef class Token:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
if 'vector' in self.doc.user_token_hooks:
|
if 'vector' in self.doc.user_token_hooks:
|
||||||
return self.doc.user_token_hooks['vector'](self)
|
return self.doc.user_token_hooks['vector'](self)
|
||||||
|
if self.has_vector:
|
||||||
return self.vocab.get_vector(self.c.lex.orth)
|
return self.vocab.get_vector(self.c.lex.orth)
|
||||||
|
else:
|
||||||
|
return self.doc.tensor[self.i]
|
||||||
|
|
||||||
property vector_norm:
|
property vector_norm:
|
||||||
"""The L2 norm of the token's vector representation.
|
"""The L2 norm of the token's vector representation.
|
||||||
|
|
Loading…
Reference in New Issue