Don't hold gradient updates in language -- let the parser decide how to batch the updates.

This commit is contained in:
Matthew Honnibal 2017-05-23 04:29:10 -05:00
parent 6b918cc58e
commit 9adfe9e8fc
1 changed files with 5 additions and 17 deletions

View File

@ -209,29 +209,17 @@ class Language(object):
>>> for docs, golds in epoch: >>> for docs, golds in epoch:
>>> state = nlp.update(docs, golds, sgd=optimizer) >>> state = nlp.update(docs, golds, sgd=optimizer)
""" """
grads = {}
def get_grads(W, dW, key=None):
grads[key] = (W, dW)
tok2vec = self.pipeline[0] tok2vec = self.pipeline[0]
feats = tok2vec.doc2feats(docs) feats = tok2vec.doc2feats(docs)
for proc in self.pipeline[1:]: for proc in self.pipeline[1:]:
if not hasattr(proc, 'update'): if not hasattr(proc, 'update'):
continue continue
grads = {}
tokvecses, bp_tokvecses = tok2vec.model.begin_update(feats, drop=drop) tokvecses, bp_tokvecses = tok2vec.model.begin_update(feats, drop=drop)
d_tokvecses = proc.update((docs, tokvecses), golds, sgd=get_grads, drop=drop) d_tokvecses = proc.update((docs, tokvecses), golds, sgd=sgd, drop=drop)
bp_tokvecses(d_tokvecses, sgd=get_grads) bp_tokvecses(d_tokvecses, sgd=sgd)
if sgd is not None: # Clear the tensor variable, to free GPU memory.
for key, (W, dW) in grads.items(): # If we don't do this, the memory leak gets pretty
# TODO: Unhack this when thinc improves # bad, because we may be holding part of a batch.
if isinstance(W, numpy.ndarray):
sgd.ops = NumpyOps()
else:
sgd.ops = CupyOps()
sgd(W, dW, key=key)
for key in list(grads.keys()):
grads.pop(key)
for doc in docs: for doc in docs:
doc.tensor = None doc.tensor = None