mirror of https://github.com/explosion/spaCy.git
Fix multitask objectives
This commit is contained in:
parent
d1246c95fb
commit
8f06903e09
|
@ -681,13 +681,19 @@ class MultitaskObjective(Tagger):
|
||||||
return tokvecs, scores
|
return tokvecs, scores
|
||||||
|
|
||||||
def get_loss(self, docs, golds, scores):
|
def get_loss(self, docs, golds, scores):
|
||||||
|
assert len(docs) == len(golds)
|
||||||
cdef int idx = 0
|
cdef int idx = 0
|
||||||
correct = numpy.zeros((scores.shape[0],), dtype='i')
|
correct = numpy.zeros((scores.shape[0],), dtype='i')
|
||||||
guesses = scores.argmax(axis=1)
|
guesses = scores.argmax(axis=1)
|
||||||
for gold in golds:
|
for i, gold in enumerate(golds):
|
||||||
for i in range(len(gold.labels)):
|
for j in range(len(docs[i])):
|
||||||
label = self.make_label(i, gold.words, gold.tags, gold.heads,
|
# Handes alignment for tokenization differences
|
||||||
gold.labels, gold.ents)
|
gold_idx = gold.cand_to_gold[j]
|
||||||
|
if gold_idx is None:
|
||||||
|
idx += 1
|
||||||
|
continue
|
||||||
|
label = self.make_label(gold_idx, gold.words, gold.tags,
|
||||||
|
gold.heads, gold.labels, gold.ents)
|
||||||
if label is None or label not in self.labels:
|
if label is None or label not in self.labels:
|
||||||
correct[idx] = guesses[idx]
|
correct[idx] = guesses[idx]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -542,6 +542,7 @@ cdef class Parser:
|
||||||
def update(self, docs, golds, drop=0., sgd=None, losses=None):
|
def update(self, docs, golds, drop=0., sgd=None, losses=None):
|
||||||
if not any(self.moves.has_gold(gold) for gold in golds):
|
if not any(self.moves.has_gold(gold) for gold in golds):
|
||||||
return None
|
return None
|
||||||
|
assert len(docs) == len(golds)
|
||||||
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.0:
|
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.0:
|
||||||
return self.update_beam(docs, golds,
|
return self.update_beam(docs, golds,
|
||||||
self.cfg['beam_width'], self.cfg['beam_density'],
|
self.cfg['beam_width'], self.cfg['beam_density'],
|
||||||
|
@ -551,6 +552,8 @@ cdef class Parser:
|
||||||
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
|
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
golds = [golds]
|
golds = [golds]
|
||||||
|
for multitask in self._multitasks:
|
||||||
|
multitask.update(docs, golds, drop=drop, sgd=sgd)
|
||||||
cuda_stream = util.get_cuda_stream()
|
cuda_stream = util.get_cuda_stream()
|
||||||
states, golds, max_steps = self._init_gold_batch(docs, golds)
|
states, golds, max_steps = self._init_gold_batch(docs, golds)
|
||||||
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream,
|
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream,
|
||||||
|
@ -605,9 +608,7 @@ cdef class Parser:
|
||||||
break
|
break
|
||||||
self._make_updates(d_tokvecs,
|
self._make_updates(d_tokvecs,
|
||||||
bp_tokvecs, backprops, sgd, cuda_stream)
|
bp_tokvecs, backprops, sgd, cuda_stream)
|
||||||
for multitask in self._multitasks:
|
|
||||||
multitask.update(docs, golds, drop=drop, sgd=sgd)
|
|
||||||
|
|
||||||
def update_beam(self, docs, golds, width=None, density=None,
|
def update_beam(self, docs, golds, width=None, density=None,
|
||||||
drop=0., sgd=None, losses=None):
|
drop=0., sgd=None, losses=None):
|
||||||
if not any(self.moves.has_gold(gold) for gold in golds):
|
if not any(self.moves.has_gold(gold) for gold in golds):
|
||||||
|
|
Loading…
Reference in New Issue