Determine labels by factory name in debug data (#10079)

* Determine labels by factory name in debug data

For all components, return labels for all components with the
corresponding factory name rather than for only the default name.

For `spancat`, return labels as a dict keyed by `spans_key`.

* Refactor for typing

* Add test

* Use assert instead of cast, removed unneeded arg

* Mark test as slow
This commit is contained in:
Adriane Boyd 2022-01-20 11:42:52 +01:00 committed by GitHub
parent e9c6314539
commit a55212fca0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 6 deletions

View File

@ -14,7 +14,7 @@ from ..training.initialize import get_sourced_components
from ..schemas import ConfigSchemaTraining from ..schemas import ConfigSchemaTraining
from ..pipeline._parser_internals import nonproj from ..pipeline._parser_internals import nonproj
from ..pipeline._parser_internals.nonproj import DELIMITER from ..pipeline._parser_internals.nonproj import DELIMITER
from ..pipeline import Morphologizer from ..pipeline import Morphologizer, SpanCategorizer
from ..morphology import Morphology from ..morphology import Morphology
from ..language import Language from ..language import Language
from ..util import registry, resolve_dot_names from ..util import registry, resolve_dot_names
@ -699,8 +699,34 @@ def _get_examples_without_label(data: Sequence[Example], label: str) -> int:
return count return count
def _get_labels_from_model(nlp: Language, pipe_name: str) -> Set[str]: def _get_labels_from_model(
if pipe_name not in nlp.pipe_names: nlp: Language, factory_name: str
return set() ) -> Set[str]:
pipe = nlp.get_pipe(pipe_name) pipe_names = [
return set(pipe.labels) pipe_name
for pipe_name in nlp.pipe_names
if nlp.get_pipe_meta(pipe_name).factory == factory_name
]
labels: Set[str] = set()
for pipe_name in pipe_names:
pipe = nlp.get_pipe(pipe_name)
labels.update(pipe.labels)
return labels
def _get_labels_from_spancat(
nlp: Language
) -> Dict[str, Set[str]]:
pipe_names = [
pipe_name
for pipe_name in nlp.pipe_names
if nlp.get_pipe_meta(pipe_name).factory == "spancat"
]
labels: Dict[str, Set[str]] = {}
for pipe_name in pipe_names:
pipe = nlp.get_pipe(pipe_name)
assert isinstance(pipe, SpanCategorizer)
if pipe.key not in labels:
labels[pipe.key] = set()
labels[pipe.key].update(pipe.labels)
return labels

View File

@ -12,6 +12,8 @@ from spacy.cli._util import is_subpath_of, load_project_config
from spacy.cli._util import parse_config_overrides, string_to_list from spacy.cli._util import parse_config_overrides, string_to_list
from spacy.cli._util import substitute_project_variables from spacy.cli._util import substitute_project_variables
from spacy.cli._util import validate_project_commands from spacy.cli._util import validate_project_commands
from spacy.cli.debug_data import _get_labels_from_model
from spacy.cli.debug_data import _get_labels_from_spancat
from spacy.cli.download import get_compatibility, get_version from spacy.cli.download import get_compatibility, get_version
from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
from spacy.cli.package import get_third_party_dependencies from spacy.cli.package import get_third_party_dependencies
@ -665,3 +667,28 @@ def test_get_third_party_dependencies():
) )
def test_is_subpath_of(parent, child, expected): def test_is_subpath_of(parent, child, expected):
assert is_subpath_of(parent, child) == expected assert is_subpath_of(parent, child) == expected
@pytest.mark.slow
@pytest.mark.parametrize(
"factory_name,pipe_name",
[
("ner", "ner"),
("ner", "my_ner"),
("spancat", "spancat"),
("spancat", "my_spancat"),
],
)
def test_get_labels_from_model(factory_name, pipe_name):
labels = ("A", "B")
nlp = English()
pipe = nlp.add_pipe(factory_name, name=pipe_name)
for label in labels:
pipe.add_label(label)
nlp.initialize()
assert nlp.get_pipe(pipe_name).labels == labels
if factory_name == "spancat":
assert _get_labels_from_spancat(nlp)[pipe.key] == set(labels)
else:
assert _get_labels_from_model(nlp, factory_name) == set(labels)