Improve efficiency of get_oracle_sequences

This commit is contained in:
Matthw Honnibal 2020-06-30 11:50:48 +02:00
parent 233945bfe0
commit 57e09747dc
3 changed files with 41 additions and 32 deletions

View File

@ -742,21 +742,14 @@ cdef class ArcEager(TransitionSystem):
if n_gold < 1:
raise ValueError
def get_oracle_sequence(self, Example example):
cdef StateClass state
cdef ArcEagerGold gold
states, golds, n_steps = self.init_gold_batch([example])
if not golds:
return []
def get_oracle_sequence_from_state(self, StateClass state, ArcEagerGold gold, _debug=None):
cdef int i
cdef Pool mem = Pool()
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
assert self.n_moves > 0
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
state = states[0]
gold = golds[0]
history = []
debug_log = []
failed = False
@ -772,18 +765,21 @@ cdef class ArcEager(TransitionSystem):
history.append(i)
s0 = state.S(0)
b0 = state.B(0)
debug_log.append(" ".join((
self.get_class_name(i),
"S0=", (example.x[s0].text if s0 >= 0 else "__"),
"B0=", (example.x[b0].text if b0 >= 0 else "__"),
"S0 head?", str(state.has_head(state.S(0))),
)))
if _debug:
example = _debug
debug_log.append(" ".join((
self.get_class_name(i),
"S0=", (example.x[s0].text if s0 >= 0 else "__"),
"B0=", (example.x[b0].text if b0 >= 0 else "__"),
"S0 head?", str(state.has_head(state.S(0))),
)))
action.do(state.c, action.label)
break
else:
failed = False
break
if failed:
example = _debug
print("Actions")
for i in range(self.n_moves):
print(self.get_class_name(i))

View File

@ -63,7 +63,9 @@ cdef class Parser:
self.model = model
if self.moves.n_moves != 0:
self.set_output(self.moves.n_moves)
self.cfg = cfg
self.cfg = dict(cfg)
self.cfg.setdefault("update_with_oracle_cut_size", 100)
self.cfg.setdefault("normalize_gradients_with_batch_size", True)
self._multitasks = []
for multitask in cfg.get("multitasks", []):
self.add_multitask_objective(multitask)
@ -272,13 +274,16 @@ cdef class Parser:
# Prepare the stepwise model, and get the callback for finishing the batch
model, backprop_tok2vec = self.model.begin_update(
[eg.predicted for eg in examples])
# Chop sequences into lengths of this many transitions, to make the
# batch uniform length. We randomize this to overfit less.
cut_gold = numpy.random.choice(range(20, 100))
states, golds, max_steps = self._init_gold_batch(
examples,
max_length=cut_gold
)
if self.cfg["update_with_oracle_cut_size"] >= 1:
# Chop sequences into lengths of this many transitions, to make the
# batch uniform length. We randomize this to overfit less.
cut_size = self.cfg["update_with_oracle_cut_size"]
states, golds, max_steps = self._init_gold_batch(
examples,
max_length=numpy.random.choice(range(20, cut_size))
)
else:
states, golds, max_steps = self.moves.init_gold_batch(examples)
all_states = list(states)
states_golds = zip(states, golds)
for _ in range(max_steps):
@ -384,7 +389,7 @@ cdef class Parser:
cpu_log_loss(c_d_scores,
costs, is_valid, &scores[i, 0], d_scores.shape[1])
c_d_scores += d_scores.shape[1]
if len(states):
if len(states) and self.cfg["normalize_gradients_with_batch_size"]:
d_scores /= len(states)
if losses is not None:
losses.setdefault(self.name, 0.)
@ -516,7 +521,8 @@ cdef class Parser:
states = []
golds = []
for eg, state, gold in kept:
oracle_actions = self.moves.get_oracle_sequence(eg)
oracle_actions = self.moves.get_oracle_sequence_from_state(
state, gold)
start = 0
while start < len(eg.predicted):
state = state.copy()

View File

@ -60,20 +60,25 @@ cdef class TransitionSystem:
states.append(state)
offset += len(doc)
return states
def get_oracle_sequence(self, Example example, _debug=False):
states, golds, _ = self.init_gold_batch([example])
if not states:
return []
state = states[0]
gold = golds[0]
if _debug:
return self.get_oracle_sequence_from_state(state, gold, _debug=example)
else:
return self.get_oracle_sequence_from_state(state, gold)
def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None):
cdef Pool mem = Pool()
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
assert self.n_moves > 0
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
cdef StateClass state
states, golds, n_steps = self.init_gold_batch([example])
if not states:
return []
state = states[0]
gold = golds[0]
history = []
debug_log = []
while not state.is_final():
@ -85,6 +90,7 @@ cdef class TransitionSystem:
s0 = state.S(0)
b0 = state.B(0)
if _debug:
example = _debug
debug_log.append(" ".join((
self.get_class_name(i),
"S0=", (example.x[s0].text if s0 >= 0 else "__"),
@ -95,6 +101,7 @@ cdef class TransitionSystem:
break
else:
if _debug:
example = _debug
print("Actions")
for i in range(self.n_moves):
print(self.get_class_name(i))