Readd beam search after refactor

This commit is contained in:
Matthew Honnibal 2018-05-08 00:19:52 +02:00
parent 36b2c9bdd5
commit 5a0f26be0c
3 changed files with 54 additions and 18 deletions

View File

@ -31,6 +31,7 @@ MOD_NAMES = [
'spacy.tokenizer',
'spacy.syntax.nn_parser',
'spacy.syntax._parser_model',
'spacy.syntax._beam_utils',
'spacy.syntax.nonproj',
'spacy.syntax.transition_system',
'spacy.syntax.arc_eager',

View File

@ -0,0 +1,6 @@
from thinc.typedefs cimport class_t
# These are passed as callbacks to thinc.search.Beam
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1
cdef int check_final_state(void* _state, void* extra_args) except -1

View File

@ -15,7 +15,7 @@ from .stateclass cimport StateC, StateClass
# These are passed as callbacks to thinc.search.Beam
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
dest = <StateC*>_dest
src = <StateC*>_src
moves = <const Transition*>_moves
@ -24,12 +24,12 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves)
dest.push_hist(clas)
cdef int _check_final_state(void* _state, void* extra_args) except -1:
cdef int check_final_state(void* _state, void* extra_args) except -1:
state = <StateC*>_state
return state.is_final()
cdef hash_t _hash_state(void* _state, void* _) except 0:
cdef hash_t hash_state(void* _state, void* _) except 0:
state = <StateC*>_state
if state.is_final():
return 1
@ -37,6 +37,20 @@ cdef hash_t _hash_state(void* _state, void* _) except 0:
return state.hash()
def collect_states(beams):
cdef StateClass state
cdef Beam beam
states = []
for state_or_beam in beams:
if isinstance(state_or_beam, StateClass):
states.append(state_or_beam)
else:
beam = state_or_beam
state = StateClass.borrow(<StateC*>beam.at(0))
states.append(state)
return states
cdef class ParserBeam(object):
cdef public TransitionSystem moves
cdef public object states
@ -82,8 +96,8 @@ cdef class ParserBeam(object):
self._set_scores(beam, scores[i])
if self.golds is not None:
self._set_costs(beam, self.golds[i], follow_gold=follow_gold)
beam.advance(_transition_state, NULL, <void*>self.moves.c)
beam.check_done(_check_final_state, NULL)
beam.advance(transition_state, NULL, <void*>self.moves.c)
beam.check_done(check_final_state, NULL)
# This handles the non-monotonic stuff for the parser.
if beam.is_done and self.golds is not None:
for j in range(beam.size):
@ -144,15 +158,12 @@ nr_update = 0
def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
states, golds,
state2vec, vec2scores,
int width, float density, int hist_feats,
losses=None, drop=0.):
int width, losses=None, drop=0.):
global nr_update
cdef MaxViolation violn
nr_update += 1
pbeam = ParserBeam(moves, states, golds,
width=width, density=density)
gbeam = ParserBeam(moves, states, golds,
width=width, density=density)
pbeam = ParserBeam(moves, states, golds, width=width)
gbeam = ParserBeam(moves, states, golds, width=width)
cdef StateClass state
beam_maps = []
backprops = []
@ -177,13 +188,7 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
# Now that we have our flat list of states, feed them through the model
token_ids = get_token_ids(states, nr_feature)
vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop)
if hist_feats:
hists = numpy.asarray([st.history[:hist_feats] for st in states],
dtype='i')
scores, bp_scores = vec2scores.begin_update((vectors, hists),
drop=drop)
else:
scores, bp_scores = vec2scores.begin_update(vectors, drop=drop)
scores, bp_scores = vec2scores.begin_update(vectors, drop=drop)
# Store the callbacks for the backward pass
backprops.append((token_ids, bp_vectors, bp_scores))
@ -291,3 +296,27 @@ def get_gradient(nr_class, beam_maps, histories, losses):
grads[j][i, clas] += loss
key = key + tuple([clas])
return grads
def cleanup_beam(Beam beam):
cdef StateC* state
# Once parsing has finished, states in beam may not be unique. Is this
# correct?
seen = set()
for i in range(beam.width):
addr = <size_t>beam._parents[i].content
if addr not in seen:
state = <StateC*>addr
del state
seen.add(addr)
else:
raise ValueError(Errors.E023.format(addr=addr, i=i))
addr = <size_t>beam._states[i].content
if addr not in seen:
state = <StateC*>addr
del state
seen.add(addr)
else:
raise ValueError(Errors.E023.format(addr=addr, i=i))