From 3d5a536eaa49a46a17156ea8ba996f43179a2e13 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 26 May 2017 11:31:23 -0500 Subject: [PATCH] Improve efficiency of parser batching --- spacy/syntax/_state.pxd | 1 + spacy/syntax/arc_eager.pyx | 9 ++++- spacy/syntax/ner.pyx | 9 ++++- spacy/syntax/nn_parser.pyx | 55 ++++++++++++------------------ spacy/syntax/stateclass.pyx | 5 +++ spacy/syntax/transition_system.pyx | 28 +++++++++++++++ 6 files changed, 72 insertions(+), 35 deletions(-) diff --git a/spacy/syntax/_state.pxd b/spacy/syntax/_state.pxd index 829779dc1..4b2b47270 100644 --- a/spacy/syntax/_state.pxd +++ b/spacy/syntax/_state.pxd @@ -345,6 +345,7 @@ cdef cppclass StateC: this._s_i = src._s_i this._e_i = src._e_i this._break = src._break + this.offset = src.offset void fast_forward() nogil: # space token attachement policy: diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 0a1422088..f7c1c7922 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -350,8 +350,15 @@ cdef class ArcEager(TransitionSystem): def __get__(self): return (SHIFT, REDUCE, LEFT, RIGHT, BREAK) + def has_gold(self, GoldParse gold, start=0, end=None): + end = end or len(gold.heads) + if all([tag is None for tag in gold.heads[start:end]]): + return False + else: + return True + def preprocess_gold(self, GoldParse gold): - if all([h is None for h in gold.heads]): + if not self.has_gold(gold): return None for i in range(gold.length): if gold.heads[i] is None: # Missing values diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index 74ab9c26c..af42eded4 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -95,8 +95,15 @@ cdef class BiluoPushDown(TransitionSystem): else: return MOVE_NAMES[move] + '-' + self.strings[label] + def has_gold(self, GoldParse gold, start=0, end=None): + end = end or len(gold.ner) + if all([tag == '-' for tag in gold.ner[start:end]]): + return False + else: + return True + def preprocess_gold(self, GoldParse gold): - if all([tag == '-' for tag in gold.ner]): + if not self.has_gold(gold): return None for i in range(gold.length): gold.c.ner[i] = self.lookup_transition(gold.ner[i]) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 341b8c041..35966d536 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -427,8 +427,7 @@ cdef class Parser: cuda_stream = get_cuda_stream() - states, golds = self._init_gold_batch(docs, golds) - max_length = min([len(doc) for doc in docs]) + states, golds, max_length = self._init_gold_batch(docs, golds) state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream, 0.0) todo = [(s, g) for (s, g) in zip(states, golds) @@ -472,46 +471,36 @@ cdef class Parser: backprops, sgd, cuda_stream) return self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs]) - def _init_gold_batch(self, docs, golds): + def _init_gold_batch(self, whole_docs, whole_golds): """Make a square batch, of length equal to the shortest doc. A long doc will get multiple states. Let's say we have a doc of length 2*N, where N is the shortest doc. We'll make two states, one representing long_doc[:N], and another representing long_doc[N:].""" - cdef StateClass state - lengths = [len(doc) for doc in docs] - min_length = min(lengths) - offset = 0 + cdef: + StateClass state + Transition action + whole_states = self.moves.init_batch(whole_docs) + max_length = max(5, min(20, min([len(doc) for doc in whole_docs]))) states = [] - extra_golds = [] - cdef Pool mem = Pool() - costs = mem.alloc(self.moves.n_moves, sizeof(float)) - is_valid = mem.alloc(self.moves.n_moves, sizeof(int)) - for doc, gold in zip(docs, golds): + golds = [] + for doc, state, gold in zip(whole_docs, whole_states, whole_golds): gold = self.moves.preprocess_gold(gold) - state = StateClass(doc, offset=offset) - self.moves.initialize_state(state.c) - if not state.is_final(): - states.append(state) - extra_golds.append(gold) - start = min(min_length, len(doc)) + if gold is None: + continue + oracle_actions = self.moves.get_oracle_sequence(doc, gold) + start = 0 while start < len(doc): - length = min(min_length, len(doc)-start) - state = StateClass(doc, offset=offset) - self.moves.initialize_state(state.c) + state = state.copy() while state.B(0) < start and not state.is_final(): - self.moves.set_costs(is_valid, costs, state, gold) - for i in range(self.moves.n_moves): - if is_valid[i] and costs[i] <= 0: - self.moves.c[i].do(state.c, self.moves.c[i].label) - break - else: - raise ValueError("Could not find gold move") - start += length - if not state.is_final(): + action = self.moves.c[oracle_actions.pop(0)] + action.do(state.c, action.label) + has_gold = self.moves.has_gold(gold, start=start, + end=start+max_length) + if not state.is_final() and has_gold: states.append(state) - extra_golds.append(gold) - offset += len(doc) - return states, extra_golds + golds.append(gold) + start += min(max_length, len(doc)-start) + return states, golds, max_length def _make_updates(self, d_tokvecs, backprops, sgd, cuda_stream=None): # Tells CUDA to block, so our async copies complete. diff --git a/spacy/syntax/stateclass.pyx b/spacy/syntax/stateclass.pyx index fd38710e7..228a3ff91 100644 --- a/spacy/syntax/stateclass.pyx +++ b/spacy/syntax/stateclass.pyx @@ -41,6 +41,11 @@ cdef class StateClass: def is_final(self): return self.c.is_final() + def copy(self): + cdef StateClass new_state = StateClass.init(self.c._sent, self.c.length) + new_state.c.clone(self.c) + return new_state + def print_state(self, words): words = list(words) + ['_'] top = words[self.S(0)] + '_%d' % self.S_(0).head diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index d6750d09c..07102aeb0 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -61,6 +61,24 @@ cdef class TransitionSystem: offset += len(doc) return states + def get_oracle_sequence(self, doc, GoldParse gold): + cdef Pool mem = Pool() + costs = mem.alloc(self.n_moves, sizeof(float)) + is_valid = mem.alloc(self.n_moves, sizeof(int)) + + cdef StateClass state = StateClass(doc, offset=0) + self.initialize_state(state.c) + history = [] + while not state.is_final(): + self.set_costs(is_valid, costs, state, gold) + for i in range(self.n_moves): + if is_valid[i] and costs[i] <= 0: + action = self.c[i] + history.append(i) + action.do(state.c, action.label) + break + return history + cdef int initialize_state(self, StateC* state) nogil: pass @@ -92,11 +110,21 @@ cdef class TransitionSystem: StateClass stcls, GoldParse gold) except -1: cdef int i self.set_valid(is_valid, stcls.c) + cdef int n_gold = 0 for i in range(self.n_moves): if is_valid[i]: costs[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label) + n_gold += costs[i] <= 0 else: costs[i] = 9000 + if n_gold <= 0: + print(gold.words) + print(gold.ner) + raise ValueError( + "Could not find a gold-standard action to supervise " + "the entity recognizer\n" + "The transition system has %d actions.\n" + "%s" % (self.n_moves)) def add_action(self, int action, label): if not isinstance(label, int):