mirror of https://github.com/explosion/spaCy.git
Allow multitask objectives to be added to the parser and NER more easily
This commit is contained in:
parent
4a7d524efb
commit
203d2ea830
|
@ -882,14 +882,16 @@ cdef class DependencyParser(Parser):
|
||||||
def postprocesses(self):
|
def postprocesses(self):
|
||||||
return [nonproj.deprojectivize]
|
return [nonproj.deprojectivize]
|
||||||
|
|
||||||
|
def add_multitask_objective(self, target):
|
||||||
|
labeller = MultitaskObjective(self.vocab, target=target)
|
||||||
|
self._multitasks.append(labeller)
|
||||||
|
|
||||||
def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
|
def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
|
||||||
for target in []:
|
for labeller in self._multitasks:
|
||||||
labeller = MultitaskObjective(self.vocab, target=target)
|
|
||||||
tok2vec = self.model[0]
|
tok2vec = self.model[0]
|
||||||
labeller.begin_training(gold_tuples, pipeline=pipeline,
|
labeller.begin_training(gold_tuples, pipeline=pipeline,
|
||||||
tok2vec=tok2vec, sgd=sgd)
|
tok2vec=tok2vec, sgd=sgd)
|
||||||
pipeline.append(labeller)
|
pipeline.append((labeller.name, labeller))
|
||||||
self._multitasks.append(labeller)
|
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return (DependencyParser, (self.vocab, self.moves, self.model),
|
return (DependencyParser, (self.vocab, self.moves, self.model),
|
||||||
|
@ -902,14 +904,16 @@ cdef class EntityRecognizer(Parser):
|
||||||
|
|
||||||
nr_feature = 6
|
nr_feature = 6
|
||||||
|
|
||||||
|
def add_multitask_objective(self, target):
|
||||||
|
labeller = MultitaskObjective(self.vocab, target=target)
|
||||||
|
self._multitasks.append(labeller)
|
||||||
|
|
||||||
def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
|
def init_multitask_objectives(self, gold_tuples, pipeline, sgd=None, **cfg):
|
||||||
for target in []:
|
for labeller in self._multitasks:
|
||||||
labeller = MultitaskObjective(self.vocab, target=target)
|
|
||||||
tok2vec = self.model[0]
|
tok2vec = self.model[0]
|
||||||
labeller.begin_training(gold_tuples, pipeline=pipeline,
|
labeller.begin_training(gold_tuples, pipeline=pipeline,
|
||||||
tok2vec=tok2vec)
|
tok2vec=tok2vec)
|
||||||
pipeline.append(labeller)
|
pipeline.append((labeller.name, labeller))
|
||||||
self._multitasks.append(labeller)
|
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return (EntityRecognizer, (self.vocab, self.moves, self.model),
|
return (EntityRecognizer, (self.vocab, self.moves, self.model),
|
||||||
|
|
|
@ -269,9 +269,6 @@ cdef class Parser:
|
||||||
zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
|
zero_init(Affine(nr_class, hidden_width, drop_factor=0.0))
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: This is an unfortunate hack atm!
|
|
||||||
# Used to set input dimensions in network.
|
|
||||||
lower.begin_training(lower.ops.allocate((500, token_vector_width)))
|
|
||||||
cfg = {
|
cfg = {
|
||||||
'nr_class': nr_class,
|
'nr_class': nr_class,
|
||||||
'hidden_depth': depth,
|
'hidden_depth': depth,
|
||||||
|
@ -840,8 +837,14 @@ cdef class Parser:
|
||||||
self.cfg.update(cfg)
|
self.cfg.update(cfg)
|
||||||
elif sgd is None:
|
elif sgd is None:
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
|
self.model[1].begin_training(
|
||||||
|
self.model[1].ops.allocate((5, cfg['token_vector_width'])))
|
||||||
return sgd
|
return sgd
|
||||||
|
|
||||||
|
def add_multitask_objective(self, target):
|
||||||
|
# Defined in subclasses, to avoid circular import
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def init_multitask_objectives(self, gold_tuples, pipeline, **cfg):
|
def init_multitask_objectives(self, gold_tuples, pipeline, **cfg):
|
||||||
'''Setup models for secondary objectives, to benefit from multi-task
|
'''Setup models for secondary objectives, to benefit from multi-task
|
||||||
learning. This method is intended to be overridden by subclasses.
|
learning. This method is intended to be overridden by subclasses.
|
||||||
|
|
Loading…
Reference in New Issue