mirror of https://github.com/explosion/spaCy.git
Ensemble textcat with listener (#8012)
* add unit test for two listeners, with a textcat ensemble in the middle * return zero gradients instead of None in accumulate_gradient
This commit is contained in:
parent
ff91e6dac7
commit
fff662e41f
|
@ -173,6 +173,7 @@ class Tok2Vec(TrainablePipe):
|
||||||
for i in range(len(one_d_tokvecs)):
|
for i in range(len(one_d_tokvecs)):
|
||||||
d_tokvecs[i] += one_d_tokvecs[i]
|
d_tokvecs[i] += one_d_tokvecs[i]
|
||||||
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
|
losses[self.name] += float((one_d_tokvecs[i] ** 2).sum())
|
||||||
|
return [self.model.ops.alloc2f(*t2v.shape) for t2v in tokvecs]
|
||||||
|
|
||||||
def backprop(one_d_tokvecs):
|
def backprop(one_d_tokvecs):
|
||||||
"""Callback to actually do the backprop. Passed to last listener."""
|
"""Callback to actually do the backprop. Passed to last listener."""
|
||||||
|
|
|
@ -129,8 +129,8 @@ cfg_string = """
|
||||||
"""
|
"""
|
||||||
|
|
||||||
TRAIN_DATA = [
|
TRAIN_DATA = [
|
||||||
("I like green eggs", {"tags": ["N", "V", "J", "N"]}),
|
("I like green eggs", {"tags": ["N", "V", "J", "N"], "cats": {"preference": 1.0, "imperative": 0.0}}),
|
||||||
("Eat blue ham", {"tags": ["V", "J", "N"]}),
|
("Eat blue ham", {"tags": ["V", "J", "N"], "cats": {"preference": 0.0, "imperative": 1.0}}),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -318,3 +318,92 @@ def test_replace_listeners_from_config():
|
||||||
new_nlp.config["components"]["ner"]["model"]["tok2vec"]["@architectures"]
|
new_nlp.config["components"]["ner"]["model"]["tok2vec"]["@architectures"]
|
||||||
== "spacy.Tok2VecListener.v1"
|
== "spacy.Tok2VecListener.v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
cfg_string_multi_textcat = """
|
||||||
|
[nlp]
|
||||||
|
lang = "en"
|
||||||
|
pipeline = ["tok2vec","textcat_multilabel","tagger"]
|
||||||
|
|
||||||
|
[components]
|
||||||
|
|
||||||
|
[components.textcat_multilabel]
|
||||||
|
factory = "textcat_multilabel"
|
||||||
|
|
||||||
|
[components.textcat_multilabel.model]
|
||||||
|
@architectures = "spacy.TextCatEnsemble.v2"
|
||||||
|
nO = null
|
||||||
|
|
||||||
|
[components.textcat_multilabel.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
|
width = ${components.tok2vec.model.encode.width}
|
||||||
|
|
||||||
|
[components.textcat_multilabel.model.linear_model]
|
||||||
|
@architectures = "spacy.TextCatBOW.v1"
|
||||||
|
exclusive_classes = false
|
||||||
|
ngram_size = 1
|
||||||
|
no_output_layer = false
|
||||||
|
|
||||||
|
[components.tagger]
|
||||||
|
factory = "tagger"
|
||||||
|
|
||||||
|
[components.tagger.model]
|
||||||
|
@architectures = "spacy.Tagger.v1"
|
||||||
|
nO = null
|
||||||
|
|
||||||
|
[components.tagger.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
|
width = ${components.tok2vec.model.encode.width}
|
||||||
|
|
||||||
|
[components.tok2vec]
|
||||||
|
factory = "tok2vec"
|
||||||
|
|
||||||
|
[components.tok2vec.model]
|
||||||
|
@architectures = "spacy.Tok2Vec.v2"
|
||||||
|
|
||||||
|
[components.tok2vec.model.embed]
|
||||||
|
@architectures = "spacy.MultiHashEmbed.v1"
|
||||||
|
width = ${components.tok2vec.model.encode.width}
|
||||||
|
rows = [2000, 1000, 1000, 1000]
|
||||||
|
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
|
||||||
|
include_static_vectors = false
|
||||||
|
|
||||||
|
[components.tok2vec.model.encode]
|
||||||
|
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||||
|
width = 96
|
||||||
|
depth = 4
|
||||||
|
window_size = 1
|
||||||
|
maxout_pieces = 3
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_tok2vec_listeners_textcat():
|
||||||
|
orig_config = Config().from_str(cfg_string_multi_textcat)
|
||||||
|
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||||
|
assert nlp.pipe_names == ["tok2vec", "textcat_multilabel", "tagger"]
|
||||||
|
tagger = nlp.get_pipe("tagger")
|
||||||
|
textcat = nlp.get_pipe("textcat_multilabel")
|
||||||
|
tok2vec = nlp.get_pipe("tok2vec")
|
||||||
|
tagger_tok2vec = tagger.model.get_ref("tok2vec")
|
||||||
|
textcat_tok2vec = textcat.model.get_ref("tok2vec")
|
||||||
|
assert isinstance(tok2vec, Tok2Vec)
|
||||||
|
assert isinstance(tagger_tok2vec, Tok2VecListener)
|
||||||
|
assert isinstance(textcat_tok2vec, Tok2VecListener)
|
||||||
|
train_examples = []
|
||||||
|
for t in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
|
|
||||||
|
optimizer = nlp.initialize(lambda: train_examples)
|
||||||
|
for i in range(50):
|
||||||
|
losses = {}
|
||||||
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
|
||||||
|
docs = list(nlp.pipe(["Eat blue ham", "I like green eggs"]))
|
||||||
|
cats0 = docs[0].cats
|
||||||
|
assert cats0["preference"] < 0.1
|
||||||
|
assert cats0["imperative"] > 0.9
|
||||||
|
cats1 = docs[1].cats
|
||||||
|
assert cats1["preference"] > 0.1
|
||||||
|
assert cats1["imperative"] < 0.9
|
||||||
|
assert([t.tag_ for t in docs[0]] == ["V", "J", "N"])
|
||||||
|
assert([t.tag_ for t in docs[1]] == ["N", "V", "J", "N"])
|
||||||
|
|
Loading…
Reference in New Issue