diff --git a/spacy/tests/test_textcat.py b/spacy/tests/test_textcat.py new file mode 100644 index 000000000..20f21131a --- /dev/null +++ b/spacy/tests/test_textcat.py @@ -0,0 +1,41 @@ +import random + +from ..pipeline import TextCategorizer +from ..lang.en import English +from ..vocab import Vocab +from ..tokens import Doc +from ..gold import GoldParse + + +def test_textcat_learns_multilabel(): + docs = [] + nlp = English() + vocab = nlp.vocab + letters = ['a', 'b', 'c'] + for w1 in letters: + for w2 in letters: + cats = {letter: float(w2==letter) for letter in letters} + docs.append((Doc(vocab, words=['d']*3 + [w1, w2] + ['d']*3), cats)) + random.shuffle(docs) + model = TextCategorizer(vocab, width=8) + for letter in letters: + model.add_label(letter) + optimizer = model.begin_training() + for i in range(20): + losses = {} + Ys = [GoldParse(doc, cats=cats) for doc, cats in docs] + Xs = [doc for doc, cats in docs] + model.update(Xs, Ys, sgd=optimizer, losses=losses) + random.shuffle(docs) + for w1 in letters: + for w2 in letters: + doc = Doc(vocab, words=['d']*3 + [w1, w2] + ['d']*3) + truth = {letter: w2==letter for letter in letters} + model(doc) + for cat, score in doc.cats.items(): + print(doc, cat, score) + if not truth[cat]: + assert score < 0.5 + else: + assert score > 0.5 +