mirror of https://github.com/explosion/spaCy.git
Determine labels by factory name in debug data (#10079)
* Determine labels by factory name in debug data For all components, return labels for all components with the corresponding factory name rather than for only the default name. For `spancat`, return labels as a dict keyed by `spans_key`. * Refactor for typing * Add test * Use assert instead of cast, removed unneeded arg * Mark test as slow
This commit is contained in:
parent
e9c6314539
commit
a55212fca0
|
@ -14,7 +14,7 @@ from ..training.initialize import get_sourced_components
|
||||||
from ..schemas import ConfigSchemaTraining
|
from ..schemas import ConfigSchemaTraining
|
||||||
from ..pipeline._parser_internals import nonproj
|
from ..pipeline._parser_internals import nonproj
|
||||||
from ..pipeline._parser_internals.nonproj import DELIMITER
|
from ..pipeline._parser_internals.nonproj import DELIMITER
|
||||||
from ..pipeline import Morphologizer
|
from ..pipeline import Morphologizer, SpanCategorizer
|
||||||
from ..morphology import Morphology
|
from ..morphology import Morphology
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from ..util import registry, resolve_dot_names
|
from ..util import registry, resolve_dot_names
|
||||||
|
@ -699,8 +699,34 @@ def _get_examples_without_label(data: Sequence[Example], label: str) -> int:
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
def _get_labels_from_model(nlp: Language, pipe_name: str) -> Set[str]:
|
def _get_labels_from_model(
|
||||||
if pipe_name not in nlp.pipe_names:
|
nlp: Language, factory_name: str
|
||||||
return set()
|
) -> Set[str]:
|
||||||
|
pipe_names = [
|
||||||
|
pipe_name
|
||||||
|
for pipe_name in nlp.pipe_names
|
||||||
|
if nlp.get_pipe_meta(pipe_name).factory == factory_name
|
||||||
|
]
|
||||||
|
labels: Set[str] = set()
|
||||||
|
for pipe_name in pipe_names:
|
||||||
pipe = nlp.get_pipe(pipe_name)
|
pipe = nlp.get_pipe(pipe_name)
|
||||||
return set(pipe.labels)
|
labels.update(pipe.labels)
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
def _get_labels_from_spancat(
|
||||||
|
nlp: Language
|
||||||
|
) -> Dict[str, Set[str]]:
|
||||||
|
pipe_names = [
|
||||||
|
pipe_name
|
||||||
|
for pipe_name in nlp.pipe_names
|
||||||
|
if nlp.get_pipe_meta(pipe_name).factory == "spancat"
|
||||||
|
]
|
||||||
|
labels: Dict[str, Set[str]] = {}
|
||||||
|
for pipe_name in pipe_names:
|
||||||
|
pipe = nlp.get_pipe(pipe_name)
|
||||||
|
assert isinstance(pipe, SpanCategorizer)
|
||||||
|
if pipe.key not in labels:
|
||||||
|
labels[pipe.key] = set()
|
||||||
|
labels[pipe.key].update(pipe.labels)
|
||||||
|
return labels
|
||||||
|
|
|
@ -12,6 +12,8 @@ from spacy.cli._util import is_subpath_of, load_project_config
|
||||||
from spacy.cli._util import parse_config_overrides, string_to_list
|
from spacy.cli._util import parse_config_overrides, string_to_list
|
||||||
from spacy.cli._util import substitute_project_variables
|
from spacy.cli._util import substitute_project_variables
|
||||||
from spacy.cli._util import validate_project_commands
|
from spacy.cli._util import validate_project_commands
|
||||||
|
from spacy.cli.debug_data import _get_labels_from_model
|
||||||
|
from spacy.cli.debug_data import _get_labels_from_spancat
|
||||||
from spacy.cli.download import get_compatibility, get_version
|
from spacy.cli.download import get_compatibility, get_version
|
||||||
from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
|
from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
|
||||||
from spacy.cli.package import get_third_party_dependencies
|
from spacy.cli.package import get_third_party_dependencies
|
||||||
|
@ -665,3 +667,28 @@ def test_get_third_party_dependencies():
|
||||||
)
|
)
|
||||||
def test_is_subpath_of(parent, child, expected):
|
def test_is_subpath_of(parent, child, expected):
|
||||||
assert is_subpath_of(parent, child) == expected
|
assert is_subpath_of(parent, child) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"factory_name,pipe_name",
|
||||||
|
[
|
||||||
|
("ner", "ner"),
|
||||||
|
("ner", "my_ner"),
|
||||||
|
("spancat", "spancat"),
|
||||||
|
("spancat", "my_spancat"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_labels_from_model(factory_name, pipe_name):
|
||||||
|
labels = ("A", "B")
|
||||||
|
|
||||||
|
nlp = English()
|
||||||
|
pipe = nlp.add_pipe(factory_name, name=pipe_name)
|
||||||
|
for label in labels:
|
||||||
|
pipe.add_label(label)
|
||||||
|
nlp.initialize()
|
||||||
|
assert nlp.get_pipe(pipe_name).labels == labels
|
||||||
|
if factory_name == "spancat":
|
||||||
|
assert _get_labels_from_spancat(nlp)[pipe.key] == set(labels)
|
||||||
|
else:
|
||||||
|
assert _get_labels_from_model(nlp, factory_name) == set(labels)
|
||||||
|
|
Loading…
Reference in New Issue