From b2ef6100af585942388930a14fa78e9762758f36 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 21 Apr 2020 19:30:41 +0200 Subject: [PATCH] Only run backprop once when shared tok2vec weights (#5331) Previously, pipelines with shared tok2vec weights would call the tok2vec backprop callback multiple times, once for each pipeline component. This caused errors for PyTorch, and was inefficient. Instead, accumulate the gradient for all but one component, and just call the callback once. --- spacy/pipeline/tok2vec.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index ef744a5da..83a4454e3 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -103,20 +103,30 @@ class Tok2Vec(Pipe): set_dropout_rate(self.model, drop) tokvecs, bp_tokvecs = self.model.begin_update(docs) - def capture_losses(d_tokvecs): - """Accumulate tok2vec loss before doing backprop.""" - l2_loss = sum((d_t2v ** 2).sum() for d_t2v in d_tokvecs) - if self.name in losses: - losses[self.name] += l2_loss / len(d_tokvecs) - else: - losses[self.name] = l2_loss / len(d_tokvecs) - return bp_tokvecs(d_tokvecs) + d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs] + losses.setdefault(self.name, 0.0) + + def accumulate_gradient(one_d_tokvecs): + """Accumulate tok2vec loss and gradient. This is passed as a callback + to all but the last listener. Only the last one does the backprop. + """ + nonlocal d_tokvecs + for i in range(len(one_d_tokvecs)): + d_tokvecs[i] += one_d_tokvecs[i] + losses[self.name] += float((one_d_tokvecs[i] ** 2).sum()) + + def backprop(one_d_tokvecs): + """Callback to actually do the backprop. Passed to last listener.""" + accumulate_gradient(one_d_tokvecs) + d_docs = bp_tokvecs(d_tokvecs) + if sgd is not None: + self.model.finish_update(sgd) + return d_docs batch_id = Tok2VecListener.get_batch_id(docs) - for listener in self.listeners: - listener.receive(batch_id, tokvecs, capture_losses) - if sgd is not None: - self.model.finish_update(sgd) + for listener in self.listeners[:-1]: + listener.receive(batch_id, tokvecs, accumulate_gradient) + self.listeners[-1].receive(batch_id, tokvecs, backprop) if set_annotations: self.set_annotations(docs, tokvecs)