mirror of https://github.com/explosion/spaCy.git
Only run backprop once when shared tok2vec weights (#5331)
Previously, pipelines with shared tok2vec weights would call the tok2vec backprop callback multiple times, once for each pipeline component. This caused errors for PyTorch, and was inefficient. Instead, accumulate the gradient for all but one component, and just call the callback once.
This commit is contained in:
parent
6918d99b6c
commit
b2ef6100af
|
@ -103,20 +103,30 @@ class Tok2Vec(Pipe):
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
tokvecs, bp_tokvecs = self.model.begin_update(docs)
|
tokvecs, bp_tokvecs = self.model.begin_update(docs)
|
||||||
|
|
||||||
def capture_losses(d_tokvecs):
|
d_tokvecs = [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||||
"""Accumulate tok2vec loss before doing backprop."""
|
losses.setdefault(self.name, 0.0)
|
||||||
l2_loss = sum((d_t2v ** 2).sum() for d_t2v in d_tokvecs)
|
|
||||||
if self.name in losses:
|
def accumulate_gradient(one_d_tokvecs):
|
||||||
losses[self.name] += l2_loss / len(d_tokvecs)
|
"""Accumulate tok2vec loss and gradient. This is passed as a callback
|
||||||
else:
|
to all but the last listener. Only the last one does the backprop.
|
||||||
losses[self.name] = l2_loss / len(d_tokvecs)
|
"""
|
||||||
return bp_tokvecs(d_tokvecs)
|
nonlocal d_tokvecs
|
||||||
|
for i in range(len(one_d_tokvecs)):
|
||||||
|
d_tokvecs[i] += one_d_tokvecs[i]
|
||||||
|
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
|
||||||
|
|
||||||
|
def backprop(one_d_tokvecs):
|
||||||
|
"""Callback to actually do the backprop. Passed to last listener."""
|
||||||
|
accumulate_gradient(one_d_tokvecs)
|
||||||
|
d_docs = bp_tokvecs(d_tokvecs)
|
||||||
|
if sgd is not None:
|
||||||
|
self.model.finish_update(sgd)
|
||||||
|
return d_docs
|
||||||
|
|
||||||
batch_id = Tok2VecListener.get_batch_id(docs)
|
batch_id = Tok2VecListener.get_batch_id(docs)
|
||||||
for listener in self.listeners:
|
for listener in self.listeners[:-1]:
|
||||||
listener.receive(batch_id, tokvecs, capture_losses)
|
listener.receive(batch_id, tokvecs, accumulate_gradient)
|
||||||
if sgd is not None:
|
self.listeners[-1].receive(batch_id, tokvecs, backprop)
|
||||||
self.model.finish_update(sgd)
|
|
||||||
if set_annotations:
|
if set_annotations:
|
||||||
self.set_annotations(docs, tokvecs)
|
self.set_annotations(docs, tokvecs)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue