Ensemble textcat with listener (#8012)

* add unit test for two listeners, with a textcat ensemble in the middle

* return zero gradients instead of None in accumulate_gradient
This commit is contained in:
Sofie Van Landeghem 2021-05-31 10:21:06 +02:00 committed by GitHub
parent ff91e6dac7
commit fff662e41f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 92 additions and 2 deletions

View File

@ -173,6 +173,7 @@ class Tok2Vec(TrainablePipe):
for i in range(len(one_d_tokvecs)): for i in range(len(one_d_tokvecs)):
d_tokvecs[i] += one_d_tokvecs[i] d_tokvecs[i] += one_d_tokvecs[i]
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum()) losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
def backprop(one_d_tokvecs): def backprop(one_d_tokvecs):
"""Callback to actually do the backprop. Passed to last listener.""" """Callback to actually do the backprop. Passed to last listener."""

View File

@ -129,8 +129,8 @@ cfg_string = """
""" """
TRAIN_DATA = [ TRAIN_DATA = [
("I like green eggs", {"tags": ["N", "V", "J", "N"]}), ("I like green eggs", {"tags": ["N", "V", "J", "N"], "cats": {"preference": 1.0, "imperative": 0.0}}),
("Eat blue ham", {"tags": ["V", "J", "N"]}), ("Eat blue ham", {"tags": ["V", "J", "N"], "cats": {"preference": 0.0, "imperative": 1.0}}),
] ]
@ -318,3 +318,92 @@ def test_replace_listeners_from_config():
new_nlp.config["components"]["ner"]["model"]["tok2vec"]["@architectures"] new_nlp.config["components"]["ner"]["model"]["tok2vec"]["@architectures"]
== "spacy.Tok2VecListener.v1" == "spacy.Tok2VecListener.v1"
) )
cfg_string_multi_textcat = """
[nlp]
lang = "en"
pipeline = ["tok2vec","textcat_multilabel","tagger"]
[components]
[components.textcat_multilabel]
factory = "textcat_multilabel"
[components.textcat_multilabel.model]
@architectures = "spacy.TextCatEnsemble.v2"
nO = null
[components.textcat_multilabel.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.encode.width}
[components.textcat_multilabel.model.linear_model]
@architectures = "spacy.TextCatBOW.v1"
exclusive_classes = false
ngram_size = 1
no_output_layer = false
[components.tagger]
factory = "tagger"
[components.tagger.model]
@architectures = "spacy.Tagger.v1"
nO = null
[components.tagger.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.encode.width}
[components.tok2vec]
factory = "tok2vec"
[components.tok2vec.model]
@architectures = "spacy.Tok2Vec.v2"
[components.tok2vec.model.embed]
@architectures = "spacy.MultiHashEmbed.v1"
width = ${components.tok2vec.model.encode.width}
rows = [2000, 1000, 1000, 1000]
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
include_static_vectors = false
[components.tok2vec.model.encode]
@architectures = "spacy.MaxoutWindowEncoder.v2"
width = 96
depth = 4
window_size = 1
maxout_pieces = 3
"""
def test_tok2vec_listeners_textcat():
orig_config = Config().from_str(cfg_string_multi_textcat)
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
assert nlp.pipe_names == ["tok2vec", "textcat_multilabel", "tagger"]
tagger = nlp.get_pipe("tagger")
textcat = nlp.get_pipe("textcat_multilabel")
tok2vec = nlp.get_pipe("tok2vec")
tagger_tok2vec = tagger.model.get_ref("tok2vec")
textcat_tok2vec = textcat.model.get_ref("tok2vec")
assert isinstance(tok2vec, Tok2Vec)
assert isinstance(tagger_tok2vec, Tok2VecListener)
assert isinstance(textcat_tok2vec, Tok2VecListener)
train_examples = []
for t in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
optimizer = nlp.initialize(lambda: train_examples)
for i in range(50):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
docs = list(nlp.pipe(["Eat blue ham", "I like green eggs"]))
cats0 = docs[0].cats
assert cats0["preference"] < 0.1
assert cats0["imperative"] > 0.9
cats1 = docs[1].cats
assert cats1["preference"] > 0.1
assert cats1["imperative"] < 0.9
assert([t.tag_ for t in docs[0]] == ["V", "J", "N"])
assert([t.tag_ for t in docs[1]] == ["N", "V", "J", "N"])