mirror of https://github.com/explosion/spaCy.git
Rearrange multi-task learning
This commit is contained in:
parent
135a13790c
commit
e6cc927ab1
|
@ -6,7 +6,8 @@ import dill
|
||||||
import numpy
|
import numpy
|
||||||
from thinc.neural import Model
|
from thinc.neural import Model
|
||||||
from thinc.neural.ops import NumpyOps, CupyOps
|
from thinc.neural.ops import NumpyOps, CupyOps
|
||||||
from thinc.neural.optimizers import Adam
|
from thinc.neural.optimizers import Adam, SGD
|
||||||
|
import random
|
||||||
|
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
from .vocab import Vocab
|
from .vocab import Vocab
|
||||||
|
@ -194,7 +195,7 @@ class Language(object):
|
||||||
proc(doc)
|
proc(doc)
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
def update(self, docs, golds, drop=0., sgd=None):
|
def update(self, docs, golds, drop=0., sgd=None, losses=None):
|
||||||
"""Update the models in the pipeline.
|
"""Update the models in the pipeline.
|
||||||
|
|
||||||
docs (iterable): A batch of `Doc` objects.
|
docs (iterable): A batch of `Doc` objects.
|
||||||
|
@ -211,12 +212,20 @@ class Language(object):
|
||||||
"""
|
"""
|
||||||
tok2vec = self.pipeline[0]
|
tok2vec = self.pipeline[0]
|
||||||
feats = tok2vec.doc2feats(docs)
|
feats = tok2vec.doc2feats(docs)
|
||||||
for proc in self.pipeline[1:]:
|
procs = list(self.pipeline[1:])
|
||||||
|
random.shuffle(procs)
|
||||||
|
grads = {}
|
||||||
|
def get_grads(W, dW, key=None):
|
||||||
|
grads[key] = (W, dW)
|
||||||
|
for proc in procs:
|
||||||
if not hasattr(proc, 'update'):
|
if not hasattr(proc, 'update'):
|
||||||
continue
|
continue
|
||||||
tokvecses, bp_tokvecses = tok2vec.model.begin_update(feats, drop=drop)
|
tokvecses, bp_tokvecses = tok2vec.model.begin_update(feats, drop=drop)
|
||||||
d_tokvecses = proc.update((docs, tokvecses), golds, sgd=sgd, drop=drop)
|
d_tokvecses = proc.update((docs, tokvecses), golds,
|
||||||
|
drop=drop, sgd=sgd, losses=losses)
|
||||||
bp_tokvecses(d_tokvecses, sgd=sgd)
|
bp_tokvecses(d_tokvecses, sgd=sgd)
|
||||||
|
for key, (W, dW) in grads.items():
|
||||||
|
sgd(W, dW, key=key)
|
||||||
# Clear the tensor variable, to free GPU memory.
|
# Clear the tensor variable, to free GPU memory.
|
||||||
# If we don't do this, the memory leak gets pretty
|
# If we don't do this, the memory leak gets pretty
|
||||||
# bad, because we may be holding part of a batch.
|
# bad, because we may be holding part of a batch.
|
||||||
|
|
Loading…
Reference in New Issue