Update name of 'train' function in BeamParser

This commit is contained in:
Matthew Honnibal 2017-03-10 14:35:43 -06:00
parent 0ed2afde89
commit b0d80dc9ae
1 changed files with 4 additions and 3 deletions

View File

@ -96,7 +96,7 @@ cdef class BeamParser(Parser):
tokens[i] = state.c._sent[i]
_cleanup(beam)
def train(self, Doc tokens, GoldParse gold_parse, itn=0):
def update(self, Doc tokens, GoldParse gold_parse, itn=0):
self.moves.preprocess_gold(gold_parse)
cdef Beam pred = Beam(self.moves.n_moves, self.beam_width)
pred.initialize(_init_state, tokens.length, tokens.c)
@ -133,7 +133,7 @@ cdef class BeamParser(Parser):
random.shuffle(histories)
for grad, hist in histories:
assert not math.isnan(grad) and not math.isinf(grad), hist
self.model._update_from_history(self.moves, tokens, hist, grad)
self.model.update_from_history(self.moves, tokens, hist, grad)
_cleanup(pred)
_cleanup(gold)
return pred.loss
@ -169,7 +169,7 @@ cdef class BeamParser(Parser):
self.moves.set_costs(beam.is_valid[i], beam.costs[i], stcls, gold)
if follow_gold:
for j in range(self.moves.n_moves):
beam.is_valid[i][j] *= beam.costs[i][j] < 1
beam.is_valid[i][j] *= beam.costs[i][j] <= 0
if follow_gold:
beam.advance(_transition_state, NULL, <void*>self.moves.c)
else:
@ -189,6 +189,7 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves)
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
cdef StateClass st = StateClass.init(<const TokenC*>tokens, length)
# Ensure sent_start is set to 0 throughout
# Ensure sent_start is set to 0 throughout
for i in range(st.c.length):
st.c._sent[i].sent_start = False
st.c._sent[i].l_edge = i