mirror of https://github.com/explosion/spaCy.git
Bug fixes to beam search after refactor
This commit is contained in:
parent
5ed71973b3
commit
4cb0494bef
|
@ -231,7 +231,7 @@ cdef class Parser:
|
|||
weights, sizes)
|
||||
return batch
|
||||
|
||||
def beam_parse(self, docs, int beam_width=3, float drop=0.):
|
||||
def beam_parse(self, docs, int beam_width, float drop=0.):
|
||||
cdef Beam beam
|
||||
cdef Doc doc
|
||||
cdef np.ndarray token_ids
|
||||
|
@ -351,10 +351,13 @@ cdef class Parser:
|
|||
if len(docs) != len(golds):
|
||||
raise ValueError(Errors.E077.format(value='update', n_docs=len(docs),
|
||||
n_golds=len(golds)))
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.)
|
||||
# The probability we use beam update, instead of falling back to
|
||||
# a greedy update
|
||||
beam_update_prob = 1-self.cfg.get('beam_update_prob', 0.5)
|
||||
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= beam_update_prob:
|
||||
beam_update_prob = self.cfg.get('beam_update_prob', 0.5)
|
||||
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() < beam_update_prob:
|
||||
return self.update_beam(docs, golds,
|
||||
self.cfg['beam_width'],
|
||||
drop=drop, sgd=sgd, losses=losses)
|
||||
|
@ -383,13 +386,15 @@ cdef class Parser:
|
|||
|
||||
def update_beam(self, docs, golds, width, drop=0., sgd=None, losses=None):
|
||||
lengths = [len(d) for d in docs]
|
||||
cut_gold = numpy.random.choice(range(20, 100))
|
||||
states, golds, max_steps = self._init_gold_batch(docs, golds, max_length=cut_gold)
|
||||
states = self.moves.init_batch(docs)
|
||||
for gold in golds:
|
||||
self.moves.preprocess_gold(gold)
|
||||
model, finish_update = self.model.begin_update(docs, drop=drop)
|
||||
states_d_scores, backprops, beams = _beam_utils.update_beam(
|
||||
self.moves, self.nr_feature, max_steps, states, golds, model.state2vec,
|
||||
self.moves, self.nr_feature, 500, states, golds, model.state2vec,
|
||||
model.vec2scores, width, drop=drop, losses=losses)
|
||||
for i, d_scores in enumerate(states_d_scores):
|
||||
losses[self.name] += (d_scores**2).sum()
|
||||
ids, bp_vectors, bp_scores = backprops[i]
|
||||
d_vector = bp_scores(d_scores, sgd=sgd)
|
||||
if isinstance(model.ops, CupyOps) \
|
||||
|
|
Loading…
Reference in New Issue