mirror of https://github.com/explosion/spaCy.git
Improve correctness of minibatching
This commit is contained in:
parent
dc07d72d80
commit
a1d4c97fb7
|
@ -427,7 +427,7 @@ cdef class Parser:
|
||||||
|
|
||||||
cuda_stream = get_cuda_stream()
|
cuda_stream = get_cuda_stream()
|
||||||
|
|
||||||
states, golds, max_length = self._init_gold_batch(docs, golds)
|
states, golds, max_steps = self._init_gold_batch(docs, golds)
|
||||||
state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream,
|
state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream,
|
||||||
0.0)
|
0.0)
|
||||||
todo = [(s, g) for (s, g) in zip(states, golds)
|
todo = [(s, g) for (s, g) in zip(states, golds)
|
||||||
|
@ -438,6 +438,7 @@ cdef class Parser:
|
||||||
backprops = []
|
backprops = []
|
||||||
d_tokvecs = state2vec.ops.allocate(tokvecs.shape)
|
d_tokvecs = state2vec.ops.allocate(tokvecs.shape)
|
||||||
cdef float loss = 0.
|
cdef float loss = 0.
|
||||||
|
n_steps = 0
|
||||||
while todo:
|
while todo:
|
||||||
states, golds = zip(*todo)
|
states, golds = zip(*todo)
|
||||||
|
|
||||||
|
@ -467,7 +468,8 @@ cdef class Parser:
|
||||||
todo = [st for st in todo if not st[0].is_final()]
|
todo = [st for st in todo if not st[0].is_final()]
|
||||||
if losses is not None:
|
if losses is not None:
|
||||||
losses[self.name] += (d_scores**2).sum()
|
losses[self.name] += (d_scores**2).sum()
|
||||||
if len(backprops) >= (max_length * 2):
|
n_steps += 1
|
||||||
|
if n_steps >= max_steps:
|
||||||
break
|
break
|
||||||
self._make_updates(d_tokvecs,
|
self._make_updates(d_tokvecs,
|
||||||
backprops, sgd, cuda_stream)
|
backprops, sgd, cuda_stream)
|
||||||
|
@ -482,7 +484,8 @@ cdef class Parser:
|
||||||
StateClass state
|
StateClass state
|
||||||
Transition action
|
Transition action
|
||||||
whole_states = self.moves.init_batch(whole_docs)
|
whole_states = self.moves.init_batch(whole_docs)
|
||||||
max_length = max(5, min(20, min([len(doc) for doc in whole_docs])))
|
max_length = max(5, min(50, min([len(doc) for doc in whole_docs])))
|
||||||
|
max_moves = 0
|
||||||
states = []
|
states = []
|
||||||
golds = []
|
golds = []
|
||||||
for doc, state, gold in zip(whole_docs, whole_states, whole_golds):
|
for doc, state, gold in zip(whole_docs, whole_states, whole_golds):
|
||||||
|
@ -493,16 +496,20 @@ cdef class Parser:
|
||||||
start = 0
|
start = 0
|
||||||
while start < len(doc):
|
while start < len(doc):
|
||||||
state = state.copy()
|
state = state.copy()
|
||||||
|
n_moves = 0
|
||||||
while state.B(0) < start and not state.is_final():
|
while state.B(0) < start and not state.is_final():
|
||||||
action = self.moves.c[oracle_actions.pop(0)]
|
action = self.moves.c[oracle_actions.pop(0)]
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
|
n_moves += 1
|
||||||
has_gold = self.moves.has_gold(gold, start=start,
|
has_gold = self.moves.has_gold(gold, start=start,
|
||||||
end=start+max_length)
|
end=start+max_length)
|
||||||
if not state.is_final() and has_gold:
|
if not state.is_final() and has_gold:
|
||||||
states.append(state)
|
states.append(state)
|
||||||
golds.append(gold)
|
golds.append(gold)
|
||||||
|
max_moves = max(max_moves, n_moves)
|
||||||
start += min(max_length, len(doc)-start)
|
start += min(max_length, len(doc)-start)
|
||||||
return states, golds, max_length
|
max_moves = max(max_moves, len(oracle_actions))
|
||||||
|
return states, golds, max_moves
|
||||||
|
|
||||||
def _make_updates(self, d_tokvecs, backprops, sgd, cuda_stream=None):
|
def _make_updates(self, d_tokvecs, backprops, sgd, cuda_stream=None):
|
||||||
# Tells CUDA to block, so our async copies complete.
|
# Tells CUDA to block, so our async copies complete.
|
||||||
|
|
Loading…
Reference in New Issue