From 278a4c17c642b366b71dccc7cec202dc22cfcb93 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 3 Oct 2017 13:27:10 +0200 Subject: [PATCH] Fix history features --- spacy/syntax/nn_parser.pyx | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 2277e568e..4a874e834 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -70,6 +70,8 @@ from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG from . import _beam_utils USE_HISTORY = True +HIST_SIZE = 2 +HIST_DIMS = 16 def get_templates(*args, **kwargs): return [] @@ -262,13 +264,11 @@ cdef class Parser: with Model.use_device('cpu'): if depth == 0: - hist_size = 8 - nr_dim = 8 if USE_HISTORY: upper = chain( - HistoryFeatures(nr_class=nr_class, hist_size=hist_size, - nr_dim=nr_dim), - zero_init(Affine(nr_class, nr_class+hist_size*nr_dim, + HistoryFeatures(nr_class=nr_class, hist_size=HIST_SIZE, + nr_dim=HIST_DIMS), + zero_init(Affine(nr_class, nr_class+HIST_SIZE*HIST_DIMS, drop_factor=0.0))) upper.is_noop = False else: @@ -736,15 +736,13 @@ cdef class Parser: cdef StateClass state cdef int[500] is_valid # TODO: Unhack cdef float* c_scores = &scores[0, 0] - hists = [] for state in states: self.moves.set_valid(is_valid, state.c) guess = arg_max_if_valid(c_scores, is_valid, scores.shape[1]) action = self.moves.c[guess] action.do(state.c, action.label) c_scores += scores.shape[1] - hists.append(guess) - return hists + state.c.push_hist(guess) def get_batch_loss(self, states, golds, float[:, ::1] scores): cdef StateClass state