diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index b860425cd..4be1046bc 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -107,8 +107,9 @@ cdef class Parser: cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width) beam.initialize(_init_state, tokens.length, tokens.data) beam.check_done(_check_final_state, NULL) + words = [w.orth_ for w in tokens] while not beam.is_done: - self._advance_beam(beam, None, False) + self._advance_beam(beam, None, False, words) state = beam.at(0) self.moves.finalize_state(state) tokens.set_parse(state._sent) @@ -147,9 +148,10 @@ cdef class Parser: gold.check_done(_check_final_state, NULL) violn = MaxViolation() + words = [w.orth_ for w in tokens] while not pred.is_done and not gold.is_done: - self._advance_beam(pred, gold_parse, False) - self._advance_beam(gold, gold_parse, True) + self._advance_beam(pred, gold_parse, False, words) + self._advance_beam(gold, gold_parse, True, words) violn.check(pred, gold) if pred.loss >= 1: counts = {clas: {} for clas in range(self.model.n_classes)} @@ -162,7 +164,7 @@ cdef class Parser: _cleanup(gold) return pred.loss - def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold): + def _advance_beam(self, Beam beam, GoldParse gold, bint follow_gold, words): cdef atom_t[CONTEXT_SIZE] context cdef int i, j, cost cdef bint is_valid @@ -176,13 +178,11 @@ cdef class Parser: if gold is not None: for i in range(beam.size): stcls = beam.at(i) - self.moves.set_costs(beam.costs[i], stcls, gold) - if follow_gold: - n_true = 0 - for j in range(self.moves.n_moves): - beam.is_valid[i][j] *= beam.costs[i][j] == 0 - n_true += beam.is_valid[i][j] - assert n_true >= 1 + if not stcls.is_final(): + self.moves.set_costs(beam.costs[i], stcls, gold) + if follow_gold: + for j in range(self.moves.n_moves): + beam.is_valid[i][j] *= beam.costs[i][j] == 0 beam.advance(_transition_state, _hash_state, self.moves.c) beam.check_done(_check_final_state, NULL) @@ -213,6 +213,7 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: cdef StateClass st = StateClass.init(tokens, length) + st.push() Py_INCREF(st) return st