diff --git a/spacy/syntax/_beam_utils.pyx b/spacy/syntax/_beam_utils.pyx index 09738b584..48030b72a 100644 --- a/spacy/syntax/_beam_utils.pyx +++ b/spacy/syntax/_beam_utils.pyx @@ -216,12 +216,13 @@ def get_states(pbeams, gbeams, beam_map, nr_update): for eg_id, (pbeam, gbeam) in enumerate(zip(pbeams, gbeams)): p_indices.append([]) g_indices.append([]) - if pbeam.loss > 0 and pbeam.min_score > gbeam.score: + if pbeam.loss > 0 and pbeam.min_score > (gbeam.score + nr_update): continue for i in range(pbeam.size): state = pbeam.at(i) if not state.is_final(): key = tuple([eg_id] + pbeam.histories[i]) + assert key not in seen, (key, seen) seen[key] = len(states) p_indices[-1].append(len(states)) states.append(state) @@ -257,12 +258,18 @@ def get_gradient(nr_class, beam_maps, histories, losses): """ nr_step = len(beam_maps) grads = [] - for beam_map in beam_maps: - if beam_map: - grads.append(numpy.zeros((max(beam_map.values())+1, nr_class), dtype='f')) + nr_step = 0 + for eg_id, hists in enumerate(histories): + for loss, hist in zip(losses[eg_id], hists): + if abs(loss) >= 0.0001 and not numpy.isnan(loss): + nr_step = max(nr_step, len(hist)) + for i in range(nr_step): + grads.append(numpy.zeros((max(beam_maps[i].values())+1, nr_class), dtype='f')) assert len(histories) == len(losses) for eg_id, hists in enumerate(histories): for loss, hist in zip(losses[eg_id], hists): + if abs(loss) < 0.0001 or numpy.isnan(loss): + continue key = tuple([eg_id]) for j, clas in enumerate(hist): i = beam_maps[j][key]