From 4cb0494beff4a2f2aaaa51ced291045948442123 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 8 May 2018 13:48:50 +0200 Subject: [PATCH] Bug fixes to beam search after refactor --- spacy/syntax/nn_parser.pyx | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 9fad36f46..5a74115e5 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -231,7 +231,7 @@ cdef class Parser: weights, sizes) return batch - def beam_parse(self, docs, int beam_width=3, float drop=0.): + def beam_parse(self, docs, int beam_width, float drop=0.): cdef Beam beam cdef Doc doc cdef np.ndarray token_ids @@ -351,10 +351,13 @@ cdef class Parser: if len(docs) != len(golds): raise ValueError(Errors.E077.format(value='update', n_docs=len(docs), n_golds=len(golds))) + if losses is None: + losses = {} + losses.setdefault(self.name, 0.) # The probability we use beam update, instead of falling back to # a greedy update - beam_update_prob = 1-self.cfg.get('beam_update_prob', 0.5) - if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= beam_update_prob: + beam_update_prob = self.cfg.get('beam_update_prob', 0.5) + if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() < beam_update_prob: return self.update_beam(docs, golds, self.cfg['beam_width'], drop=drop, sgd=sgd, losses=losses) @@ -383,13 +386,15 @@ cdef class Parser: def update_beam(self, docs, golds, width, drop=0., sgd=None, losses=None): lengths = [len(d) for d in docs] - cut_gold = numpy.random.choice(range(20, 100)) - states, golds, max_steps = self._init_gold_batch(docs, golds, max_length=cut_gold) + states = self.moves.init_batch(docs) + for gold in golds: + self.moves.preprocess_gold(gold) model, finish_update = self.model.begin_update(docs, drop=drop) states_d_scores, backprops, beams = _beam_utils.update_beam( - self.moves, self.nr_feature, max_steps, states, golds, model.state2vec, + self.moves, self.nr_feature, 500, states, golds, model.state2vec, model.vec2scores, width, drop=drop, losses=losses) for i, d_scores in enumerate(states_d_scores): + losses[self.name] += (d_scores**2).sum() ids, bp_vectors, bp_scores = backprops[i] d_vector = bp_scores(d_scores, sgd=sgd) if isinstance(model.ops, CupyOps) \