diff --git a/spacy/syntax/_beam_utils.pyx b/spacy/syntax/_beam_utils.pyx index a26900f6b..da4efefbc 100644 --- a/spacy/syntax/_beam_utils.pyx +++ b/spacy/syntax/_beam_utils.pyx @@ -21,6 +21,7 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) moves = _moves dest.clone(src) moves[clas].do(dest.c, moves[clas].label) + dest.c.push_hist(clas) cdef int _check_final_state(void* _state, void* extra_args) except -1: @@ -148,8 +149,8 @@ def get_token_ids(states, int n_tokens): nr_update = 0 def update_beam(TransitionSystem moves, int nr_feature, int max_steps, states, golds, - state2vec, vec2scores, - int width, float density, + state2vec, vec2scores, + int width, float density, int hist_feats, losses=None, drop=0.): global nr_update cdef MaxViolation violn @@ -180,7 +181,11 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps, # Now that we have our flat list of states, feed them through the model token_ids = get_token_ids(states, nr_feature) vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop) - scores, bp_scores = vec2scores.begin_update(vectors, drop=drop) + if hist_feats: + hists = numpy.asarray([st.history[:hist_feats] for st in states], dtype='i') + scores, bp_scores = vec2scores.begin_update((vectors, hists), drop=drop) + else: + scores, bp_scores = vec2scores.begin_update(vectors, drop=drop) # Store the callbacks for the backward pass backprops.append((token_ids, bp_vectors, bp_scores)) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index b57e8b466..9a071ae14 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -505,7 +505,12 @@ cdef class Parser: states.append(stcls) token_ids = self.get_token_ids(states) vectors = state2vec(token_ids) - scores = vec2scores(vectors) + if self.cfg.get('hist_size', 0): + hists = numpy.asarray([st.history[:self.cfg['hist_size']] + for st in states], dtype='i') + scores = vec2scores(vectors, drop=drop) + else: + scores = vec2scores(vectors, drop=drop) j = 0 c_scores = scores.data for i in range(beam.size): @@ -537,6 +542,7 @@ cdef class Parser: guess = arg_maxout_if_valid(scores, is_valid, nr_class, nr_piece) action = self.moves.c[guess] action.do(state, action.label) + state.push_hist(guess) free(is_valid) free(scores) @@ -634,7 +640,7 @@ cdef class Parser: states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, 500, states, golds, state2vec, vec2scores, - width, density, + width, density, self.cfg.get('hist_size', 0), drop=drop, losses=losses) backprop_lower = [] cdef float batch_size = len(docs) @@ -967,6 +973,7 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) moves = _moves dest.clone(src) moves[clas].do(dest.c, moves[clas].label) + dest.c.push_hist(clas) cdef int _check_final_state(void* _state, void* extra_args) except -1: