mirror of https://github.com/explosion/spaCy.git
* Revise greedy_parse/beam_parse ownership goof
This commit is contained in:
parent
70a7ad89ca
commit
66dfa95847
|
@ -14,6 +14,5 @@ cdef class Parser:
|
||||||
cdef readonly Model model
|
cdef readonly Model model
|
||||||
cdef readonly TransitionSystem moves
|
cdef readonly TransitionSystem moves
|
||||||
|
|
||||||
|
cdef int _greedy_parse(self, Tokens tokens) except -1
|
||||||
cdef State* _greedy_parse(self, Tokens tokens) except NULL
|
cdef int _beam_parse(self, Tokens tokens) except -1
|
||||||
cdef State* _beam_parse(self, Tokens tokens) except NULL
|
|
||||||
|
|
|
@ -81,15 +81,19 @@ cdef class Parser:
|
||||||
def __call__(self, Tokens tokens):
|
def __call__(self, Tokens tokens):
|
||||||
if tokens.length == 0:
|
if tokens.length == 0:
|
||||||
return 0
|
return 0
|
||||||
cdef State* state
|
|
||||||
if self.cfg.beam_width == 1:
|
if self.cfg.beam_width == 1:
|
||||||
state = self._greedy_parse(tokens)
|
self._greedy_parse(tokens)
|
||||||
else:
|
else:
|
||||||
state = self._beam_parse(tokens)
|
self._beam_parse(tokens)
|
||||||
self.moves.finalize_state(state)
|
|
||||||
tokens.set_parse(state.sent)
|
|
||||||
|
|
||||||
cdef State* _greedy_parse(self, Tokens tokens) except NULL:
|
def train(self, Tokens tokens, GoldParse gold):
|
||||||
|
self.moves.preprocess_gold(gold)
|
||||||
|
if self.cfg.beam_width == 1:
|
||||||
|
return self._greedy_train(tokens, gold)
|
||||||
|
else:
|
||||||
|
return self._beam_train(tokens, gold)
|
||||||
|
|
||||||
|
cdef int _greedy_parse(self, Tokens tokens) except -1:
|
||||||
cdef atom_t[CONTEXT_SIZE] context
|
cdef atom_t[CONTEXT_SIZE] context
|
||||||
cdef int n_feats
|
cdef int n_feats
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
|
@ -101,21 +105,17 @@ cdef class Parser:
|
||||||
scores = self.model.score(context)
|
scores = self.model.score(context)
|
||||||
guess = self.moves.best_valid(scores, state)
|
guess = self.moves.best_valid(scores, state)
|
||||||
guess.do(&guess, state)
|
guess.do(&guess, state)
|
||||||
return state
|
self.moves.finalize_state(state)
|
||||||
|
tokens.set_parse(state.sent)
|
||||||
|
|
||||||
cdef State* _beam_parse(self, Tokens tokens) except NULL:
|
cdef int _beam_parse(self, Tokens tokens) except -1:
|
||||||
cdef Beam beam = Beam(self.model.n_classes, self.cfg.beam_width)
|
cdef Beam beam = Beam(self.model.n_classes, self.cfg.beam_width)
|
||||||
beam.initialize(_init_state, tokens.length, tokens.data)
|
beam.initialize(_init_state, tokens.length, tokens.data)
|
||||||
while not beam.is_done:
|
while not beam.is_done:
|
||||||
self._advance_beam(beam, None, False)
|
self._advance_beam(beam, None, False)
|
||||||
return <State*>beam.at(0)
|
state = <State*>beam.at(0)
|
||||||
|
self.moves.finalize_state(state)
|
||||||
def train(self, Tokens tokens, GoldParse gold):
|
tokens.set_parse(state.sent)
|
||||||
self.moves.preprocess_gold(gold)
|
|
||||||
if self.beam_width == 1:
|
|
||||||
return self._greedy_train(tokens, gold)
|
|
||||||
else:
|
|
||||||
return self._beam_train(tokens, gold)
|
|
||||||
|
|
||||||
def _greedy_train(self, Tokens tokens, GoldParse gold):
|
def _greedy_train(self, Tokens tokens, GoldParse gold):
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
|
|
Loading…
Reference in New Issue