code cleanup

This commit is contained in:
svlandeg 2019-07-15 17:36:43 +02:00
parent cdc589d344
commit a63d15a142
1 changed files with 5 additions and 14 deletions

View File

@ -14,7 +14,6 @@ from thinc.neural.util import to_categorical
from thinc.neural.util import get_array_module
from spacy.kb import KnowledgeBase
from ..cli.pretrain import get_cossim_loss
from .functions import merge_subtokens
from ..tokens.doc cimport Doc
from ..syntax.nn_parser cimport Parser
@ -1164,7 +1163,6 @@ class EntityLinker(Pipe):
candidates = self.kb.get_candidates(mention)
random.shuffle(candidates)
nr_neg = 0
for c in candidates:
kb_id = c.entity_
entity_encoding = c.entity_vector
@ -1180,21 +1178,20 @@ class EntityLinker(Pipe):
if kb_id == gold_kb:
cats.append([1])
else:
nr_neg += 1
cats.append([0])
if len(entity_encodings) > 0:
assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors)
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
cats = self.model.ops.asarray(cats, dtype="float32")
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
mention_encodings = [list(context_encodings[i]) + list(entity_encodings[i]) + priors[i] + type_vectors[i]
for i in range(len(entity_encodings))]
pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop)
cats = self.model.ops.asarray(cats, dtype="float32")
loss, d_scores = self.get_loss(prediction=pred, golds=cats, docs=None)
loss, d_scores = self.get_loss(scores=pred, golds=cats, docs=docs)
mention_gradient = bp_mention(d_scores, sgd=sgd)
context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient]
@ -1205,18 +1202,12 @@ class EntityLinker(Pipe):
return loss
return 0
def get_loss(self, docs, golds, prediction):
d_scores = (prediction - golds)
def get_loss(self, docs, golds, scores):
d_scores = (scores - golds)
loss = (d_scores ** 2).sum()
loss = loss / len(golds)
return loss, d_scores
def get_loss_old(self, docs, golds, scores):
# this loss function assumes we're only using positive examples
loss, gradients = get_cossim_loss(yh=scores, y=golds)
loss = loss / len(golds)
return loss, gradients
def __call__(self, doc):
entities, kb_ids = self.predict([doc])
self.set_annotations([doc], entities, kb_ids)