mirror of https://github.com/explosion/spaCy.git
Fix significant performance bug in parser training (#6010)
The parser training makes use of a trick for long documents, where we use the oracle to cut up the document into sections, so that we can have batch items in the middle of a document. For instance, if we have one document of 600 words, we might make 6 states, starting at words 0, 100, 200, 300, 400 and 500. The problem is for v3, I screwed this up and didn't stop parsing! So instead of a batch of [100, 100, 100, 100, 100, 100], we'd have a batch of [600, 500, 400, 300, 200, 100]. Oops. The implementation here could probably be improved, it's annoying to have this extra variable in the state. But this'll do. This makes the v3 parser training 5-10 times faster, depending on document lengths. This problem wasn't in v2.
This commit is contained in:
parent
6bfb1b3a29
commit
c1bf3a5602
|
@ -42,6 +42,7 @@ cdef cppclass StateC:
|
|||
RingBufferC _hist
|
||||
int length
|
||||
int offset
|
||||
int n_pushes
|
||||
int _s_i
|
||||
int _b_i
|
||||
int _e_i
|
||||
|
@ -49,6 +50,7 @@ cdef cppclass StateC:
|
|||
|
||||
__init__(const TokenC* sent, int length) nogil:
|
||||
cdef int PADDING = 5
|
||||
this.n_pushes = 0
|
||||
this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
||||
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
||||
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
|
||||
|
@ -335,6 +337,7 @@ cdef cppclass StateC:
|
|||
this.set_break(this.B_(0).l_edge)
|
||||
if this._b_i > this._break:
|
||||
this._break = -1
|
||||
this.n_pushes += 1
|
||||
|
||||
void pop() nogil:
|
||||
if this._s_i >= 1:
|
||||
|
@ -351,6 +354,7 @@ cdef cppclass StateC:
|
|||
this._buffer[this._b_i] = this.S(0)
|
||||
this._s_i -= 1
|
||||
this.shifted[this.B(0)] = True
|
||||
this.n_pushes -= 1
|
||||
|
||||
void add_arc(int head, int child, attr_t label) nogil:
|
||||
if this.has_head(child):
|
||||
|
@ -431,6 +435,7 @@ cdef cppclass StateC:
|
|||
this._break = src._break
|
||||
this.offset = src.offset
|
||||
this._empty_token = src._empty_token
|
||||
this.n_pushes = src.n_pushes
|
||||
|
||||
void fast_forward() nogil:
|
||||
# space token attachement policy:
|
||||
|
|
|
@ -36,6 +36,10 @@ cdef class StateClass:
|
|||
hist[i] = self.c.get_hist(i+1)
|
||||
return hist
|
||||
|
||||
@property
|
||||
def n_pushes(self):
|
||||
return self.c.n_pushes
|
||||
|
||||
def is_final(self):
|
||||
return self.c.is_final()
|
||||
|
||||
|
|
|
@ -279,14 +279,14 @@ cdef class Parser(Pipe):
|
|||
# Chop sequences into lengths of this many transitions, to make the
|
||||
# batch uniform length.
|
||||
# We used to randomize this, but it's not clear that actually helps?
|
||||
cut_size = self.cfg["update_with_oracle_cut_size"]
|
||||
states, golds, max_steps = self._init_gold_batch(
|
||||
max_pushes = self.cfg["update_with_oracle_cut_size"]
|
||||
states, golds, _ = self._init_gold_batch(
|
||||
examples,
|
||||
max_length=cut_size
|
||||
max_length=max_pushes
|
||||
)
|
||||
else:
|
||||
states, golds, _ = self.moves.init_gold_batch(examples)
|
||||
max_steps = max([len(eg.x) for eg in examples])
|
||||
max_pushes = max([len(eg.x) for eg in examples])
|
||||
if not states:
|
||||
return losses
|
||||
all_states = list(states)
|
||||
|
@ -302,7 +302,8 @@ cdef class Parser(Pipe):
|
|||
backprop(d_scores)
|
||||
# Follow the predicted action
|
||||
self.transition_states(states, scores)
|
||||
states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()]
|
||||
states_golds = [(s, g) for (s, g) in zip(states, golds)
|
||||
if s.n_pushes < max_pushes and not s.is_final()]
|
||||
|
||||
backprop_tok2vec(golds)
|
||||
if sgd not in (None, False):
|
||||
|
|
Loading…
Reference in New Issue