Bug fixes to beam search after refactor

This commit is contained in:
Matthew Honnibal 2018-05-08 13:48:50 +02:00
parent 5ed71973b3
commit 4cb0494bef
1 changed files with 11 additions and 6 deletions

View File

@ -231,7 +231,7 @@ cdef class Parser:
weights, sizes) weights, sizes)
return batch return batch
def beam_parse(self, docs, int beam_width=3, float drop=0.): def beam_parse(self, docs, int beam_width, float drop=0.):
cdef Beam beam cdef Beam beam
cdef Doc doc cdef Doc doc
cdef np.ndarray token_ids cdef np.ndarray token_ids
@ -351,10 +351,13 @@ cdef class Parser:
if len(docs) != len(golds): if len(docs) != len(golds):
raise ValueError(Errors.E077.format(value='update', n_docs=len(docs), raise ValueError(Errors.E077.format(value='update', n_docs=len(docs),
n_golds=len(golds))) n_golds=len(golds)))
if losses is None:
losses = {}
losses.setdefault(self.name, 0.)
# The probability we use beam update, instead of falling back to # The probability we use beam update, instead of falling back to
# a greedy update # a greedy update
beam_update_prob = 1-self.cfg.get('beam_update_prob', 0.5) beam_update_prob = self.cfg.get('beam_update_prob', 0.5)
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= beam_update_prob: if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() < beam_update_prob:
return self.update_beam(docs, golds, return self.update_beam(docs, golds,
self.cfg['beam_width'], self.cfg['beam_width'],
drop=drop, sgd=sgd, losses=losses) drop=drop, sgd=sgd, losses=losses)
@ -383,13 +386,15 @@ cdef class Parser:
def update_beam(self, docs, golds, width, drop=0., sgd=None, losses=None): def update_beam(self, docs, golds, width, drop=0., sgd=None, losses=None):
lengths = [len(d) for d in docs] lengths = [len(d) for d in docs]
cut_gold = numpy.random.choice(range(20, 100)) states = self.moves.init_batch(docs)
states, golds, max_steps = self._init_gold_batch(docs, golds, max_length=cut_gold) for gold in golds:
self.moves.preprocess_gold(gold)
model, finish_update = self.model.begin_update(docs, drop=drop) model, finish_update = self.model.begin_update(docs, drop=drop)
states_d_scores, backprops, beams = _beam_utils.update_beam( states_d_scores, backprops, beams = _beam_utils.update_beam(
self.moves, self.nr_feature, max_steps, states, golds, model.state2vec, self.moves, self.nr_feature, 500, states, golds, model.state2vec,
model.vec2scores, width, drop=drop, losses=losses) model.vec2scores, width, drop=drop, losses=losses)
for i, d_scores in enumerate(states_d_scores): for i, d_scores in enumerate(states_d_scores):
losses[self.name] += (d_scores**2).sum()
ids, bp_vectors, bp_scores = backprops[i] ids, bp_vectors, bp_scores = backprops[i]
d_vector = bp_scores(d_scores, sgd=sgd) d_vector = bp_scores(d_scores, sgd=sgd)
if isinstance(model.ops, CupyOps) \ if isinstance(model.ops, CupyOps) \