mirror of https://github.com/explosion/spaCy.git
Use FTRL training in parser
This commit is contained in:
parent
d108534dc2
commit
40703988bc
|
@ -124,6 +124,8 @@ cdef class Parser:
|
||||||
elif 'features' not in cfg:
|
elif 'features' not in cfg:
|
||||||
cfg['features'] = self.feature_templates
|
cfg['features'] = self.feature_templates
|
||||||
self.model = ParserModel(cfg['features'])
|
self.model = ParserModel(cfg['features'])
|
||||||
|
self.model.l1_penalty = 1e-7
|
||||||
|
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
|
@ -258,15 +260,20 @@ cdef class Parser:
|
||||||
self.model.set_featuresC(&eg.c, stcls.c)
|
self.model.set_featuresC(&eg.c, stcls.c)
|
||||||
self.moves.set_costs(eg.c.is_valid, eg.c.costs, stcls, gold)
|
self.moves.set_costs(eg.c.is_valid, eg.c.costs, stcls, gold)
|
||||||
self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat)
|
self.model.set_scoresC(eg.c.scores, eg.c.features, eg.c.nr_feat)
|
||||||
self.model.updateC(&eg.c)
|
self.model.time += 1
|
||||||
guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class)
|
guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class)
|
||||||
|
if eg.c.costs[guess] > 0:
|
||||||
action = self.moves.c[eg.guess]
|
best = VecVec.arg_max_if_zero(eg.c.scores, eg.c.costs, eg.c.nr_class)
|
||||||
|
for feat in eg.c.features[:eg.c.nr_feat]:
|
||||||
|
self.model.update_weight_ftrl(feat.key, best, -feat.value * eg.costs[guess])
|
||||||
|
self.model.update_weight_ftrl(feat.key, guess, feat.value * eg.costs[guess])
|
||||||
|
|
||||||
|
action = self.moves.c[guess]
|
||||||
action.do(stcls.c, action.label)
|
action.do(stcls.c, action.label)
|
||||||
loss += eg.costs[eg.guess]
|
loss += eg.costs[guess]
|
||||||
eg.fill_scores(0, eg.nr_class)
|
eg.fill_scores(0, eg.c.nr_class)
|
||||||
eg.fill_costs(0, eg.nr_class)
|
eg.fill_costs(0, eg.c.nr_class)
|
||||||
eg.fill_is_valid(1, eg.nr_class)
|
eg.fill_is_valid(1, eg.c.nr_class)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def step_through(self, Doc doc):
|
def step_through(self, Doc doc):
|
||||||
|
@ -296,7 +303,7 @@ cdef class Parser:
|
||||||
# Doesn't set label into serializer -- subclasses override it to do that.
|
# Doesn't set label into serializer -- subclasses override it to do that.
|
||||||
for action in self.moves.action_types:
|
for action in self.moves.action_types:
|
||||||
self.moves.add_action(action, label)
|
self.moves.add_action(action, label)
|
||||||
|
|
||||||
|
|
||||||
cdef class StepwiseState:
|
cdef class StepwiseState:
|
||||||
cdef readonly StateClass stcls
|
cdef readonly StateClass stcls
|
||||||
|
|
Loading…
Reference in New Issue