diff --git a/spacy/syntax/parser.pxd b/spacy/syntax/parser.pxd index fc15ac2df..1b4bf15fd 100644 --- a/spacy/syntax/parser.pxd +++ b/spacy/syntax/parser.pxd @@ -14,6 +14,5 @@ cdef class Parser: cdef readonly Model model cdef readonly TransitionSystem moves - - cdef State* _greedy_parse(self, Tokens tokens) except NULL - cdef State* _beam_parse(self, Tokens tokens) except NULL + cdef int _greedy_parse(self, Tokens tokens) except -1 + cdef int _beam_parse(self, Tokens tokens) except -1 diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index b308aa2e2..7813be51d 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -81,15 +81,19 @@ cdef class Parser: def __call__(self, Tokens tokens): if tokens.length == 0: return 0 - cdef State* state if self.cfg.beam_width == 1: - state = self._greedy_parse(tokens) + self._greedy_parse(tokens) else: - state = self._beam_parse(tokens) - self.moves.finalize_state(state) - tokens.set_parse(state.sent) + self._beam_parse(tokens) - cdef State* _greedy_parse(self, Tokens tokens) except NULL: + def train(self, Tokens tokens, GoldParse gold): + self.moves.preprocess_gold(gold) + if self.cfg.beam_width == 1: + return self._greedy_train(tokens, gold) + else: + return self._beam_train(tokens, gold) + + cdef int _greedy_parse(self, Tokens tokens) except -1: cdef atom_t[CONTEXT_SIZE] context cdef int n_feats cdef Pool mem = Pool() @@ -101,21 +105,17 @@ cdef class Parser: scores = self.model.score(context) guess = self.moves.best_valid(scores, state) guess.do(&guess, state) - return state + self.moves.finalize_state(state) + tokens.set_parse(state.sent) - cdef State* _beam_parse(self, Tokens tokens) except NULL: + cdef int _beam_parse(self, Tokens tokens) except -1: cdef Beam beam = Beam(self.model.n_classes, self.cfg.beam_width) beam.initialize(_init_state, tokens.length, tokens.data) while not beam.is_done: self._advance_beam(beam, None, False) - return beam.at(0) - - def train(self, Tokens tokens, GoldParse gold): - self.moves.preprocess_gold(gold) - if self.beam_width == 1: - return self._greedy_train(tokens, gold) - else: - return self._beam_train(tokens, gold) + state = beam.at(0) + self.moves.finalize_state(state) + tokens.set_parse(state.sent) def _greedy_train(self, Tokens tokens, GoldParse gold): cdef Pool mem = Pool()