From 1d20e21f3e7e68d2a70679941eb20cc5d5e826e1 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 27 Jan 2021 12:54:47 +1100 Subject: [PATCH] Add labels implicitly for parser and ner --- .../pipeline/_parser_internals/arc_eager.pyx | 12 ++++++++ spacy/pipeline/_parser_internals/ner.pyx | 7 +++++ spacy/pipeline/transition_parser.pyx | 30 +++++++++++++------ 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/spacy/pipeline/_parser_internals/arc_eager.pyx b/spacy/pipeline/_parser_internals/arc_eager.pyx index 069b41170..9ca702f9b 100644 --- a/spacy/pipeline/_parser_internals/arc_eager.pyx +++ b/spacy/pipeline/_parser_internals/arc_eager.pyx @@ -614,10 +614,22 @@ cdef class ArcEager(TransitionSystem): actions[LEFT].setdefault('dep', 0) return actions + @property + def builtin_labels(self): + return ["ROOT", "dep"] + @property def action_types(self): return (SHIFT, REDUCE, LEFT, RIGHT, BREAK) + def get_doc_labels(self, doc): + """Get the labels required for a given Doc.""" + labels = set(self.builtin_labels) + for token in doc: + if token.dep_: + labels.add(token.dep_) + return labels + def transition(self, StateClass state, action): cdef Transition t = self.lookup_transition(action) t.do(state.c, t.label) diff --git a/spacy/pipeline/_parser_internals/ner.pyx b/spacy/pipeline/_parser_internals/ner.pyx index d0da6ff70..dd747c08e 100644 --- a/spacy/pipeline/_parser_internals/ner.pyx +++ b/spacy/pipeline/_parser_internals/ner.pyx @@ -126,6 +126,13 @@ cdef class BiluoPushDown(TransitionSystem): def action_types(self): return (BEGIN, IN, LAST, UNIT, OUT) + def get_doc_labels(self, doc): + labels = set() + for token in doc: + if token.ent_type: + labels.add(token.ent_type_) + return labels + def move_name(self, int move, attr_t label): if move == OUT: return 'O' diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index e97d2b020..ad92e41b2 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -132,6 +132,23 @@ cdef class Parser(TrainablePipe): return 1 return 0 + def _ensure_labels_are_added(self, docs): + """Ensure that all labels for a batch of docs are added.""" + resized = False + labels = set() + for doc in docs: + labels.update(self.moves.get_doc_labels(doc)) + for label in labels: + for action in self.moves.action_types: + added = self.moves.add_action(action, label) + if added: + self.vocab.strings.add(label) + resized = True + if resized: + self._resize() + return 1 + return 0 + def _resize(self): self.model.attrs["resize_output"](self.model, self.moves.n_moves) if self._rehearsal_model not in (True, False, None): @@ -188,9 +205,9 @@ cdef class Parser(TrainablePipe): def predict(self, docs): if isinstance(docs, Doc): docs = [docs] + self._ensure_labels_are_added(docs) if not any(len(doc) for doc in docs): result = self.moves.init_batch(docs) - self._resize() return result if self.cfg["beam_width"] == 1: return self.greedy_parse(docs, drop=0.0) @@ -207,10 +224,6 @@ cdef class Parser(TrainablePipe): cdef StateClass state set_dropout_rate(self.model, drop) batch = self.moves.init_batch(docs) - # This is pretty dirty, but the NER can resize itself in init_batch, - # if labels are missing. We therefore have to check whether we need to - # expand our model output. - self._resize() model = self.model.predict(docs) weights = get_c_weights(model) for state in batch: @@ -234,10 +247,6 @@ cdef class Parser(TrainablePipe): beam_width, density=beam_density ) - # This is pretty dirty, but the NER can resize itself in init_batch, - # if labels are missing. We therefore have to check whether we need to - # expand our model output. - self._resize() model = self.model.predict(docs) while not batch.is_done: states = batch.get_unfinished_states() @@ -314,6 +323,9 @@ cdef class Parser(TrainablePipe): losses = {} losses.setdefault(self.name, 0.) validate_examples(examples, "Parser.update") + self._ensure_labels_are_added( + [eg.x for eg in examples] + [eg.y for eg in examples] + ) for multitask in self._multitasks: multitask.update(examples, drop=drop, sgd=sgd)