mirror of https://github.com/explosion/spaCy.git
Determine labels by factory name in debug data (#10079)
* Determine labels by factory name in debug data For all components, return labels for all components with the corresponding factory name rather than for only the default name. For `spancat`, return labels as a dict keyed by `spans_key`. * Refactor for typing * Add test * Use assert instead of cast, removed unneeded arg * Mark test as slow
This commit is contained in:
parent
e9c6314539
commit
a55212fca0
|
@ -14,7 +14,7 @@ from ..training.initialize import get_sourced_components
|
|||
from ..schemas import ConfigSchemaTraining
|
||||
from ..pipeline._parser_internals import nonproj
|
||||
from ..pipeline._parser_internals.nonproj import DELIMITER
|
||||
from ..pipeline import Morphologizer
|
||||
from ..pipeline import Morphologizer, SpanCategorizer
|
||||
from ..morphology import Morphology
|
||||
from ..language import Language
|
||||
from ..util import registry, resolve_dot_names
|
||||
|
@ -699,8 +699,34 @@ def _get_examples_without_label(data: Sequence[Example], label: str) -> int:
|
|||
return count
|
||||
|
||||
|
||||
def _get_labels_from_model(nlp: Language, pipe_name: str) -> Set[str]:
|
||||
if pipe_name not in nlp.pipe_names:
|
||||
return set()
|
||||
pipe = nlp.get_pipe(pipe_name)
|
||||
return set(pipe.labels)
|
||||
def _get_labels_from_model(
|
||||
nlp: Language, factory_name: str
|
||||
) -> Set[str]:
|
||||
pipe_names = [
|
||||
pipe_name
|
||||
for pipe_name in nlp.pipe_names
|
||||
if nlp.get_pipe_meta(pipe_name).factory == factory_name
|
||||
]
|
||||
labels: Set[str] = set()
|
||||
for pipe_name in pipe_names:
|
||||
pipe = nlp.get_pipe(pipe_name)
|
||||
labels.update(pipe.labels)
|
||||
return labels
|
||||
|
||||
|
||||
def _get_labels_from_spancat(
|
||||
nlp: Language
|
||||
) -> Dict[str, Set[str]]:
|
||||
pipe_names = [
|
||||
pipe_name
|
||||
for pipe_name in nlp.pipe_names
|
||||
if nlp.get_pipe_meta(pipe_name).factory == "spancat"
|
||||
]
|
||||
labels: Dict[str, Set[str]] = {}
|
||||
for pipe_name in pipe_names:
|
||||
pipe = nlp.get_pipe(pipe_name)
|
||||
assert isinstance(pipe, SpanCategorizer)
|
||||
if pipe.key not in labels:
|
||||
labels[pipe.key] = set()
|
||||
labels[pipe.key].update(pipe.labels)
|
||||
return labels
|
||||
|
|
|
@ -12,6 +12,8 @@ from spacy.cli._util import is_subpath_of, load_project_config
|
|||
from spacy.cli._util import parse_config_overrides, string_to_list
|
||||
from spacy.cli._util import substitute_project_variables
|
||||
from spacy.cli._util import validate_project_commands
|
||||
from spacy.cli.debug_data import _get_labels_from_model
|
||||
from spacy.cli.debug_data import _get_labels_from_spancat
|
||||
from spacy.cli.download import get_compatibility, get_version
|
||||
from spacy.cli.init_config import RECOMMENDATIONS, init_config, fill_config
|
||||
from spacy.cli.package import get_third_party_dependencies
|
||||
|
@ -665,3 +667,28 @@ def test_get_third_party_dependencies():
|
|||
)
|
||||
def test_is_subpath_of(parent, child, expected):
|
||||
assert is_subpath_of(parent, child) == expected
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize(
|
||||
"factory_name,pipe_name",
|
||||
[
|
||||
("ner", "ner"),
|
||||
("ner", "my_ner"),
|
||||
("spancat", "spancat"),
|
||||
("spancat", "my_spancat"),
|
||||
],
|
||||
)
|
||||
def test_get_labels_from_model(factory_name, pipe_name):
|
||||
labels = ("A", "B")
|
||||
|
||||
nlp = English()
|
||||
pipe = nlp.add_pipe(factory_name, name=pipe_name)
|
||||
for label in labels:
|
||||
pipe.add_label(label)
|
||||
nlp.initialize()
|
||||
assert nlp.get_pipe(pipe_name).labels == labels
|
||||
if factory_name == "spancat":
|
||||
assert _get_labels_from_spancat(nlp)[pipe.key] == set(labels)
|
||||
else:
|
||||
assert _get_labels_from_model(nlp, factory_name) == set(labels)
|
||||
|
|
Loading…
Reference in New Issue