diff --git a/spacy/syntax/beam_parser.pyx b/spacy/syntax/beam_parser.pyx index 447bbc811..164cda605 100644 --- a/spacy/syntax/beam_parser.pyx +++ b/spacy/syntax/beam_parser.pyx @@ -96,7 +96,7 @@ cdef class BeamParser(Parser): tokens[i] = state.c._sent[i] _cleanup(beam) - def train(self, Doc tokens, GoldParse gold_parse, itn=0): + def update(self, Doc tokens, GoldParse gold_parse, itn=0): self.moves.preprocess_gold(gold_parse) cdef Beam pred = Beam(self.moves.n_moves, self.beam_width) pred.initialize(_init_state, tokens.length, tokens.c) @@ -133,7 +133,7 @@ cdef class BeamParser(Parser): random.shuffle(histories) for grad, hist in histories: assert not math.isnan(grad) and not math.isinf(grad), hist - self.model._update_from_history(self.moves, tokens, hist, grad) + self.model.update_from_history(self.moves, tokens, hist, grad) _cleanup(pred) _cleanup(gold) return pred.loss @@ -169,7 +169,7 @@ cdef class BeamParser(Parser): self.moves.set_costs(beam.is_valid[i], beam.costs[i], stcls, gold) if follow_gold: for j in range(self.moves.n_moves): - beam.is_valid[i][j] *= beam.costs[i][j] < 1 + beam.is_valid[i][j] *= beam.costs[i][j] <= 0 if follow_gold: beam.advance(_transition_state, NULL, self.moves.c) else: @@ -189,6 +189,7 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: cdef StateClass st = StateClass.init(tokens, length) # Ensure sent_start is set to 0 throughout + # Ensure sent_start is set to 0 throughout for i in range(st.c.length): st.c._sent[i].sent_start = False st.c._sent[i].l_edge = i