mirror of https://github.com/explosion/spaCy.git
code cleanup
This commit is contained in:
parent
cdc589d344
commit
a63d15a142
|
@ -14,7 +14,6 @@ from thinc.neural.util import to_categorical
|
|||
from thinc.neural.util import get_array_module
|
||||
|
||||
from spacy.kb import KnowledgeBase
|
||||
from ..cli.pretrain import get_cossim_loss
|
||||
from .functions import merge_subtokens
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..syntax.nn_parser cimport Parser
|
||||
|
@ -1164,7 +1163,6 @@ class EntityLinker(Pipe):
|
|||
|
||||
candidates = self.kb.get_candidates(mention)
|
||||
random.shuffle(candidates)
|
||||
nr_neg = 0
|
||||
for c in candidates:
|
||||
kb_id = c.entity_
|
||||
entity_encoding = c.entity_vector
|
||||
|
@ -1180,21 +1178,20 @@ class EntityLinker(Pipe):
|
|||
if kb_id == gold_kb:
|
||||
cats.append([1])
|
||||
else:
|
||||
nr_neg += 1
|
||||
cats.append([0])
|
||||
|
||||
if len(entity_encodings) > 0:
|
||||
assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors)
|
||||
|
||||
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
|
||||
cats = self.model.ops.asarray(cats, dtype="float32")
|
||||
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
||||
|
||||
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
|
||||
mention_encodings = [list(context_encodings[i]) + list(entity_encodings[i]) + priors[i] + type_vectors[i]
|
||||
for i in range(len(entity_encodings))]
|
||||
pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop)
|
||||
cats = self.model.ops.asarray(cats, dtype="float32")
|
||||
|
||||
loss, d_scores = self.get_loss(prediction=pred, golds=cats, docs=None)
|
||||
loss, d_scores = self.get_loss(scores=pred, golds=cats, docs=docs)
|
||||
mention_gradient = bp_mention(d_scores, sgd=sgd)
|
||||
|
||||
context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient]
|
||||
|
@ -1205,18 +1202,12 @@ class EntityLinker(Pipe):
|
|||
return loss
|
||||
return 0
|
||||
|
||||
def get_loss(self, docs, golds, prediction):
|
||||
d_scores = (prediction - golds)
|
||||
def get_loss(self, docs, golds, scores):
|
||||
d_scores = (scores - golds)
|
||||
loss = (d_scores ** 2).sum()
|
||||
loss = loss / len(golds)
|
||||
return loss, d_scores
|
||||
|
||||
def get_loss_old(self, docs, golds, scores):
|
||||
# this loss function assumes we're only using positive examples
|
||||
loss, gradients = get_cossim_loss(yh=scores, y=golds)
|
||||
loss = loss / len(golds)
|
||||
return loss, gradients
|
||||
|
||||
def __call__(self, doc):
|
||||
entities, kb_ids = self.predict([doc])
|
||||
self.set_annotations([doc], entities, kb_ids)
|
||||
|
|
Loading…
Reference in New Issue