Improve correctness of minibatching

This commit is contained in:
Matthew Honnibal 2017-05-27 17:59:00 -05:00
parent dc07d72d80
commit a1d4c97fb7
1 changed files with 11 additions and 4 deletions

View File

@ -427,7 +427,7 @@ cdef class Parser:
cuda_stream = get_cuda_stream() cuda_stream = get_cuda_stream()
states, golds, max_length = self._init_gold_batch(docs, golds) states, golds, max_steps = self._init_gold_batch(docs, golds)
state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream, state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream,
0.0) 0.0)
todo = [(s, g) for (s, g) in zip(states, golds) todo = [(s, g) for (s, g) in zip(states, golds)
@ -438,6 +438,7 @@ cdef class Parser:
backprops = [] backprops = []
d_tokvecs = state2vec.ops.allocate(tokvecs.shape) d_tokvecs = state2vec.ops.allocate(tokvecs.shape)
cdef float loss = 0. cdef float loss = 0.
n_steps = 0
while todo: while todo:
states, golds = zip(*todo) states, golds = zip(*todo)
@ -467,7 +468,8 @@ cdef class Parser:
todo = [st for st in todo if not st[0].is_final()] todo = [st for st in todo if not st[0].is_final()]
if losses is not None: if losses is not None:
losses[self.name] += (d_scores**2).sum() losses[self.name] += (d_scores**2).sum()
if len(backprops) >= (max_length * 2): n_steps += 1
if n_steps >= max_steps:
break break
self._make_updates(d_tokvecs, self._make_updates(d_tokvecs,
backprops, sgd, cuda_stream) backprops, sgd, cuda_stream)
@ -482,7 +484,8 @@ cdef class Parser:
StateClass state StateClass state
Transition action Transition action
whole_states = self.moves.init_batch(whole_docs) whole_states = self.moves.init_batch(whole_docs)
max_length = max(5, min(20, min([len(doc) for doc in whole_docs]))) max_length = max(5, min(50, min([len(doc) for doc in whole_docs])))
max_moves = 0
states = [] states = []
golds = [] golds = []
for doc, state, gold in zip(whole_docs, whole_states, whole_golds): for doc, state, gold in zip(whole_docs, whole_states, whole_golds):
@ -493,16 +496,20 @@ cdef class Parser:
start = 0 start = 0
while start < len(doc): while start < len(doc):
state = state.copy() state = state.copy()
n_moves = 0
while state.B(0) < start and not state.is_final(): while state.B(0) < start and not state.is_final():
action = self.moves.c[oracle_actions.pop(0)] action = self.moves.c[oracle_actions.pop(0)]
action.do(state.c, action.label) action.do(state.c, action.label)
n_moves += 1
has_gold = self.moves.has_gold(gold, start=start, has_gold = self.moves.has_gold(gold, start=start,
end=start+max_length) end=start+max_length)
if not state.is_final() and has_gold: if not state.is_final() and has_gold:
states.append(state) states.append(state)
golds.append(gold) golds.append(gold)
max_moves = max(max_moves, n_moves)
start += min(max_length, len(doc)-start) start += min(max_length, len(doc)-start)
return states, golds, max_length max_moves = max(max_moves, len(oracle_actions))
return states, golds, max_moves
def _make_updates(self, d_tokvecs, backprops, sgd, cuda_stream=None): def _make_updates(self, d_tokvecs, backprops, sgd, cuda_stream=None):
# Tells CUDA to block, so our async copies complete. # Tells CUDA to block, so our async copies complete.