Add tests for Pipe.label_data

This commit is contained in:
Ines Montani 2020-10-03 19:12:46 +02:00
parent 80603f0fa5
commit c2401fca41
1 changed files with 32 additions and 1 deletions

View File

@ -1,6 +1,6 @@
import pytest import pytest
from spacy.language import Language from spacy.language import Language
from spacy.util import SimpleFrozenList from spacy.util import SimpleFrozenList, get_arg_names
@pytest.fixture @pytest.fixture
@ -346,3 +346,34 @@ def test_pipe_methods_frozen():
nlp.components.sort() nlp.components.sort()
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
nlp.component_names.clear() nlp.component_names.clear()
@pytest.mark.parametrize(
"pipe",
[
"tagger",
"parser",
"ner",
"textcat",
pytest.param("morphologizer", marks=pytest.mark.xfail),
],
)
def test_pipe_label_data_exports_labels(pipe):
nlp = Language()
pipe = nlp.add_pipe(pipe)
# Make sure pipe has pipe labels
assert getattr(pipe, "label_data", None) is not None
# Make sure pipe can be initialized with labels
initialize = getattr(pipe, "initialize", None)
assert initialize is not None
assert "labels" in get_arg_names(initialize)
@pytest.mark.parametrize("pipe", ["senter", "entity_linker"])
def test_pipe_label_data_no_labels(pipe):
nlp = Language()
pipe = nlp.add_pipe(pipe)
assert getattr(pipe, "label_data", None) is None
initialize = getattr(pipe, "initialize", None)
if initialize is not None:
assert "labels" not in get_arg_names(initialize)