diff --git a/spacy/cli/debug_data.py b/spacy/cli/debug_data.py index 688b07a9b..b9831fe0c 100644 --- a/spacy/cli/debug_data.py +++ b/spacy/cli/debug_data.py @@ -14,7 +14,7 @@ from ..training.initialize import get_sourced_components from ..schemas import ConfigSchemaTraining from ..pipeline._parser_internals import nonproj from ..pipeline._parser_internals.nonproj import DELIMITER -from ..pipeline import Morphologizer +from ..pipeline import Morphologizer, SpanCategorizer from ..morphology import Morphology from ..language import Language from ..util import registry, resolve_dot_names @@ -699,8 +699,34 @@ def _get_examples_without_label(data: Sequence[Example], label: str) -> int: return count -def _get_labels_from_model(nlp: Language, pipe_name: str) -> Set[str]: - if pipe_name not in nlp.pipe_names: - return set() - pipe = nlp.get_pipe(pipe_name) - return set(pipe.labels) +def _get_labels_from_model( + nlp: Language, factory_name: str +) -> Set[str]: + pipe_names = [ + pipe_name + for pipe_name in nlp.pipe_names + if nlp.get_pipe_meta(pipe_name).factory == factory_name + ] + labels: Set[str] = set() + for pipe_name in pipe_names: + pipe = nlp.get_pipe(pipe_name) + labels.update(pipe.labels) + return labels + + +def _get_labels_from_spancat( + nlp: Language +) -> Dict[str, Set[str]]: + pipe_names = [ + pipe_name + for pipe_name in nlp.pipe_names + if nlp.get_pipe_meta(pipe_name).factory == "spancat" + ] + labels: Dict[str, Set[str]] = {} + for pipe_name in pipe_names: + pipe = nlp.get_pipe(pipe_name) + assert isinstance(pipe, SpanCategorizer) + if pipe.key not in labels: + labels[pipe.key] = set() + labels[pipe.key].update(pipe.labels) + return labels diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index b0862eab6..253469909 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -12,6 +12,8 @@ from spacy.cli._util import is_subpath_of, load_project_config from spacy.cli._util import parse_config_overrides, string_to_list from spacy.cli._util import substitute_project_variables from spacy.cli._util import validate_project_commands +from spacy.cli.debug_data import _get_labels_from_model +from spacy.cli.debug_data import _get_labels_from_spancat from spacy.cli.download import get_compatibility, get_version from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config from spacy.cli.package import get_third_party_dependencies @@ -665,3 +667,28 @@ def test_get_third_party_dependencies(): ) def test_is_subpath_of(parent, child, expected): assert is_subpath_of(parent, child) == expected + + +@pytest.mark.slow +@pytest.mark.parametrize( + "factory_name,pipe_name", + [ + ("ner", "ner"), + ("ner", "my_ner"), + ("spancat", "spancat"), + ("spancat", "my_spancat"), + ], +) +def test_get_labels_from_model(factory_name, pipe_name): + labels = ("A", "B") + + nlp = English() + pipe = nlp.add_pipe(factory_name, name=pipe_name) + for label in labels: + pipe.add_label(label) + nlp.initialize() + assert nlp.get_pipe(pipe_name).labels == labels + if factory_name == "spancat": + assert _get_labels_from_spancat(nlp)[pipe.key] == set(labels) + else: + assert _get_labels_from_model(nlp, factory_name) == set(labels)