From a9b1f23c7def351a4274370991a3bbae37b78f6f Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 26 Mar 2017 09:26:30 -0500 Subject: [PATCH] Enable regression loss for parser --- spacy/syntax/parser.pyx | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index c94d4ebee..764efea8b 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -52,7 +52,7 @@ from ._parse_features cimport fill_context from .stateclass cimport StateClass from ._state cimport StateC -USE_FTRL = False +USE_FTRL = True DEBUG = False def set_debug(val): global DEBUG @@ -82,14 +82,19 @@ cdef class ParserModel(AveragedPerceptron): def update(self, Example eg, itn=0): '''Does regression on negative cost. Sort of cute?''' self.time += 1 - best = arg_max_if_gold(eg.c.scores, eg.c.costs, eg.c.nr_class) - guess = eg.guess + cdef int best = arg_max_if_gold(eg.c.scores, eg.c.costs, eg.c.nr_class) + cdef int guess = eg.guess if guess == best or best == -1: return 0.0 + cdef FeatureC feat + cdef int clas + cdef weight_t gradient if USE_FTRL: for feat in eg.c.features[:eg.c.nr_feat]: - self.update_weight_ftrl(feat.key, guess, feat.value * eg.c.costs[guess]) - self.update_weight_ftrl(feat.key, best, -feat.value * eg.c.costs[guess]) + for clas in range(eg.c.nr_class): + if eg.c.is_valid[clas] and eg.c.scores[clas] >= eg.c.scores[best]: + gradient = eg.c.scores[clas] + eg.c.costs[clas] + self.update_weight_ftrl(feat.key, clas, feat.value * gradient) else: for feat in eg.c.features[:eg.c.nr_feat]: self.update_weight(feat.key, guess, feat.value * eg.c.costs[guess])