From 249b97184d12664dde53a3c5b8c658ad7b8cf0ca Mon Sep 17 00:00:00 2001 From: kadarakos Date: Wed, 23 Feb 2022 16:10:05 +0100 Subject: [PATCH] Bugfixes and test for rehearse (#10347) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fixing argument order for rehearse * rehearse test for ner and tagger * rehearse bugfix * added test for parser * test for multilabel textcat * rehearse fix * remove debug line * Update spacy/tests/training/test_rehearse.py Co-authored-by: Sofie Van Landeghem * Update spacy/tests/training/test_rehearse.py Co-authored-by: Sofie Van Landeghem Co-authored-by: Kádár Ákos Co-authored-by: Sofie Van Landeghem --- spacy/language.py | 5 +- spacy/pipeline/tagger.pyx | 11 +- spacy/pipeline/textcat.py | 2 +- spacy/tests/training/test_rehearse.py | 168 ++++++++++++++++++++++++++ 4 files changed, 178 insertions(+), 8 deletions(-) create mode 100644 spacy/tests/training/test_rehearse.py diff --git a/spacy/language.py b/spacy/language.py index e8fd2720c..bab403f0e 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1222,8 +1222,9 @@ class Language: component_cfg = {} grads = {} - def get_grads(W, dW, key=None): + def get_grads(key, W, dW): grads[key] = (W, dW) + return W, dW get_grads.learn_rate = sgd.learn_rate # type: ignore[attr-defined, union-attr] get_grads.b1 = sgd.b1 # type: ignore[attr-defined, union-attr] @@ -1236,7 +1237,7 @@ class Language: examples, sgd=get_grads, losses=losses, **component_cfg.get(name, {}) ) for key, (W, dW) in grads.items(): - sgd(W, dW, key=key) # type: ignore[call-arg, misc] + sgd(key, W, dW) # type: ignore[call-arg, misc] return losses def begin_training( diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index a2bec888e..e21a9096e 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -225,6 +225,7 @@ class Tagger(TrainablePipe): DOCS: https://spacy.io/api/tagger#rehearse """ + loss_func = SequenceCategoricalCrossentropy() if losses is None: losses = {} losses.setdefault(self.name, 0.0) @@ -236,12 +237,12 @@ class Tagger(TrainablePipe): # Handle cases where there are no tokens in any docs. return losses set_dropout_rate(self.model, drop) - guesses, backprop = self.model.begin_update(docs) - target = self._rehearsal_model(examples) - gradient = guesses - target - backprop(gradient) + tag_scores, bp_tag_scores = self.model.begin_update(docs) + tutor_tag_scores, _ = self._rehearsal_model.begin_update(docs) + grads, loss = loss_func(tag_scores, tutor_tag_scores) + bp_tag_scores(grads) self.finish_update(sgd) - losses[self.name] += (gradient**2).sum() + losses[self.name] += loss return losses def get_loss(self, examples, scores): diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index 690c350fa..bc3f127fc 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -283,7 +283,7 @@ class TextCategorizer(TrainablePipe): return losses set_dropout_rate(self.model, drop) scores, bp_scores = self.model.begin_update(docs) - target = self._rehearsal_model(examples) + target, _ = self._rehearsal_model.begin_update(docs) gradient = scores - target bp_scores(gradient) if sgd is not None: diff --git a/spacy/tests/training/test_rehearse.py b/spacy/tests/training/test_rehearse.py new file mode 100644 index 000000000..1bb8fac86 --- /dev/null +++ b/spacy/tests/training/test_rehearse.py @@ -0,0 +1,168 @@ +import pytest +import spacy + +from typing import List +from spacy.training import Example + + +TRAIN_DATA = [ + ( + 'Who is Kofi Annan?', + { + 'entities': [(7, 18, 'PERSON')], + 'tags': ['PRON', 'AUX', 'PROPN', 'PRON', 'PUNCT'], + 'heads': [1, 1, 3, 1, 1], + 'deps': ['attr', 'ROOT', 'compound', 'nsubj', 'punct'], + 'morphs': ['', 'Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin', 'Number=Sing', 'Number=Sing', 'PunctType=Peri'], + 'cats': {'question': 1.0} + } + ), + ( + 'Who is Steve Jobs?', + { + 'entities': [(7, 17, 'PERSON')], + 'tags': ['PRON', 'AUX', 'PROPN', 'PRON', 'PUNCT'], + 'heads': [1, 1, 3, 1, 1], + 'deps': ['attr', 'ROOT', 'compound', 'nsubj', 'punct'], + 'morphs': ['', 'Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin', 'Number=Sing', 'Number=Sing', 'PunctType=Peri'], + 'cats': {'question': 1.0} + } + ), + ( + 'Bob is a nice person.', + { + 'entities': [(0, 3, 'PERSON')], + 'tags': ['PROPN', 'AUX', 'DET', 'ADJ', 'NOUN', 'PUNCT'], + 'heads': [1, 1, 4, 4, 1, 1], + 'deps': ['nsubj', 'ROOT', 'det', 'amod', 'attr', 'punct'], + 'morphs': ['Number=Sing', 'Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin', 'Definite=Ind|PronType=Art', 'Degree=Pos', 'Number=Sing', 'PunctType=Peri'], + 'cats': {'statement': 1.0} + }, + ), + ( + 'Hi Anil, how are you?', + { + 'entities': [(3, 7, 'PERSON')], + 'tags': ['INTJ', 'PROPN', 'PUNCT', 'ADV', 'AUX', 'PRON', 'PUNCT'], + 'deps': ['intj', 'npadvmod', 'punct', 'advmod', 'ROOT', 'nsubj', 'punct'], + 'heads': [4, 0, 4, 4, 4, 4, 4], + 'morphs': ['', 'Number=Sing', 'PunctType=Comm', '', 'Mood=Ind|Tense=Pres|VerbForm=Fin', 'Case=Nom|Person=2|PronType=Prs', 'PunctType=Peri'], + 'cats': {'greeting': 1.0, 'question': 1.0} + } + ), + ( + 'I like London and Berlin.', + { + 'entities': [(7, 13, 'LOC'), (18, 24, 'LOC')], + 'tags': ['PROPN', 'VERB', 'PROPN', 'CCONJ', 'PROPN', 'PUNCT'], + 'deps': ['nsubj', 'ROOT', 'dobj', 'cc', 'conj', 'punct'], + 'heads': [1, 1, 1, 2, 2, 1], + 'morphs': ['Case=Nom|Number=Sing|Person=1|PronType=Prs', 'Tense=Pres|VerbForm=Fin', 'Number=Sing', 'ConjType=Cmp', 'Number=Sing', 'PunctType=Peri'], + 'cats': {'statement': 1.0} + } + ) +] + +REHEARSE_DATA = [ + ( + 'Hi Anil', + { + 'entities': [(3, 7, 'PERSON')], + 'tags': ['INTJ', 'PROPN'], + 'deps': ['ROOT', 'npadvmod'], + 'heads': [0, 0], + 'morphs': ['', 'Number=Sing'], + 'cats': {'greeting': 1.0} + } + ), + ( + 'Hi Ravish, how you doing?', + { + 'entities': [(3, 9, 'PERSON')], + 'tags': ['INTJ', 'PROPN', 'PUNCT', 'ADV', 'AUX', 'PRON', 'PUNCT'], + 'deps': ['intj', 'ROOT', 'punct', 'advmod', 'nsubj', 'advcl', 'punct'], + 'heads': [1, 1, 1, 5, 5, 1, 1], + 'morphs': ['', 'VerbForm=Inf', 'PunctType=Comm', '', 'Case=Nom|Person=2|PronType=Prs', 'Aspect=Prog|Tense=Pres|VerbForm=Part', 'PunctType=Peri'], + 'cats': {'greeting': 1.0, 'question': 1.0} + } + ), + # UTENSIL new label + ( + 'Natasha bought new forks.', + { + 'entities': [(0, 7, 'PERSON'), (19, 24, 'UTENSIL')], + 'tags': ['PROPN', 'VERB', 'ADJ', 'NOUN', 'PUNCT'], + 'deps': ['nsubj', 'ROOT', 'amod', 'dobj', 'punct'], + 'heads': [1, 1, 3, 1, 1], + 'morphs': ['Number=Sing', 'Tense=Past|VerbForm=Fin', 'Degree=Pos', 'Number=Plur', 'PunctType=Peri'], + 'cats': {'statement': 1.0} + } + ) +] + + +def _add_ner_label(ner, data): + for _, annotations in data: + for ent in annotations['entities']: + ner.add_label(ent[2]) + + +def _add_tagger_label(tagger, data): + for _, annotations in data: + for tag in annotations['tags']: + tagger.add_label(tag) + + +def _add_parser_label(parser, data): + for _, annotations in data: + for dep in annotations['deps']: + parser.add_label(dep) + + +def _add_textcat_label(textcat, data): + for _, annotations in data: + for cat in annotations['cats']: + textcat.add_label(cat) + + +def _optimize( + nlp, + component: str, + data: List, + rehearse: bool +): + """Run either train or rehearse.""" + pipe = nlp.get_pipe(component) + if component == 'ner': + _add_ner_label(pipe, data) + elif component == 'tagger': + _add_tagger_label(pipe, data) + elif component == 'parser': + _add_tagger_label(pipe, data) + elif component == 'textcat_multilabel': + _add_textcat_label(pipe, data) + else: + raise NotImplementedError + + if rehearse: + optimizer = nlp.resume_training() + else: + optimizer = nlp.initialize() + + for _ in range(5): + for text, annotation in data: + doc = nlp.make_doc(text) + example = Example.from_dict(doc, annotation) + if rehearse: + nlp.rehearse([example], sgd=optimizer) + else: + nlp.update([example], sgd=optimizer) + return nlp + + +@pytest.mark.parametrize("component", ['ner', 'tagger', 'parser', 'textcat_multilabel']) +def test_rehearse(component): + nlp = spacy.blank("en") + nlp.add_pipe(component) + nlp = _optimize(nlp, component, TRAIN_DATA, False) + _optimize(nlp, component, REHEARSE_DATA, True)