Fix multitask objectives

This commit is contained in:
Matthew Honnibal 2018-02-17 18:41:18 +01:00
parent d1246c95fb
commit 8f06903e09
2 changed files with 14 additions and 7 deletions

View File

@ -681,13 +681,19 @@ class MultitaskObjective(Tagger):
return tokvecs, scores
def get_loss(self, docs, golds, scores):
assert len(docs) == len(golds)
cdef int idx = 0
correct = numpy.zeros((scores.shape[0],), dtype='i')
guesses = scores.argmax(axis=1)
for gold in golds:
for i in range(len(gold.labels)):
label = self.make_label(i, gold.words, gold.tags, gold.heads,
gold.labels, gold.ents)
for i, gold in enumerate(golds):
for j in range(len(docs[i])):
# Handes alignment for tokenization differences
gold_idx = gold.cand_to_gold[j]
if gold_idx is None:
idx += 1
continue
label = self.make_label(gold_idx, gold.words, gold.tags,
gold.heads, gold.labels, gold.ents)
if label is None or label not in self.labels:
correct[idx] = guesses[idx]
else:

View File

@ -542,6 +542,7 @@ cdef class Parser:
def update(self, docs, golds, drop=0., sgd=None, losses=None):
if not any(self.moves.has_gold(gold) for gold in golds):
return None
assert len(docs) == len(golds)
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.0:
return self.update_beam(docs, golds,
self.cfg['beam_width'], self.cfg['beam_density'],
@ -551,6 +552,8 @@ cdef class Parser:
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
docs = [docs]
golds = [golds]
for multitask in self._multitasks:
multitask.update(docs, golds, drop=drop, sgd=sgd)
cuda_stream = util.get_cuda_stream()
states, golds, max_steps = self._init_gold_batch(docs, golds)
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream,
@ -605,9 +608,7 @@ cdef class Parser:
break
self._make_updates(d_tokvecs,
bp_tokvecs, backprops, sgd, cuda_stream)
for multitask in self._multitasks:
multitask.update(docs, golds, drop=drop, sgd=sgd)
def update_beam(self, docs, golds, width=None, density=None,
drop=0., sgd=None, losses=None):
if not any(self.moves.has_gold(gold) for gold in golds):