* Revise greedy_parse/beam_parse ownership goof

This commit is contained in:
Matthew Honnibal 2015-06-02 01:34:19 +02:00
parent 70a7ad89ca
commit 66dfa95847
2 changed files with 18 additions and 19 deletions

View File

@ -14,6 +14,5 @@ cdef class Parser:
cdef readonly Model model cdef readonly Model model
cdef readonly TransitionSystem moves cdef readonly TransitionSystem moves
cdef int _greedy_parse(self, Tokens tokens) except -1
cdef State* _greedy_parse(self, Tokens tokens) except NULL cdef int _beam_parse(self, Tokens tokens) except -1
cdef State* _beam_parse(self, Tokens tokens) except NULL

View File

@ -81,15 +81,19 @@ cdef class Parser:
def __call__(self, Tokens tokens): def __call__(self, Tokens tokens):
if tokens.length == 0: if tokens.length == 0:
return 0 return 0
cdef State* state
if self.cfg.beam_width == 1: if self.cfg.beam_width == 1:
state = self._greedy_parse(tokens) self._greedy_parse(tokens)
else: else:
state = self._beam_parse(tokens) self._beam_parse(tokens)
self.moves.finalize_state(state)
tokens.set_parse(state.sent)
cdef State* _greedy_parse(self, Tokens tokens) except NULL: def train(self, Tokens tokens, GoldParse gold):
self.moves.preprocess_gold(gold)
if self.cfg.beam_width == 1:
return self._greedy_train(tokens, gold)
else:
return self._beam_train(tokens, gold)
cdef int _greedy_parse(self, Tokens tokens) except -1:
cdef atom_t[CONTEXT_SIZE] context cdef atom_t[CONTEXT_SIZE] context
cdef int n_feats cdef int n_feats
cdef Pool mem = Pool() cdef Pool mem = Pool()
@ -101,21 +105,17 @@ cdef class Parser:
scores = self.model.score(context) scores = self.model.score(context)
guess = self.moves.best_valid(scores, state) guess = self.moves.best_valid(scores, state)
guess.do(&guess, state) guess.do(&guess, state)
return state self.moves.finalize_state(state)
tokens.set_parse(state.sent)
cdef State* _beam_parse(self, Tokens tokens) except NULL: cdef int _beam_parse(self, Tokens tokens) except -1:
cdef Beam beam = Beam(self.model.n_classes, self.cfg.beam_width) cdef Beam beam = Beam(self.model.n_classes, self.cfg.beam_width)
beam.initialize(_init_state, tokens.length, tokens.data) beam.initialize(_init_state, tokens.length, tokens.data)
while not beam.is_done: while not beam.is_done:
self._advance_beam(beam, None, False) self._advance_beam(beam, None, False)
return <State*>beam.at(0) state = <State*>beam.at(0)
self.moves.finalize_state(state)
def train(self, Tokens tokens, GoldParse gold): tokens.set_parse(state.sent)
self.moves.preprocess_gold(gold)
if self.beam_width == 1:
return self._greedy_train(tokens, gold)
else:
return self._beam_train(tokens, gold)
def _greedy_train(self, Tokens tokens, GoldParse gold): def _greedy_train(self, Tokens tokens, GoldParse gold):
cdef Pool mem = Pool() cdef Pool mem = Pool()