Enable history features for beam parser

This commit is contained in:
Matthew Honnibal 2017-10-05 21:53:29 -05:00
parent fc06b0a333
commit ca12764772
2 changed files with 17 additions and 5 deletions

View File

@ -21,6 +21,7 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves)
moves = <const Transition*>_moves
dest.clone(src)
moves[clas].do(dest.c, moves[clas].label)
dest.c.push_hist(clas)
cdef int _check_final_state(void* _state, void* extra_args) except -1:
@ -149,7 +150,7 @@ nr_update = 0
def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
states, golds,
state2vec, vec2scores,
int width, float density,
int width, float density, int hist_feats,
losses=None, drop=0.):
global nr_update
cdef MaxViolation violn
@ -180,7 +181,11 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
# Now that we have our flat list of states, feed them through the model
token_ids = get_token_ids(states, nr_feature)
vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop)
scores, bp_scores = vec2scores.begin_update(vectors, drop=drop)
if hist_feats:
hists = numpy.asarray([st.history[:hist_feats] for st in states], dtype='i')
scores, bp_scores = vec2scores.begin_update((vectors, hists), drop=drop)
else:
scores, bp_scores = vec2scores.begin_update(vectors, drop=drop)
# Store the callbacks for the backward pass
backprops.append((token_ids, bp_vectors, bp_scores))

View File

@ -505,7 +505,12 @@ cdef class Parser:
states.append(stcls)
token_ids = self.get_token_ids(states)
vectors = state2vec(token_ids)
scores = vec2scores(vectors)
if self.cfg.get('hist_size', 0):
hists = numpy.asarray([st.history[:self.cfg['hist_size']]
for st in states], dtype='i')
scores = vec2scores(vectors, drop=drop)
else:
scores = vec2scores(vectors, drop=drop)
j = 0
c_scores = <float*>scores.data
for i in range(beam.size):
@ -537,6 +542,7 @@ cdef class Parser:
guess = arg_maxout_if_valid(scores, is_valid, nr_class, nr_piece)
action = self.moves.c[guess]
action.do(state, action.label)
state.push_hist(guess)
free(is_valid)
free(scores)
@ -634,7 +640,7 @@ cdef class Parser:
states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, 500,
states, golds,
state2vec, vec2scores,
width, density,
width, density, self.cfg.get('hist_size', 0),
drop=drop, losses=losses)
backprop_lower = []
cdef float batch_size = len(docs)
@ -967,6 +973,7 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves)
moves = <const Transition*>_moves
dest.clone(src)
moves[clas].do(dest.c, moves[clas].label)
dest.c.push_hist(clas)
cdef int _check_final_state(void* _state, void* extra_args) except -1: