mirror of https://github.com/explosion/spaCy.git
Fix multitasks
This commit is contained in:
parent
0b5c72fce2
commit
a4da3120b4
|
@ -126,13 +126,13 @@ cdef class DependencyParser(Parser):
|
|||
def add_multitask_objective(self, mt_component):
|
||||
self._multitasks.append(mt_component)
|
||||
|
||||
def init_multitask_objectives(self, get_examples, pipeline, sgd=None, **cfg):
|
||||
def init_multitask_objectives(self, get_examples, nlp=None, **cfg):
|
||||
# TODO: transfer self.model.get_ref("tok2vec") to the multitask's model ?
|
||||
for labeller in self._multitasks:
|
||||
labeller.model.set_dim("nO", len(self.labels))
|
||||
if labeller.model.has_ref("output_layer"):
|
||||
labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels))
|
||||
labeller.initialize(get_examples, pipeline=pipeline)
|
||||
labeller.initialize(get_examples, nlp=nlp)
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
|
|
|
@ -96,14 +96,14 @@ cdef class EntityRecognizer(Parser):
|
|||
"""Register another component as a multi-task objective. Experimental."""
|
||||
self._multitasks.append(mt_component)
|
||||
|
||||
def init_multitask_objectives(self, get_examples, pipeline, sgd=None, **cfg):
|
||||
def init_multitask_objectives(self, get_examples, nlp=None, **cfg):
|
||||
"""Setup multi-task objective components. Experimental and internal."""
|
||||
# TODO: transfer self.model.get_ref("tok2vec") to the multitask's model ?
|
||||
for labeller in self._multitasks:
|
||||
labeller.model.set_dim("nO", len(self.labels))
|
||||
if labeller.model.has_ref("output_layer"):
|
||||
labeller.model.get_ref("output_layer").set_dim("nO", len(self.labels))
|
||||
labeller.initialize(get_examples, pipeline=pipeline)
|
||||
labeller.initialize(get_examples, nlp=nlp)
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
|
|
Loading…
Reference in New Issue