mirror of https://github.com/explosion/spaCy.git
Ensemble textcat with listener (#8012)
* add unit test for two listeners, with a textcat ensemble in the middle * return zero gradients instead of None in accumulate_gradient
This commit is contained in:
parent
ff91e6dac7
commit
fff662e41f
|
@ -173,6 +173,7 @@ class Tok2Vec(TrainablePipe):
|
|||
for i in range(len(one_d_tokvecs)):
|
||||
d_tokvecs[i] += one_d_tokvecs[i]
|
||||
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
|
||||
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||
|
||||
def backprop(one_d_tokvecs):
|
||||
"""Callback to actually do the backprop. Passed to last listener."""
|
||||
|
|
|
@ -129,8 +129,8 @@ cfg_string = """
|
|||
"""
|
||||
|
||||
TRAIN_DATA = [
|
||||
("I like green eggs", {"tags": ["N", "V", "J", "N"]}),
|
||||
("Eat blue ham", {"tags": ["V", "J", "N"]}),
|
||||
("I like green eggs", {"tags": ["N", "V", "J", "N"], "cats": {"preference": 1.0, "imperative": 0.0}}),
|
||||
("Eat blue ham", {"tags": ["V", "J", "N"], "cats": {"preference": 0.0, "imperative": 1.0}}),
|
||||
]
|
||||
|
||||
|
||||
|
@ -318,3 +318,92 @@ def test_replace_listeners_from_config():
|
|||
new_nlp.config["components"]["ner"]["model"]["tok2vec"]["@architectures"]
|
||||
== "spacy.Tok2VecListener.v1"
|
||||
)
|
||||
|
||||
|
||||
cfg_string_multi_textcat = """
|
||||
[nlp]
|
||||
lang = "en"
|
||||
pipeline = ["tok2vec","textcat_multilabel","tagger"]
|
||||
|
||||
[components]
|
||||
|
||||
[components.textcat_multilabel]
|
||||
factory = "textcat_multilabel"
|
||||
|
||||
[components.textcat_multilabel.model]
|
||||
@architectures = "spacy.TextCatEnsemble.v2"
|
||||
nO = null
|
||||
|
||||
[components.textcat_multilabel.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecListener.v1"
|
||||
width = ${components.tok2vec.model.encode.width}
|
||||
|
||||
[components.textcat_multilabel.model.linear_model]
|
||||
@architectures = "spacy.TextCatBOW.v1"
|
||||
exclusive_classes = false
|
||||
ngram_size = 1
|
||||
no_output_layer = false
|
||||
|
||||
[components.tagger]
|
||||
factory = "tagger"
|
||||
|
||||
[components.tagger.model]
|
||||
@architectures = "spacy.Tagger.v1"
|
||||
nO = null
|
||||
|
||||
[components.tagger.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecListener.v1"
|
||||
width = ${components.tok2vec.model.encode.width}
|
||||
|
||||
[components.tok2vec]
|
||||
factory = "tok2vec"
|
||||
|
||||
[components.tok2vec.model]
|
||||
@architectures = "spacy.Tok2Vec.v2"
|
||||
|
||||
[components.tok2vec.model.embed]
|
||||
@architectures = "spacy.MultiHashEmbed.v1"
|
||||
width = ${components.tok2vec.model.encode.width}
|
||||
rows = [2000, 1000, 1000, 1000]
|
||||
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
|
||||
include_static_vectors = false
|
||||
|
||||
[components.tok2vec.model.encode]
|
||||
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||
width = 96
|
||||
depth = 4
|
||||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
"""
|
||||
|
||||
|
||||
def test_tok2vec_listeners_textcat():
|
||||
orig_config = Config().from_str(cfg_string_multi_textcat)
|
||||
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||
assert nlp.pipe_names == ["tok2vec", "textcat_multilabel", "tagger"]
|
||||
tagger = nlp.get_pipe("tagger")
|
||||
textcat = nlp.get_pipe("textcat_multilabel")
|
||||
tok2vec = nlp.get_pipe("tok2vec")
|
||||
tagger_tok2vec = tagger.model.get_ref("tok2vec")
|
||||
textcat_tok2vec = textcat.model.get_ref("tok2vec")
|
||||
assert isinstance(tok2vec, Tok2Vec)
|
||||
assert isinstance(tagger_tok2vec, Tok2VecListener)
|
||||
assert isinstance(textcat_tok2vec, Tok2VecListener)
|
||||
train_examples = []
|
||||
for t in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
|
||||
optimizer = nlp.initialize(lambda: train_examples)
|
||||
for i in range(50):
|
||||
losses = {}
|
||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
|
||||
docs = list(nlp.pipe(["Eat blue ham", "I like green eggs"]))
|
||||
cats0 = docs[0].cats
|
||||
assert cats0["preference"] < 0.1
|
||||
assert cats0["imperative"] > 0.9
|
||||
cats1 = docs[1].cats
|
||||
assert cats1["preference"] > 0.1
|
||||
assert cats1["imperative"] < 0.9
|
||||
assert([t.tag_ for t in docs[0]] == ["V", "J", "N"])
|
||||
assert([t.tag_ for t in docs[1]] == ["N", "V", "J", "N"])
|
||||
|
|
Loading…
Reference in New Issue