mirror of https://github.com/explosion/spaCy.git
Allow multitask objectives to be added to the parser and NER more easily
This commit is contained in:
parent
4a7d524efb
commit
203d2ea830
|
@ -882,14 +882,16 @@ cdef class DependencyParser(Parser):
|
|||
def postprocesses(self):
|
||||
return [nonproj.deprojectivize]
|
||||
|
||||
def add_multitask_objective(self, target):
|
||||
labeller = MultitaskObjective(self.vocab, target=target)
|
||||
self._multitasks.append(labeller)
|
||||
|
||||
def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
|
||||
for target in []:
|
||||
labeller = MultitaskObjective(self.vocab, target=target)
|
||||
for labeller in self._multitasks:
|
||||
tok2vec = self.model[0]
|
||||
labeller.begin_training(gold_tuples, pipeline=pipeline,
|
||||
tok2vec=tok2vec, sgd=sgd)
|
||||
pipeline.append(labeller)
|
||||
self._multitasks.append(labeller)
|
||||
pipeline.append((labeller.name, labeller))
|
||||
|
||||
def __reduce__(self):
|
||||
return (DependencyParser, (self.vocab, self.moves, self.model),
|
||||
|
@ -902,14 +904,16 @@ cdef class EntityRecognizer(Parser):
|
|||
|
||||
nr_feature = 6
|
||||
|
||||
def add_multitask_objective(self, target):
|
||||
labeller = MultitaskObjective(self.vocab, target=target)
|
||||
self._multitasks.append(labeller)
|
||||
|
||||
def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
|
||||
for target in []:
|
||||
labeller = MultitaskObjective(self.vocab, target=target)
|
||||
for labeller in self._multitasks:
|
||||
tok2vec = self.model[0]
|
||||
labeller.begin_training(gold_tuples, pipeline=pipeline,
|
||||
tok2vec=tok2vec)
|
||||
pipeline.append(labeller)
|
||||
self._multitasks.append(labeller)
|
||||
pipeline.append((labeller.name, labeller))
|
||||
|
||||
def __reduce__(self):
|
||||
return (EntityRecognizer, (self.vocab, self.moves, self.model),
|
||||
|
|
|
@ -269,9 +269,6 @@ cdef class Parser:
|
|||
zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
|
||||
)
|
||||
|
||||
# TODO: This is an unfortunate hack atm!
|
||||
# Used to set input dimensions in network.
|
||||
lower.begin_training(lower.ops.allocate((500, token_vector_width)))
|
||||
cfg = {
|
||||
'nr_class': nr_class,
|
||||
'hidden_depth': depth,
|
||||
|
@ -840,8 +837,14 @@ cdef class Parser:
|
|||
self.cfg.update(cfg)
|
||||
elif sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
self.model[1].begin_training(
|
||||
self.model[1].ops.allocate((5, cfg['token_vector_width'])))
|
||||
return sgd
|
||||
|
||||
def add_multitask_objective(self, target):
|
||||
# Defined in subclasses, to avoid circular import
|
||||
raise NotImplementedError
|
||||
|
||||
def init_multitask_objectives(self, gold_tuples, pipeline, **cfg):
|
||||
'''Setup models for secondary objectives, to benefit from multi-task
|
||||
learning. This method is intended to be overridden by subclasses.
|
||||
|
|
Loading…
Reference in New Issue