Only run backprop once when shared tok2vec weights (#5331)

Previously, pipelines with shared tok2vec weights would call the
tok2vec backprop callback multiple times, once for each pipeline
component. This caused errors for PyTorch, and was inefficient.

Instead, accumulate the gradient for all but one component, and just
call the callback once.
This commit is contained in:
Matthew Honnibal 2020-04-21 19:30:41 +02:00 committed by GitHub
parent 6918d99b6c
commit b2ef6100af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 22 additions and 12 deletions

View File

@ -103,20 +103,30 @@ class Tok2Vec(Pipe):
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
tokvecs, bp_tokvecs = self.model.begin_update(docs) tokvecs, bp_tokvecs = self.model.begin_update(docs)
def capture_losses(d_tokvecs): d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
"""Accumulate tok2vec loss before doing backprop.""" losses.setdefault(self.name, 0.0)
l2_loss = sum((d_t2v ** 2).sum() for d_t2v in d_tokvecs)
if self.name in losses: def accumulate_gradient(one_d_tokvecs):
losses[self.name] += l2_loss / len(d_tokvecs) """Accumulate tok2vec loss and gradient. This is passed as a callback
else: to all but the last listener. Only the last one does the backprop.
losses[self.name] = l2_loss / len(d_tokvecs) """
return bp_tokvecs(d_tokvecs) nonlocal d_tokvecs
for i in range(len(one_d_tokvecs)):
d_tokvecs[i] += one_d_tokvecs[i]
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
def backprop(one_d_tokvecs):
"""Callback to actually do the backprop. Passed to last listener."""
accumulate_gradient(one_d_tokvecs)
d_docs = bp_tokvecs(d_tokvecs)
if sgd is not None:
self.model.finish_update(sgd)
return d_docs
batch_id = Tok2VecListener.get_batch_id(docs) batch_id = Tok2VecListener.get_batch_id(docs)
for listener in self.listeners: for listener in self.listeners[:-1]:
listener.receive(batch_id, tokvecs, capture_losses) listener.receive(batch_id, tokvecs, accumulate_gradient)
if sgd is not None: self.listeners[-1].receive(batch_id, tokvecs, backprop)
self.model.finish_update(sgd)
if set_annotations: if set_annotations:
self.set_annotations(docs, tokvecs) self.set_annotations(docs, tokvecs)