diff --git a/spacy/syntax/_beam_utils.pyx b/spacy/syntax/_beam_utils.pyx index 15f1ce59b..09738b584 100644 --- a/spacy/syntax/_beam_utils.pyx +++ b/spacy/syntax/_beam_utils.pyx @@ -97,12 +97,19 @@ cdef class ParserBeam(object): def _set_scores(self, Beam beam, float[:, ::1] scores): cdef float* c_scores = &scores[0, 0] - for i in range(beam.size): + cdef int nr_state = min(scores.shape[0], beam.size) + cdef int nr_class = scores.shape[1] + for i in range(nr_state): state = beam.at(i) if not state.is_final(): - for j in range(beam.nr_class): - beam.scores[i][j] = c_scores[i * beam.nr_class + j] + for j in range(nr_class): + beam.scores[i][j] = c_scores[i * nr_class + j] self.moves.set_valid(beam.is_valid[i], state.c) + else: + for j in range(beam.nr_class): + beam.scores[i][j] = 0 + beam.costs[i][j] = 0 + def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False): for i in range(beam.size): @@ -196,8 +203,7 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps, losses = [((v.p_probs + v.g_probs) if v.p_probs else []) for v in violns] states_d_scores = get_gradient(moves.n_moves, beam_maps, histories, losses) - assert len(states_d_scores) == len(backprops), (len(states_d_scores), len(backprops)) - return states_d_scores, backprops + return states_d_scores, backprops[:len(states_d_scores)] def get_states(pbeams, gbeams, beam_map, nr_update):