mirror of https://github.com/explosion/spaCy.git
Update name of 'train' function in BeamParser
This commit is contained in:
parent
0ed2afde89
commit
b0d80dc9ae
|
@ -96,7 +96,7 @@ cdef class BeamParser(Parser):
|
|||
tokens[i] = state.c._sent[i]
|
||||
_cleanup(beam)
|
||||
|
||||
def train(self, Doc tokens, GoldParse gold_parse, itn=0):
|
||||
def update(self, Doc tokens, GoldParse gold_parse, itn=0):
|
||||
self.moves.preprocess_gold(gold_parse)
|
||||
cdef Beam pred = Beam(self.moves.n_moves, self.beam_width)
|
||||
pred.initialize(_init_state, tokens.length, tokens.c)
|
||||
|
@ -133,7 +133,7 @@ cdef class BeamParser(Parser):
|
|||
random.shuffle(histories)
|
||||
for grad, hist in histories:
|
||||
assert not math.isnan(grad) and not math.isinf(grad), hist
|
||||
self.model._update_from_history(self.moves, tokens, hist, grad)
|
||||
self.model.update_from_history(self.moves, tokens, hist, grad)
|
||||
_cleanup(pred)
|
||||
_cleanup(gold)
|
||||
return pred.loss
|
||||
|
@ -169,7 +169,7 @@ cdef class BeamParser(Parser):
|
|||
self.moves.set_costs(beam.is_valid[i], beam.costs[i], stcls, gold)
|
||||
if follow_gold:
|
||||
for j in range(self.moves.n_moves):
|
||||
beam.is_valid[i][j] *= beam.costs[i][j] < 1
|
||||
beam.is_valid[i][j] *= beam.costs[i][j] <= 0
|
||||
if follow_gold:
|
||||
beam.advance(_transition_state, NULL, <void*>self.moves.c)
|
||||
else:
|
||||
|
@ -189,6 +189,7 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves)
|
|||
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
||||
cdef StateClass st = StateClass.init(<const TokenC*>tokens, length)
|
||||
# Ensure sent_start is set to 0 throughout
|
||||
# Ensure sent_start is set to 0 throughout
|
||||
for i in range(st.c.length):
|
||||
st.c._sent[i].sent_start = False
|
||||
st.c._sent[i].l_edge = i
|
||||
|
|
Loading…
Reference in New Issue