mirror of https://github.com/explosion/spaCy.git
Bug fixes to beam search after refactor
This commit is contained in:
parent
5ed71973b3
commit
4cb0494bef
|
@ -231,7 +231,7 @@ cdef class Parser:
|
||||||
weights, sizes)
|
weights, sizes)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def beam_parse(self, docs, int beam_width=3, float drop=0.):
|
def beam_parse(self, docs, int beam_width, float drop=0.):
|
||||||
cdef Beam beam
|
cdef Beam beam
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
cdef np.ndarray token_ids
|
cdef np.ndarray token_ids
|
||||||
|
@ -351,10 +351,13 @@ cdef class Parser:
|
||||||
if len(docs) != len(golds):
|
if len(docs) != len(golds):
|
||||||
raise ValueError(Errors.E077.format(value='update', n_docs=len(docs),
|
raise ValueError(Errors.E077.format(value='update', n_docs=len(docs),
|
||||||
n_golds=len(golds)))
|
n_golds=len(golds)))
|
||||||
|
if losses is None:
|
||||||
|
losses = {}
|
||||||
|
losses.setdefault(self.name, 0.)
|
||||||
# The probability we use beam update, instead of falling back to
|
# The probability we use beam update, instead of falling back to
|
||||||
# a greedy update
|
# a greedy update
|
||||||
beam_update_prob = 1-self.cfg.get('beam_update_prob', 0.5)
|
beam_update_prob = self.cfg.get('beam_update_prob', 0.5)
|
||||||
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= beam_update_prob:
|
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() < beam_update_prob:
|
||||||
return self.update_beam(docs, golds,
|
return self.update_beam(docs, golds,
|
||||||
self.cfg['beam_width'],
|
self.cfg['beam_width'],
|
||||||
drop=drop, sgd=sgd, losses=losses)
|
drop=drop, sgd=sgd, losses=losses)
|
||||||
|
@ -383,13 +386,15 @@ cdef class Parser:
|
||||||
|
|
||||||
def update_beam(self, docs, golds, width, drop=0., sgd=None, losses=None):
|
def update_beam(self, docs, golds, width, drop=0., sgd=None, losses=None):
|
||||||
lengths = [len(d) for d in docs]
|
lengths = [len(d) for d in docs]
|
||||||
cut_gold = numpy.random.choice(range(20, 100))
|
states = self.moves.init_batch(docs)
|
||||||
states, golds, max_steps = self._init_gold_batch(docs, golds, max_length=cut_gold)
|
for gold in golds:
|
||||||
|
self.moves.preprocess_gold(gold)
|
||||||
model, finish_update = self.model.begin_update(docs, drop=drop)
|
model, finish_update = self.model.begin_update(docs, drop=drop)
|
||||||
states_d_scores, backprops, beams = _beam_utils.update_beam(
|
states_d_scores, backprops, beams = _beam_utils.update_beam(
|
||||||
self.moves, self.nr_feature, max_steps, states, golds, model.state2vec,
|
self.moves, self.nr_feature, 500, states, golds, model.state2vec,
|
||||||
model.vec2scores, width, drop=drop, losses=losses)
|
model.vec2scores, width, drop=drop, losses=losses)
|
||||||
for i, d_scores in enumerate(states_d_scores):
|
for i, d_scores in enumerate(states_d_scores):
|
||||||
|
losses[self.name] += (d_scores**2).sum()
|
||||||
ids, bp_vectors, bp_scores = backprops[i]
|
ids, bp_vectors, bp_scores = backprops[i]
|
||||||
d_vector = bp_scores(d_scores, sgd=sgd)
|
d_vector = bp_scores(d_scores, sgd=sgd)
|
||||||
if isinstance(model.ops, CupyOps) \
|
if isinstance(model.ops, CupyOps) \
|
||||||
|
|
Loading…
Reference in New Issue