mirror of https://github.com/explosion/spaCy.git
Enable history features for beam parser
This commit is contained in:
parent
fc06b0a333
commit
ca12764772
|
@ -21,6 +21,7 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves)
|
||||||
moves = <const Transition*>_moves
|
moves = <const Transition*>_moves
|
||||||
dest.clone(src)
|
dest.clone(src)
|
||||||
moves[clas].do(dest.c, moves[clas].label)
|
moves[clas].do(dest.c, moves[clas].label)
|
||||||
|
dest.c.push_hist(clas)
|
||||||
|
|
||||||
|
|
||||||
cdef int _check_final_state(void* _state, void* extra_args) except -1:
|
cdef int _check_final_state(void* _state, void* extra_args) except -1:
|
||||||
|
@ -149,7 +150,7 @@ nr_update = 0
|
||||||
def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
||||||
states, golds,
|
states, golds,
|
||||||
state2vec, vec2scores,
|
state2vec, vec2scores,
|
||||||
int width, float density,
|
int width, float density, int hist_feats,
|
||||||
losses=None, drop=0.):
|
losses=None, drop=0.):
|
||||||
global nr_update
|
global nr_update
|
||||||
cdef MaxViolation violn
|
cdef MaxViolation violn
|
||||||
|
@ -180,7 +181,11 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
||||||
# Now that we have our flat list of states, feed them through the model
|
# Now that we have our flat list of states, feed them through the model
|
||||||
token_ids = get_token_ids(states, nr_feature)
|
token_ids = get_token_ids(states, nr_feature)
|
||||||
vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop)
|
vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop)
|
||||||
scores, bp_scores = vec2scores.begin_update(vectors, drop=drop)
|
if hist_feats:
|
||||||
|
hists = numpy.asarray([st.history[:hist_feats] for st in states], dtype='i')
|
||||||
|
scores, bp_scores = vec2scores.begin_update((vectors, hists), drop=drop)
|
||||||
|
else:
|
||||||
|
scores, bp_scores = vec2scores.begin_update(vectors, drop=drop)
|
||||||
|
|
||||||
# Store the callbacks for the backward pass
|
# Store the callbacks for the backward pass
|
||||||
backprops.append((token_ids, bp_vectors, bp_scores))
|
backprops.append((token_ids, bp_vectors, bp_scores))
|
||||||
|
|
|
@ -505,7 +505,12 @@ cdef class Parser:
|
||||||
states.append(stcls)
|
states.append(stcls)
|
||||||
token_ids = self.get_token_ids(states)
|
token_ids = self.get_token_ids(states)
|
||||||
vectors = state2vec(token_ids)
|
vectors = state2vec(token_ids)
|
||||||
scores = vec2scores(vectors)
|
if self.cfg.get('hist_size', 0):
|
||||||
|
hists = numpy.asarray([st.history[:self.cfg['hist_size']]
|
||||||
|
for st in states], dtype='i')
|
||||||
|
scores = vec2scores(vectors, drop=drop)
|
||||||
|
else:
|
||||||
|
scores = vec2scores(vectors, drop=drop)
|
||||||
j = 0
|
j = 0
|
||||||
c_scores = <float*>scores.data
|
c_scores = <float*>scores.data
|
||||||
for i in range(beam.size):
|
for i in range(beam.size):
|
||||||
|
@ -537,6 +542,7 @@ cdef class Parser:
|
||||||
guess = arg_maxout_if_valid(scores, is_valid, nr_class, nr_piece)
|
guess = arg_maxout_if_valid(scores, is_valid, nr_class, nr_piece)
|
||||||
action = self.moves.c[guess]
|
action = self.moves.c[guess]
|
||||||
action.do(state, action.label)
|
action.do(state, action.label)
|
||||||
|
state.push_hist(guess)
|
||||||
|
|
||||||
free(is_valid)
|
free(is_valid)
|
||||||
free(scores)
|
free(scores)
|
||||||
|
@ -634,7 +640,7 @@ cdef class Parser:
|
||||||
states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, 500,
|
states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, 500,
|
||||||
states, golds,
|
states, golds,
|
||||||
state2vec, vec2scores,
|
state2vec, vec2scores,
|
||||||
width, density,
|
width, density, self.cfg.get('hist_size', 0),
|
||||||
drop=drop, losses=losses)
|
drop=drop, losses=losses)
|
||||||
backprop_lower = []
|
backprop_lower = []
|
||||||
cdef float batch_size = len(docs)
|
cdef float batch_size = len(docs)
|
||||||
|
@ -967,6 +973,7 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves)
|
||||||
moves = <const Transition*>_moves
|
moves = <const Transition*>_moves
|
||||||
dest.clone(src)
|
dest.clone(src)
|
||||||
moves[clas].do(dest.c, moves[clas].label)
|
moves[clas].do(dest.c, moves[clas].label)
|
||||||
|
dest.c.push_hist(clas)
|
||||||
|
|
||||||
|
|
||||||
cdef int _check_final_state(void* _state, void* extra_args) except -1:
|
cdef int _check_final_state(void* _state, void* extra_args) except -1:
|
||||||
|
|
Loading…
Reference in New Issue