diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index ef744a5da..83a4454e3 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -103,20 +103,30 @@ class Tok2Vec(Pipe): set_dropout_rate(self.model, drop) tokvecs, bp_tokvecs = self.model.begin_update(docs) - def capture_losses(d_tokvecs): - """Accumulate tok2vec loss before doing backprop.""" - l2_loss = sum((d_t2v ** 2).sum() for d_t2v in d_tokvecs) - if self.name in losses: - losses[self.name] += l2_loss / len(d_tokvecs) - else: - losses[self.name] = l2_loss / len(d_tokvecs) - return bp_tokvecs(d_tokvecs) + d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs] + losses.setdefault(self.name, 0.0) + + def accumulate_gradient(one_d_tokvecs): + """Accumulate tok2vec loss and gradient. This is passed as a callback + to all but the last listener. Only the last one does the backprop. + """ + nonlocal d_tokvecs + for i in range(len(one_d_tokvecs)): + d_tokvecs[i] += one_d_tokvecs[i] + losses[self.name] += float((one_d_tokvecs[i] ** 2).sum()) + + def backprop(one_d_tokvecs): + """Callback to actually do the backprop. Passed to last listener.""" + accumulate_gradient(one_d_tokvecs) + d_docs = bp_tokvecs(d_tokvecs) + if sgd is not None: + self.model.finish_update(sgd) + return d_docs batch_id = Tok2VecListener.get_batch_id(docs) - for listener in self.listeners: - listener.receive(batch_id, tokvecs, capture_losses) - if sgd is not None: - self.model.finish_update(sgd) + for listener in self.listeners[:-1]: + listener.receive(batch_id, tokvecs, accumulate_gradient) + self.listeners[-1].receive(batch_id, tokvecs, backprop) if set_annotations: self.set_annotations(docs, tokvecs)