Pass kwargs into pipeline components during begin_training

This commit is contained in:
Matthew Honnibal 2018-02-12 10:18:39 +01:00
parent ab35ac4e6f
commit d7c9b53120
1 changed files with 9 additions and 5 deletions

View File

@ -144,7 +144,8 @@ class Pipe(object):
return create_default_optimizer(self.model.ops,
**self.cfg.get('optimizer', {}))
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None):
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None,
**kwargs):
"""Initialize the pipe for training, using data exampes if available.
If no model has been initialized yet, the model is added."""
if self.model is True:
@ -344,7 +345,8 @@ class Tensorizer(Pipe):
loss = (d_scores**2).sum()
return loss, d_scores
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None):
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None,
**kwargs):
"""Allocate models, pre-process training data and acquire an
optimizer.
@ -467,7 +469,8 @@ class Tagger(Pipe):
d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs])
return float(loss), d_scores
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None):
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None,
**kwargs):
orig_tag_map = dict(self.vocab.morphology.tag_map)
new_tag_map = OrderedDict()
for raw_text, annots_brackets in gold_tuples:
@ -641,7 +644,7 @@ class MultitaskObjective(Tagger):
pass
def begin_training(self, gold_tuples=tuple(), pipeline=None, tok2vec=None,
sgd=None):
sgd=None, **kwargs):
gold_tuples = nonproj.preprocess_training_data(gold_tuples)
for raw_text, annots_brackets in gold_tuples:
for annots, brackets in annots_brackets:
@ -766,7 +769,7 @@ class SimilarityHook(Pipe):
def update(self, doc1_doc2, golds, sgd=None, drop=0.):
sims, bp_sims = self.model.begin_update(doc1_doc2, drop=drop)
def begin_training(self, _=tuple(), pipeline=None, sgd=None):
def begin_training(self, _=tuple(), pipeline=None, sgd=None, **kwargs):
"""Allocate model, using width from tensorizer in pipeline.
gold_tuples (iterable): Gold-standard training data.
@ -887,6 +890,7 @@ cdef class DependencyParser(Parser):
self._multitasks.append(labeller)
def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
self.add_multitask_objective('tag')
for labeller in self._multitasks:
tok2vec = self.model[0]
labeller.begin_training(gold_tuples, pipeline=pipeline,