mirror of https://github.com/explosion/spaCy.git
Fix parser gold cutting and gradient normalization
This commit is contained in:
parent
8c5a88e777
commit
a1b6add4c8
|
@ -265,11 +265,15 @@ cdef class Parser:
|
||||||
free(is_valid)
|
free(is_valid)
|
||||||
|
|
||||||
def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None):
|
def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None):
|
||||||
|
cdef StateClass state
|
||||||
if losses is None:
|
if losses is None:
|
||||||
losses = {}
|
losses = {}
|
||||||
losses.setdefault(self.name, 0.)
|
losses.setdefault(self.name, 0.)
|
||||||
for multitask in self._multitasks:
|
for multitask in self._multitasks:
|
||||||
multitask.update(examples, drop=drop, sgd=sgd)
|
multitask.update(examples, drop=drop, sgd=sgd)
|
||||||
|
n_examples = len([eg for eg in examples if self.moves.has_gold(eg)])
|
||||||
|
if n_examples == 0:
|
||||||
|
return losses
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
# Prepare the stepwise model, and get the callback for finishing the batch
|
# Prepare the stepwise model, and get the callback for finishing the batch
|
||||||
model, backprop_tok2vec = self.model.begin_update(
|
model, backprop_tok2vec = self.model.begin_update(
|
||||||
|
@ -280,10 +284,13 @@ cdef class Parser:
|
||||||
cut_size = self.cfg["update_with_oracle_cut_size"]
|
cut_size = self.cfg["update_with_oracle_cut_size"]
|
||||||
states, golds, max_steps = self._init_gold_batch(
|
states, golds, max_steps = self._init_gold_batch(
|
||||||
examples,
|
examples,
|
||||||
max_length=numpy.random.choice(range(20, cut_size))
|
max_length=numpy.random.choice(range(5, cut_size))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
states, golds, max_steps = self.moves.init_gold_batch(examples)
|
states, golds, _ = self.moves.init_gold_batch(examples)
|
||||||
|
max_steps = max([len(eg.x) for eg in examples])
|
||||||
|
if not states:
|
||||||
|
return losses
|
||||||
all_states = list(states)
|
all_states = list(states)
|
||||||
states_golds = zip(states, golds)
|
states_golds = zip(states, golds)
|
||||||
for _ in range(max_steps):
|
for _ in range(max_steps):
|
||||||
|
@ -292,6 +299,17 @@ cdef class Parser:
|
||||||
states, golds = zip(*states_golds)
|
states, golds = zip(*states_golds)
|
||||||
scores, backprop = model.begin_update(states)
|
scores, backprop = model.begin_update(states)
|
||||||
d_scores = self.get_batch_loss(states, golds, scores, losses)
|
d_scores = self.get_batch_loss(states, golds, scores, losses)
|
||||||
|
if self.cfg["normalize_gradients_with_batch_size"]:
|
||||||
|
# We have to be very careful how we do this, because of the way we
|
||||||
|
# cut up the batch. We subdivide long sequences. If we normalize
|
||||||
|
# naively, we end up normalizing by sequence length, which
|
||||||
|
# is bad: that would mean that states in long sequences
|
||||||
|
# consistently get smaller gradients. Imagine if we have two
|
||||||
|
# sequences, one length 1000, one length 20. If we cut up
|
||||||
|
# the 1k sequence so that we have a "batch" of 50 subsequences,
|
||||||
|
# we don't want the gradients to get 50 times smaller!
|
||||||
|
d_scores /= n_examples
|
||||||
|
|
||||||
backprop(d_scores)
|
backprop(d_scores)
|
||||||
# Follow the predicted action
|
# Follow the predicted action
|
||||||
self.transition_states(states, scores)
|
self.transition_states(states, scores)
|
||||||
|
@ -389,8 +407,6 @@ cdef class Parser:
|
||||||
cpu_log_loss(c_d_scores,
|
cpu_log_loss(c_d_scores,
|
||||||
costs, is_valid, &scores[i, 0], d_scores.shape[1])
|
costs, is_valid, &scores[i, 0], d_scores.shape[1])
|
||||||
c_d_scores += d_scores.shape[1]
|
c_d_scores += d_scores.shape[1]
|
||||||
if len(states) and self.cfg["normalize_gradients_with_batch_size"]:
|
|
||||||
d_scores /= len(states)
|
|
||||||
if losses is not None:
|
if losses is not None:
|
||||||
losses.setdefault(self.name, 0.)
|
losses.setdefault(self.name, 0.)
|
||||||
losses[self.name] += (d_scores**2).sum()
|
losses[self.name] += (d_scores**2).sum()
|
||||||
|
@ -503,41 +519,49 @@ cdef class Parser:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _init_gold_batch(self, examples, min_length=5, max_length=500):
|
def _init_gold_batch(self, examples, min_length=5, max_length=500):
|
||||||
"""Make a square batch, of length equal to the shortest doc. A long
|
"""Make a square batch, of length equal to the shortest transition
|
||||||
|
sequence or a cap. A long
|
||||||
doc will get multiple states. Let's say we have a doc of length 2*N,
|
doc will get multiple states. Let's say we have a doc of length 2*N,
|
||||||
where N is the shortest doc. We'll make two states, one representing
|
where N is the shortest doc. We'll make two states, one representing
|
||||||
long_doc[:N], and another representing long_doc[N:]."""
|
long_doc[:N], and another representing long_doc[N:]."""
|
||||||
cdef:
|
cdef:
|
||||||
|
StateClass start_state
|
||||||
StateClass state
|
StateClass state
|
||||||
Transition action
|
Transition action
|
||||||
all_states = self.moves.init_batch([eg.predicted for eg in examples])
|
all_states = self.moves.init_batch([eg.predicted for eg in examples])
|
||||||
kept = []
|
kept = []
|
||||||
|
max_length_seen = 0
|
||||||
for state, eg in zip(all_states, examples):
|
for state, eg in zip(all_states, examples):
|
||||||
if self.moves.has_gold(eg) and not state.is_final():
|
if self.moves.has_gold(eg) and not state.is_final():
|
||||||
gold = self.moves.init_gold(state, eg)
|
gold = self.moves.init_gold(state, eg)
|
||||||
kept.append((eg, state, gold))
|
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
||||||
max_length = max(min_length, min(max_length, min([len(eg.x) for eg in examples])))
|
state.copy(), gold)
|
||||||
max_moves = 0
|
kept.append((eg, state, gold, oracle_actions))
|
||||||
|
min_length = min(min_length, len(oracle_actions))
|
||||||
|
max_length_seen = max(max_length, len(oracle_actions))
|
||||||
|
if not kept:
|
||||||
|
return [], [], 0
|
||||||
|
max_length = max(min_length, min(max_length, max_length_seen))
|
||||||
states = []
|
states = []
|
||||||
golds = []
|
golds = []
|
||||||
for eg, state, gold in kept:
|
cdef int clas
|
||||||
oracle_actions = self.moves.get_oracle_sequence_from_state(
|
max_moves = 0
|
||||||
state, gold)
|
for eg, state, gold, oracle_actions in kept:
|
||||||
start = 0
|
for i in range(0, len(oracle_actions), max_length):
|
||||||
while start < len(eg.predicted):
|
start_state = state.copy()
|
||||||
state = state.copy()
|
|
||||||
n_moves = 0
|
n_moves = 0
|
||||||
while state.B(0) < start and not state.is_final():
|
for clas in oracle_actions[i:i+max_length]:
|
||||||
action = self.moves.c[oracle_actions.pop(0)]
|
action = self.moves.c[clas]
|
||||||
action.do(state.c, action.label)
|
action.do(state.c, action.label)
|
||||||
state.c.push_hist(action.clas)
|
state.c.push_hist(action.clas)
|
||||||
n_moves += 1
|
n_moves += 1
|
||||||
has_gold = self.moves.has_gold(eg, start=start,
|
if state.is_final():
|
||||||
end=start+max_length)
|
break
|
||||||
if not state.is_final() and has_gold:
|
max_moves = max(max_moves, n_moves)
|
||||||
states.append(state)
|
if self.moves.has_gold(eg, start_state.B(0), state.B(0)):
|
||||||
|
states.append(start_state)
|
||||||
golds.append(gold)
|
golds.append(gold)
|
||||||
max_moves = max(max_moves, n_moves)
|
max_moves = max(max_moves, n_moves)
|
||||||
start += min(max_length, len(eg.x)-start)
|
if state.is_final():
|
||||||
max_moves = max(max_moves, len(oracle_actions))
|
break
|
||||||
return states, golds, max_moves
|
return states, golds, max_moves
|
||||||
|
|
Loading…
Reference in New Issue