Fix parser gold cutting and gradient normalization

This commit is contained in:
Matthw Honnibal 2020-07-01 01:02:58 +02:00
parent 8c5a88e777
commit a1b6add4c8
1 changed files with 46 additions and 22 deletions

View File

@ -265,11 +265,15 @@ cdef class Parser:
free(is_valid)
def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None):
cdef StateClass state
if losses is None:
losses = {}
losses.setdefault(self.name, 0.)
for multitask in self._multitasks:
multitask.update(examples, drop=drop, sgd=sgd)
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
if n_examples == 0:
return losses
set_dropout_rate(self.model, drop)
# Prepare the stepwise model, and get the callback for finishing the batch
model, backprop_tok2vec = self.model.begin_update(
@ -280,10 +284,13 @@ cdef class Parser:
cut_size = self.cfg["update_with_oracle_cut_size"]
states, golds, max_steps = self._init_gold_batch(
examples,
max_length=numpy.random.choice(range(20, cut_size))
max_length=numpy.random.choice(range(5, cut_size))
)
else:
states, golds, max_steps = self.moves.init_gold_batch(examples)
states, golds, _ = self.moves.init_gold_batch(examples)
max_steps = max([len(eg.x) for eg in examples])
if not states:
return losses
all_states = list(states)
states_golds = zip(states, golds)
for _ in range(max_steps):
@ -292,6 +299,17 @@ cdef class Parser:
states, golds = zip(*states_golds)
scores, backprop = model.begin_update(states)
d_scores = self.get_batch_loss(states, golds, scores, losses)
if self.cfg["normalize_gradients_with_batch_size"]:
# We have to be very careful how we do this, because of the way we
# cut up the batch. We subdivide long sequences. If we normalize
# naively, we end up normalizing by sequence length, which
# is bad: that would mean that states in long sequences
# consistently get smaller gradients. Imagine if we have two
# sequences, one length 1000, one length 20. If we cut up
# the 1k sequence so that we have a "batch" of 50 subsequences,
# we don't want the gradients to get 50 times smaller!
d_scores /= n_examples
backprop(d_scores)
# Follow the predicted action
self.transition_states(states, scores)
@ -389,8 +407,6 @@ cdef class Parser:
cpu_log_loss(c_d_scores,
costs, is_valid, &scores[i, 0], d_scores.shape[1])
c_d_scores += d_scores.shape[1]
if len(states) and self.cfg["normalize_gradients_with_batch_size"]:
d_scores /= len(states)
if losses is not None:
losses.setdefault(self.name, 0.)
losses[self.name] += (d_scores**2).sum()
@ -503,41 +519,49 @@ cdef class Parser:
return self
def _init_gold_batch(self, examples, min_length=5, max_length=500):
"""Make a square batch, of length equal to the shortest doc. A long
"""Make a square batch, of length equal to the shortest transition
sequence or a cap. A long
doc will get multiple states. Let's say we have a doc of length 2*N,
where N is the shortest doc. We'll make two states, one representing
long_doc[:N], and another representing long_doc[N:]."""
cdef:
StateClass start_state
StateClass state
Transition action
all_states = self.moves.init_batch([eg.predicted for eg in examples])
kept = []
max_length_seen = 0
for state, eg in zip(all_states, examples):
if self.moves.has_gold(eg) and not state.is_final():
gold = self.moves.init_gold(state, eg)
kept.append((eg, state, gold))
max_length = max(min_length, min(max_length, min([len(eg.x) for eg in examples])))
max_moves = 0
oracle_actions = self.moves.get_oracle_sequence_from_state(
state.copy(), gold)
kept.append((eg, state, gold, oracle_actions))
min_length = min(min_length, len(oracle_actions))
max_length_seen = max(max_length, len(oracle_actions))
if not kept:
return [], [], 0
max_length = max(min_length, min(max_length, max_length_seen))
states = []
golds = []
for eg, state, gold in kept:
oracle_actions = self.moves.get_oracle_sequence_from_state(
state, gold)
start = 0
while start < len(eg.predicted):
state = state.copy()
cdef int clas
max_moves = 0
for eg, state, gold, oracle_actions in kept:
for i in range(0, len(oracle_actions), max_length):
start_state = state.copy()
n_moves = 0
while state.B(0) < start and not state.is_final():
action = self.moves.c[oracle_actions.pop(0)]
for clas in oracle_actions[i:i+max_length]:
action = self.moves.c[clas]
action.do(state.c, action.label)
state.c.push_hist(action.clas)
n_moves += 1
has_gold = self.moves.has_gold(eg, start=start,
end=start+max_length)
if not state.is_final() and has_gold:
states.append(state)
if state.is_final():
break
max_moves = max(max_moves, n_moves)
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
states.append(start_state)
golds.append(gold)
max_moves = max(max_moves, n_moves)
start += min(max_length, len(eg.x)-start)
max_moves = max(max_moves, len(oracle_actions))
if state.is_final():
break
return states, golds, max_moves