mirror of https://github.com/explosion/spaCy.git
Tweaks to beam parser
This commit is contained in:
parent
500e92553d
commit
23537a011d
|
@ -216,12 +216,13 @@ def get_states(pbeams, gbeams, beam_map, nr_update):
|
||||||
for eg_id, (pbeam, gbeam) in enumerate(zip(pbeams, gbeams)):
|
for eg_id, (pbeam, gbeam) in enumerate(zip(pbeams, gbeams)):
|
||||||
p_indices.append([])
|
p_indices.append([])
|
||||||
g_indices.append([])
|
g_indices.append([])
|
||||||
if pbeam.loss > 0 and pbeam.min_score > gbeam.score:
|
if pbeam.loss > 0 and pbeam.min_score > (gbeam.score + nr_update):
|
||||||
continue
|
continue
|
||||||
for i in range(pbeam.size):
|
for i in range(pbeam.size):
|
||||||
state = <StateClass>pbeam.at(i)
|
state = <StateClass>pbeam.at(i)
|
||||||
if not state.is_final():
|
if not state.is_final():
|
||||||
key = tuple([eg_id] + pbeam.histories[i])
|
key = tuple([eg_id] + pbeam.histories[i])
|
||||||
|
assert key not in seen, (key, seen)
|
||||||
seen[key] = len(states)
|
seen[key] = len(states)
|
||||||
p_indices[-1].append(len(states))
|
p_indices[-1].append(len(states))
|
||||||
states.append(state)
|
states.append(state)
|
||||||
|
@ -257,12 +258,18 @@ def get_gradient(nr_class, beam_maps, histories, losses):
|
||||||
"""
|
"""
|
||||||
nr_step = len(beam_maps)
|
nr_step = len(beam_maps)
|
||||||
grads = []
|
grads = []
|
||||||
for beam_map in beam_maps:
|
nr_step = 0
|
||||||
if beam_map:
|
for eg_id, hists in enumerate(histories):
|
||||||
grads.append(numpy.zeros((max(beam_map.values())+1, nr_class), dtype='f'))
|
for loss, hist in zip(losses[eg_id], hists):
|
||||||
|
if abs(loss) >= 0.0001 and not numpy.isnan(loss):
|
||||||
|
nr_step = max(nr_step, len(hist))
|
||||||
|
for i in range(nr_step):
|
||||||
|
grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class), dtype='f'))
|
||||||
assert len(histories) == len(losses)
|
assert len(histories) == len(losses)
|
||||||
for eg_id, hists in enumerate(histories):
|
for eg_id, hists in enumerate(histories):
|
||||||
for loss, hist in zip(losses[eg_id], hists):
|
for loss, hist in zip(losses[eg_id], hists):
|
||||||
|
if abs(loss) < 0.0001 or numpy.isnan(loss):
|
||||||
|
continue
|
||||||
key = tuple([eg_id])
|
key = tuple([eg_id])
|
||||||
for j, clas in enumerate(hist):
|
for j, clas in enumerate(hist):
|
||||||
i = beam_maps[j][key]
|
i = beam_maps[j][key]
|
||||||
|
|
Loading…
Reference in New Issue