From 8f06903e0915f9d7eda31617e4fc72d24b0872a3 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 17 Feb 2018 18:41:18 +0100 Subject: [PATCH] Fix multitask objectives --- spacy/pipeline.pyx | 14 ++++++++++---- spacy/syntax/nn_parser.pyx | 7 ++++--- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index b233f2071..cbd58281e 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -681,13 +681,19 @@ class MultitaskObjective(Tagger): return tokvecs, scores def get_loss(self, docs, golds, scores): + assert len(docs) == len(golds) cdef int idx = 0 correct = numpy.zeros((scores.shape[0],), dtype='i') guesses = scores.argmax(axis=1) - for gold in golds: - for i in range(len(gold.labels)): - label = self.make_label(i, gold.words, gold.tags, gold.heads, - gold.labels, gold.ents) + for i, gold in enumerate(golds): + for j in range(len(docs[i])): + # Handes alignment for tokenization differences + gold_idx = gold.cand_to_gold[j] + if gold_idx is None: + idx += 1 + continue + label = self.make_label(gold_idx, gold.words, gold.tags, + gold.heads, gold.labels, gold.ents) if label is None or label not in self.labels: correct[idx] = guesses[idx] else: diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 804838b66..b4b8d4779 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -542,6 +542,7 @@ cdef class Parser: def update(self, docs, golds, drop=0., sgd=None, losses=None): if not any(self.moves.has_gold(gold) for gold in golds): return None + assert len(docs) == len(golds) if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.0: return self.update_beam(docs, golds, self.cfg['beam_width'], self.cfg['beam_density'], @@ -551,6 +552,8 @@ cdef class Parser: if isinstance(docs, Doc) and isinstance(golds, GoldParse): docs = [docs] golds = [golds] + for multitask in self._multitasks: + multitask.update(docs, golds, drop=drop, sgd=sgd) cuda_stream = util.get_cuda_stream() states, golds, max_steps = self._init_gold_batch(docs, golds) (tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream, @@ -605,9 +608,7 @@ cdef class Parser: break self._make_updates(d_tokvecs, bp_tokvecs, backprops, sgd, cuda_stream) - for multitask in self._multitasks: - multitask.update(docs, golds, drop=drop, sgd=sgd) - + def update_beam(self, docs, golds, width=None, density=None, drop=0., sgd=None, losses=None): if not any(self.moves.has_gold(gold) for gold in golds):