diff --git a/spacy/tests/test_language.py b/spacy/tests/test_language.py index 03a98d32f..fc9229867 100644 --- a/spacy/tests/test_language.py +++ b/spacy/tests/test_language.py @@ -58,6 +58,29 @@ def nlp(): return nlp +@pytest.fixture +def nlp_multi(): + nlp = Language(Vocab()) + textcat_multilabel = nlp.add_pipe("textcat_multilabel") + for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"): + textcat_multilabel.add_label(label) + nlp.initialize() + return nlp + + +@pytest.fixture +def nlp_both(): + nlp = Language(Vocab()) + textcat = nlp.add_pipe("textcat") + for label in ("POSITIVE", "NEGATIVE"): + textcat.add_label(label) + textcat_multilabel = nlp.add_pipe("textcat_multilabel") + for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"): + textcat_multilabel.add_label(label) + nlp.initialize() + return nlp + + def test_language_update(nlp): text = "hello world" annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}} @@ -91,6 +114,9 @@ def test_language_evaluate(nlp): example = Example.from_dict(doc, annots) scores = nlp.evaluate([example]) assert scores["speed"] > 0 + assert scores["cats_f_per_type"].get("POSITIVE") is not None + assert scores["cats_f_per_type"].get("NEGATIVE") is not None + assert scores["cats_f_per_type"].get("BUG") is None # test with generator scores = nlp.evaluate(eg for eg in [example]) @@ -126,6 +152,35 @@ def test_evaluate_no_pipe(nlp): nlp.evaluate([Example.from_dict(doc, annots)]) +def test_evaluate_textcat(nlp_multi): + """Test that evaluate works with a multilabel textcat pipe.""" + text = "hello world" + annots = {"doc_annotation": {"cats": {"FEATURE": 1.0, "QUESTION": 1.0}}} + doc = Doc(nlp_multi.vocab, words=text.split(" ")) + example = Example.from_dict(doc, annots) + scores = nlp_multi.evaluate([example]) + assert scores["cats_f_per_type"].get("FEATURE") is not None + assert scores["cats_f_per_type"].get("QUESTION") is not None + assert scores["cats_f_per_type"].get("REQUEST") is not None + assert scores["cats_f_per_type"].get("BUG") is not None + assert scores["cats_f_per_type"].get("POSITIVE") is None + assert scores["cats_f_per_type"].get("NEGATIVE") is None + + +def test_evaluate_both(nlp_both): + """Test that evaluate works with two textcat pipes.""" + text = "hello world" + annots = {"doc_annotation": {"cats": {"FEATURE": 1.0, "QUESTION": 1.0, "POSITIVE": 1.0, "NEGATIVE": 0.0}}} + doc = Doc(nlp_both.vocab, words=text.split(" ")) + example = Example.from_dict(doc, annots) + scores = nlp_both.evaluate([example]) + assert scores["cats_f_per_type"].get("FEATURE") is not None + assert scores["cats_f_per_type"].get("QUESTION") is not None + assert scores["cats_f_per_type"].get("BUG") is not None + assert scores["cats_f_per_type"].get("POSITIVE") is not None + assert scores["cats_f_per_type"].get("NEGATIVE") is not None + + def vector_modification_pipe(doc): doc.vector += 1 return doc