From fff662e41fc680a26887e178653bf8bfe1f87f9d Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Mon, 31 May 2021 10:21:06 +0200 Subject: [PATCH] Ensemble textcat with listener (#8012) * add unit test for two listeners, with a textcat ensemble in the middle * return zero gradients instead of None in accumulate_gradient --- spacy/pipeline/tok2vec.py | 1 + spacy/tests/pipeline/test_tok2vec.py | 93 +++++++++++++++++++++++++++- 2 files changed, 92 insertions(+), 2 deletions(-) diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index 3ee324d50..00d9548a4 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -173,6 +173,7 @@ class Tok2Vec(TrainablePipe): for i in range(len(one_d_tokvecs)): d_tokvecs[i] += one_d_tokvecs[i] losses[self.name] += float((one_d_tokvecs[i] ** 2).sum()) + return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs] def backprop(one_d_tokvecs): """Callback to actually do the backprop. Passed to last listener.""" diff --git a/spacy/tests/pipeline/test_tok2vec.py b/spacy/tests/pipeline/test_tok2vec.py index 7a9e96b14..809a79dd6 100644 --- a/spacy/tests/pipeline/test_tok2vec.py +++ b/spacy/tests/pipeline/test_tok2vec.py @@ -129,8 +129,8 @@ cfg_string = """ """ TRAIN_DATA = [ - ("I like green eggs", {"tags": ["N", "V", "J", "N"]}), - ("Eat blue ham", {"tags": ["V", "J", "N"]}), + ("I like green eggs", {"tags": ["N", "V", "J", "N"], "cats": {"preference": 1.0, "imperative": 0.0}}), + ("Eat blue ham", {"tags": ["V", "J", "N"], "cats": {"preference": 0.0, "imperative": 1.0}}), ] @@ -318,3 +318,92 @@ def test_replace_listeners_from_config(): new_nlp.config["components"]["ner"]["model"]["tok2vec"]["@architectures"] == "spacy.Tok2VecListener.v1" ) + + +cfg_string_multi_textcat = """ + [nlp] + lang = "en" + pipeline = ["tok2vec","textcat_multilabel","tagger"] + + [components] + + [components.textcat_multilabel] + factory = "textcat_multilabel" + + [components.textcat_multilabel.model] + @architectures = "spacy.TextCatEnsemble.v2" + nO = null + + [components.textcat_multilabel.model.tok2vec] + @architectures = "spacy.Tok2VecListener.v1" + width = ${components.tok2vec.model.encode.width} + + [components.textcat_multilabel.model.linear_model] + @architectures = "spacy.TextCatBOW.v1" + exclusive_classes = false + ngram_size = 1 + no_output_layer = false + + [components.tagger] + factory = "tagger" + + [components.tagger.model] + @architectures = "spacy.Tagger.v1" + nO = null + + [components.tagger.model.tok2vec] + @architectures = "spacy.Tok2VecListener.v1" + width = ${components.tok2vec.model.encode.width} + + [components.tok2vec] + factory = "tok2vec" + + [components.tok2vec.model] + @architectures = "spacy.Tok2Vec.v2" + + [components.tok2vec.model.embed] + @architectures = "spacy.MultiHashEmbed.v1" + width = ${components.tok2vec.model.encode.width} + rows = [2000, 1000, 1000, 1000] + attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"] + include_static_vectors = false + + [components.tok2vec.model.encode] + @architectures = "spacy.MaxoutWindowEncoder.v2" + width = 96 + depth = 4 + window_size = 1 + maxout_pieces = 3 + """ + + +def test_tok2vec_listeners_textcat(): + orig_config = Config().from_str(cfg_string_multi_textcat) + nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True) + assert nlp.pipe_names == ["tok2vec", "textcat_multilabel", "tagger"] + tagger = nlp.get_pipe("tagger") + textcat = nlp.get_pipe("textcat_multilabel") + tok2vec = nlp.get_pipe("tok2vec") + tagger_tok2vec = tagger.model.get_ref("tok2vec") + textcat_tok2vec = textcat.model.get_ref("tok2vec") + assert isinstance(tok2vec, Tok2Vec) + assert isinstance(tagger_tok2vec, Tok2VecListener) + assert isinstance(textcat_tok2vec, Tok2VecListener) + train_examples = [] + for t in TRAIN_DATA: + train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1])) + + optimizer = nlp.initialize(lambda: train_examples) + for i in range(50): + losses = {} + nlp.update(train_examples, sgd=optimizer, losses=losses) + + docs = list(nlp.pipe(["Eat blue ham", "I like green eggs"])) + cats0 = docs[0].cats + assert cats0["preference"] < 0.1 + assert cats0["imperative"] > 0.9 + cats1 = docs[1].cats + assert cats1["preference"] > 0.1 + assert cats1["imperative"] < 0.9 + assert([t.tag_ for t in docs[0]] == ["V", "J", "N"]) + assert([t.tag_ for t in docs[1]] == ["N", "V", "J", "N"])