From 57e09747dc73065847ee13a5e8d3b757009540c8 Mon Sep 17 00:00:00 2001 From: Matthw Honnibal Date: Tue, 30 Jun 2020 11:50:48 +0200 Subject: [PATCH] Improve efficiency of get_oracle_sequences --- spacy/syntax/arc_eager.pyx | 26 +++++++++++--------------- spacy/syntax/nn_parser.pyx | 26 ++++++++++++++++---------- spacy/syntax/transition_system.pyx | 21 ++++++++++++++------- 3 files changed, 41 insertions(+), 32 deletions(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index f129ee7d1..6e63859f0 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -742,21 +742,14 @@ cdef class ArcEager(TransitionSystem): if n_gold < 1: raise ValueError - def get_oracle_sequence(self, Example example): - cdef StateClass state - cdef ArcEagerGold gold - states, golds, n_steps = self.init_gold_batch([example]) - if not golds: - return [] - + def get_oracle_sequence_from_state(self, StateClass state, ArcEagerGold gold, _debug=None): + cdef int i cdef Pool mem = Pool() # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc assert self.n_moves > 0 costs = mem.alloc(self.n_moves, sizeof(float)) is_valid = mem.alloc(self.n_moves, sizeof(int)) - state = states[0] - gold = golds[0] history = [] debug_log = [] failed = False @@ -772,18 +765,21 @@ cdef class ArcEager(TransitionSystem): history.append(i) s0 = state.S(0) b0 = state.B(0) - debug_log.append(" ".join(( - self.get_class_name(i), - "S0=", (example.x[s0].text if s0 >= 0 else "__"), - "B0=", (example.x[b0].text if b0 >= 0 else "__"), - "S0 head?", str(state.has_head(state.S(0))), - ))) + if _debug: + example = _debug + debug_log.append(" ".join(( + self.get_class_name(i), + "S0=", (example.x[s0].text if s0 >= 0 else "__"), + "B0=", (example.x[b0].text if b0 >= 0 else "__"), + "S0 head?", str(state.has_head(state.S(0))), + ))) action.do(state.c, action.label) break else: failed = False break if failed: + example = _debug print("Actions") for i in range(self.n_moves): print(self.get_class_name(i)) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 23dca79e3..9949c0ef3 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -63,7 +63,9 @@ cdef class Parser: self.model = model if self.moves.n_moves != 0: self.set_output(self.moves.n_moves) - self.cfg = cfg + self.cfg = dict(cfg) + self.cfg.setdefault("update_with_oracle_cut_size", 100) + self.cfg.setdefault("normalize_gradients_with_batch_size", True) self._multitasks = [] for multitask in cfg.get("multitasks", []): self.add_multitask_objective(multitask) @@ -272,13 +274,16 @@ cdef class Parser: # Prepare the stepwise model, and get the callback for finishing the batch model, backprop_tok2vec = self.model.begin_update( [eg.predicted for eg in examples]) - # Chop sequences into lengths of this many transitions, to make the - # batch uniform length. We randomize this to overfit less. - cut_gold = numpy.random.choice(range(20, 100)) - states, golds, max_steps = self._init_gold_batch( - examples, - max_length=cut_gold - ) + if self.cfg["update_with_oracle_cut_size"] >= 1: + # Chop sequences into lengths of this many transitions, to make the + # batch uniform length. We randomize this to overfit less. + cut_size = self.cfg["update_with_oracle_cut_size"] + states, golds, max_steps = self._init_gold_batch( + examples, + max_length=numpy.random.choice(range(20, cut_size)) + ) + else: + states, golds, max_steps = self.moves.init_gold_batch(examples) all_states = list(states) states_golds = zip(states, golds) for _ in range(max_steps): @@ -384,7 +389,7 @@ cdef class Parser: cpu_log_loss(c_d_scores, costs, is_valid, &scores[i, 0], d_scores.shape[1]) c_d_scores += d_scores.shape[1] - if len(states): + if len(states) and self.cfg["normalize_gradients_with_batch_size"]: d_scores /= len(states) if losses is not None: losses.setdefault(self.name, 0.) @@ -516,7 +521,8 @@ cdef class Parser: states = [] golds = [] for eg, state, gold in kept: - oracle_actions = self.moves.get_oracle_sequence(eg) + oracle_actions = self.moves.get_oracle_sequence_from_state( + state, gold) start = 0 while start < len(eg.predicted): state = state.copy() diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index e1ec40e0e..09477f8e0 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -60,20 +60,25 @@ cdef class TransitionSystem: states.append(state) offset += len(doc) return states - + def get_oracle_sequence(self, Example example, _debug=False): + states, golds, _ = self.init_gold_batch([example]) + if not states: + return [] + state = states[0] + gold = golds[0] + if _debug: + return self.get_oracle_sequence_from_state(state, gold, _debug=example) + else: + return self.get_oracle_sequence_from_state(state, gold) + + def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None): cdef Pool mem = Pool() # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc assert self.n_moves > 0 costs = mem.alloc(self.n_moves, sizeof(float)) is_valid = mem.alloc(self.n_moves, sizeof(int)) - cdef StateClass state - states, golds, n_steps = self.init_gold_batch([example]) - if not states: - return [] - state = states[0] - gold = golds[0] history = [] debug_log = [] while not state.is_final(): @@ -85,6 +90,7 @@ cdef class TransitionSystem: s0 = state.S(0) b0 = state.B(0) if _debug: + example = _debug debug_log.append(" ".join(( self.get_class_name(i), "S0=", (example.x[s0].text if s0 >= 0 else "__"), @@ -95,6 +101,7 @@ cdef class TransitionSystem: break else: if _debug: + example = _debug print("Actions") for i in range(self.n_moves): print(self.get_class_name(i))