Determine labels by factory name in debug data (#10079)

* Determine labels by factory name in debug data

For all components, return labels for all components with the
corresponding factory name rather than for only the default name.

For `spancat`, return labels as a dict keyed by `spans_key`.

* Refactor for typing

* Add test

* Use assert instead of cast, removed unneeded arg

* Mark test as slow
This commit is contained in:
Adriane Boyd 2022-01-20 11:42:52 +01:00 committed by GitHub
parent e9c6314539
commit a55212fca0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 6 deletions

View File

@ -14,7 +14,7 @@ from ..training.initialize import get_sourced_components
from ..schemas import ConfigSchemaTraining
from ..pipeline._parser_internals import nonproj
from ..pipeline._parser_internals.nonproj import DELIMITER
from ..pipeline import Morphologizer
from ..pipeline import Morphologizer, SpanCategorizer
from ..morphology import Morphology
from ..language import Language
from ..util import registry, resolve_dot_names
@ -699,8 +699,34 @@ def _get_examples_without_label(data: Sequence[Example], label: str) -> int:
return count
def _get_labels_from_model(nlp: Language, pipe_name: str) -> Set[str]:
if pipe_name not in nlp.pipe_names:
return set()
def _get_labels_from_model(
nlp: Language, factory_name: str
) -> Set[str]:
pipe_names = [
pipe_name
for pipe_name in nlp.pipe_names
if nlp.get_pipe_meta(pipe_name).factory == factory_name
]
labels: Set[str] = set()
for pipe_name in pipe_names:
pipe = nlp.get_pipe(pipe_name)
return set(pipe.labels)
labels.update(pipe.labels)
return labels
def _get_labels_from_spancat(
nlp: Language
) -> Dict[str, Set[str]]:
pipe_names = [
pipe_name
for pipe_name in nlp.pipe_names
if nlp.get_pipe_meta(pipe_name).factory == "spancat"
]
labels: Dict[str, Set[str]] = {}
for pipe_name in pipe_names:
pipe = nlp.get_pipe(pipe_name)
assert isinstance(pipe, SpanCategorizer)
if pipe.key not in labels:
labels[pipe.key] = set()
labels[pipe.key].update(pipe.labels)
return labels

View File

@ -12,6 +12,8 @@ from spacy.cli._util import is_subpath_of, load_project_config
from spacy.cli._util import parse_config_overrides, string_to_list
from spacy.cli._util import substitute_project_variables
from spacy.cli._util import validate_project_commands
from spacy.cli.debug_data import _get_labels_from_model
from spacy.cli.debug_data import _get_labels_from_spancat
from spacy.cli.download import get_compatibility, get_version
from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
from spacy.cli.package import get_third_party_dependencies
@ -665,3 +667,28 @@ def test_get_third_party_dependencies():
)
def test_is_subpath_of(parent, child, expected):
assert is_subpath_of(parent, child) == expected
@pytest.mark.slow
@pytest.mark.parametrize(
"factory_name,pipe_name",
[
("ner", "ner"),
("ner", "my_ner"),
("spancat", "spancat"),
("spancat", "my_spancat"),
],
)
def test_get_labels_from_model(factory_name, pipe_name):
labels = ("A", "B")
nlp = English()
pipe = nlp.add_pipe(factory_name, name=pipe_name)
for label in labels:
pipe.add_label(label)
nlp.initialize()
assert nlp.get_pipe(pipe_name).labels == labels
if factory_name == "spancat":
assert _get_labels_from_spancat(nlp)[pipe.key] == set(labels)
else:
assert _get_labels_from_model(nlp, factory_name) == set(labels)