diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index c5f8065de..d86b2f3c0 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -144,7 +144,8 @@ class Pipe(object): return create_default_optimizer(self.model.ops, **self.cfg.get('optimizer', {})) - def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None): + def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None, + **kwargs): """Initialize the pipe for training, using data exampes if available. If no model has been initialized yet, the model is added.""" if self.model is True: @@ -344,7 +345,8 @@ class Tensorizer(Pipe): loss = (d_scores**2).sum() return loss, d_scores - def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None): + def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None, + **kwargs): """Allocate models, pre-process training data and acquire an optimizer. @@ -467,7 +469,8 @@ class Tagger(Pipe): d_scores = self.model.ops.unflatten(d_scores, [len(d) for d in docs]) return float(loss), d_scores - def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None): + def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None, + **kwargs): orig_tag_map = dict(self.vocab.morphology.tag_map) new_tag_map = OrderedDict() for raw_text, annots_brackets in gold_tuples: @@ -641,7 +644,7 @@ class MultitaskObjective(Tagger): pass def begin_training(self, gold_tuples=tuple(), pipeline=None, tok2vec=None, - sgd=None): + sgd=None, **kwargs): gold_tuples = nonproj.preprocess_training_data(gold_tuples) for raw_text, annots_brackets in gold_tuples: for annots, brackets in annots_brackets: @@ -766,7 +769,7 @@ class SimilarityHook(Pipe): def update(self, doc1_doc2, golds, sgd=None, drop=0.): sims, bp_sims = self.model.begin_update(doc1_doc2, drop=drop) - def begin_training(self, _=tuple(), pipeline=None, sgd=None): + def begin_training(self, _=tuple(), pipeline=None, sgd=None, **kwargs): """Allocate model, using width from tensorizer in pipeline. gold_tuples (iterable): Gold-standard training data. @@ -887,6 +890,7 @@ cdef class DependencyParser(Parser): self._multitasks.append(labeller) def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg): + self.add_multitask_objective('tag') for labeller in self._multitasks: tok2vec = self.model[0] labeller.begin_training(gold_tuples, pipeline=pipeline,