diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 34ee920c6..093186518 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -68,7 +68,7 @@ def get_templates(name): cdef class ParserModel(AveragedPerceptron): - cdef void set_featuresC(self, ExampleC* eg, const StateC* state) nogil: + cdef void set_featuresC(self, ExampleC* eg, const StateC* state) nogil: fill_context(eg.atoms, state) eg.nr_feat = self.extracter.set_features(eg.features, eg.atoms) @@ -124,7 +124,7 @@ cdef class Parser: elif 'features' not in cfg: cfg['features'] = self.feature_templates self.model = ParserModel(cfg['features']) - self.model.l1_penalty = 1e-7 + self.model.l1_penalty = cfg.get('L1', 0.0) self.cfg = cfg @@ -234,7 +234,7 @@ cdef class Parser: free(eg.scores) free(eg.is_valid) return 0 - + def update(self, Doc tokens, GoldParse gold): """Update the statistical model. @@ -263,11 +263,11 @@ cdef class Parser: self.model.time += 1 guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class) if eg.c.costs[guess] > 0: - best = VecVec.arg_max_if_zero(eg.c.scores, eg.c.costs, eg.c.nr_class) + best = arg_max_if_gold(eg.c.scores, eg.c.costs, eg.c.nr_class) for feat in eg.c.features[:eg.c.nr_feat]: - self.model.update_weight_ftrl(feat.key, best, -feat.value * eg.costs[guess]) - self.model.update_weight_ftrl(feat.key, guess, feat.value * eg.costs[guess]) - + self.model.update_weight_ftrl(feat.key, best, -feat.value * eg.c.costs[guess]) + self.model.update_weight_ftrl(feat.key, guess, feat.value * eg.c.costs[guess]) + action = self.moves.c[guess] action.do(stcls.c, action.label) loss += eg.costs[guess] @@ -392,6 +392,14 @@ class ParserStateError(ValueError): "Please include the text that the parser failed on, which is:\n" "%s" % repr(doc.text)) +cdef int arg_max_if_gold(const weight_t* scores, const weight_t* costs, int n) nogil: + cdef int best = -1 + for i in range(n): + if costs[i] <= 0: + if best == -1 or scores[i] > scores[best]: + best = i + return best + cdef int _arg_max_clas(const weight_t* scores, int move, const Transition* actions, int nr_class) except -1: