mirror of https://github.com/explosion/spaCy.git
Tweaks to beam parser
This commit is contained in:
parent
500e92553d
commit
23537a011d
|
@ -216,12 +216,13 @@ def get_states(pbeams, gbeams, beam_map, nr_update):
|
|||
for eg_id, (pbeam, gbeam) in enumerate(zip(pbeams, gbeams)):
|
||||
p_indices.append([])
|
||||
g_indices.append([])
|
||||
if pbeam.loss > 0 and pbeam.min_score > gbeam.score:
|
||||
if pbeam.loss > 0 and pbeam.min_score > (gbeam.score + nr_update):
|
||||
continue
|
||||
for i in range(pbeam.size):
|
||||
state = <StateClass>pbeam.at(i)
|
||||
if not state.is_final():
|
||||
key = tuple([eg_id] + pbeam.histories[i])
|
||||
assert key not in seen, (key, seen)
|
||||
seen[key] = len(states)
|
||||
p_indices[-1].append(len(states))
|
||||
states.append(state)
|
||||
|
@ -257,12 +258,18 @@ def get_gradient(nr_class, beam_maps, histories, losses):
|
|||
"""
|
||||
nr_step = len(beam_maps)
|
||||
grads = []
|
||||
for beam_map in beam_maps:
|
||||
if beam_map:
|
||||
grads.append(numpy.zeros((max(beam_map.values())+1, nr_class), dtype='f'))
|
||||
nr_step = 0
|
||||
for eg_id, hists in enumerate(histories):
|
||||
for loss, hist in zip(losses[eg_id], hists):
|
||||
if abs(loss) >= 0.0001 and not numpy.isnan(loss):
|
||||
nr_step = max(nr_step, len(hist))
|
||||
for i in range(nr_step):
|
||||
grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class), dtype='f'))
|
||||
assert len(histories) == len(losses)
|
||||
for eg_id, hists in enumerate(histories):
|
||||
for loss, hist in zip(losses[eg_id], hists):
|
||||
if abs(loss) < 0.0001 or numpy.isnan(loss):
|
||||
continue
|
||||
key = tuple([eg_id])
|
||||
for j, clas in enumerate(hist):
|
||||
i = beam_maps[j][key]
|
||||
|
|
Loading…
Reference in New Issue