mirror of https://github.com/explosion/spaCy.git
Update name of 'train' function in BeamParser
This commit is contained in:
parent
0ed2afde89
commit
b0d80dc9ae
|
@ -96,7 +96,7 @@ cdef class BeamParser(Parser):
|
||||||
tokens[i] = state.c._sent[i]
|
tokens[i] = state.c._sent[i]
|
||||||
_cleanup(beam)
|
_cleanup(beam)
|
||||||
|
|
||||||
def train(self, Doc tokens, GoldParse gold_parse, itn=0):
|
def update(self, Doc tokens, GoldParse gold_parse, itn=0):
|
||||||
self.moves.preprocess_gold(gold_parse)
|
self.moves.preprocess_gold(gold_parse)
|
||||||
cdef Beam pred = Beam(self.moves.n_moves, self.beam_width)
|
cdef Beam pred = Beam(self.moves.n_moves, self.beam_width)
|
||||||
pred.initialize(_init_state, tokens.length, tokens.c)
|
pred.initialize(_init_state, tokens.length, tokens.c)
|
||||||
|
@ -133,7 +133,7 @@ cdef class BeamParser(Parser):
|
||||||
random.shuffle(histories)
|
random.shuffle(histories)
|
||||||
for grad, hist in histories:
|
for grad, hist in histories:
|
||||||
assert not math.isnan(grad) and not math.isinf(grad), hist
|
assert not math.isnan(grad) and not math.isinf(grad), hist
|
||||||
self.model._update_from_history(self.moves, tokens, hist, grad)
|
self.model.update_from_history(self.moves, tokens, hist, grad)
|
||||||
_cleanup(pred)
|
_cleanup(pred)
|
||||||
_cleanup(gold)
|
_cleanup(gold)
|
||||||
return pred.loss
|
return pred.loss
|
||||||
|
@ -169,7 +169,7 @@ cdef class BeamParser(Parser):
|
||||||
self.moves.set_costs(beam.is_valid[i], beam.costs[i], stcls, gold)
|
self.moves.set_costs(beam.is_valid[i], beam.costs[i], stcls, gold)
|
||||||
if follow_gold:
|
if follow_gold:
|
||||||
for j in range(self.moves.n_moves):
|
for j in range(self.moves.n_moves):
|
||||||
beam.is_valid[i][j] *= beam.costs[i][j] < 1
|
beam.is_valid[i][j] *= beam.costs[i][j] <= 0
|
||||||
if follow_gold:
|
if follow_gold:
|
||||||
beam.advance(_transition_state, NULL, <void*>self.moves.c)
|
beam.advance(_transition_state, NULL, <void*>self.moves.c)
|
||||||
else:
|
else:
|
||||||
|
@ -189,6 +189,7 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves)
|
||||||
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
||||||
cdef StateClass st = StateClass.init(<const TokenC*>tokens, length)
|
cdef StateClass st = StateClass.init(<const TokenC*>tokens, length)
|
||||||
# Ensure sent_start is set to 0 throughout
|
# Ensure sent_start is set to 0 throughout
|
||||||
|
# Ensure sent_start is set to 0 throughout
|
||||||
for i in range(st.c.length):
|
for i in range(st.c.length):
|
||||||
st.c._sent[i].sent_start = False
|
st.c._sent[i].sent_start = False
|
||||||
st.c._sent[i].l_edge = i
|
st.c._sent[i].l_edge = i
|
||||||
|
|
Loading…
Reference in New Issue