Integrate and simplify pipe analysis

This commit is contained in:
Ines Montani 2020-07-31 18:34:35 +02:00
parent 2d955fbf98
commit 30a76fcf6f
3 changed files with 42 additions and 39 deletions

View File

@ -18,7 +18,7 @@ from timeit import default_timer as timer
from .tokens.underscore import Underscore
from .vocab import Vocab, create_vocab
from .pipe_analysis import analyze_pipes, analyze_all_pipes, validate_attrs
from .pipe_analysis import validate_attrs, print_summary
from .gold import Example
from .scorer import Scorer
from .util import create_default_optimizer, registry
@ -37,8 +37,6 @@ from . import util
from . import about
# TODO: integrate pipeline analyis
ENABLE_PIPELINE_ANALYSIS = False
# This is the base config will all settings (training etc.)
DEFAULT_CONFIG_PATH = Path(__file__).parent / "default_config.cfg"
DEFAULT_CONFIG = Config().from_disk(DEFAULT_CONFIG_PATH)
@ -522,6 +520,24 @@ class Language:
return add_component(func)
return add_component
def analyze_pipes(
self,
*,
keys: List[str] = ["assigns", "requires", "scores", "retokenizes"],
pretty: bool = True,
no_print: bool = False,
) -> Optional[Dict[str, Any]]:
"""Analyze the current pipeline components, print a summary of what
they assign or require and check that all requirements are met.
keys (List[str]): The meta values to display in the table. Corresponds
to values in FactoryMeta, defined by @Language.factory decorator.
pretty (bool): Pretty-print the results with colors and icons.
no_print (bool): Don't print anything and return structured dict instead.
RETURNS (dict): The data, if no_print is set to True.
"""
return print_summary(self, keys=keys, pretty=pretty, no_print=no_print)
def get_pipe(self, name: str) -> Callable[[Doc], Doc]:
"""Get a pipeline component for a given component name.
@ -666,8 +682,6 @@ class Language:
pipe_index = self._get_pipe_index(before, after, first, last)
self._pipe_meta[name] = self.get_factory_meta(factory_name)
self.pipeline.insert(pipe_index, (name, pipe_component))
if ENABLE_PIPELINE_ANALYSIS:
analyze_pipes(self, name, pipe_index)
return pipe_component
def _get_pipe_index(
@ -758,8 +772,6 @@ class Language:
self.add_pipe(factory_name, name=name)
else:
self.add_pipe(factory_name, name=name, before=pipe_index)
if ENABLE_PIPELINE_ANALYSIS:
analyze_all_pipes(self)
def rename_pipe(self, old_name: str, new_name: str) -> None:
"""Rename a pipeline component.
@ -793,8 +805,6 @@ class Language:
# because factory may be used for something else
self._pipe_meta.pop(name)
self._pipe_configs.pop(name)
if ENABLE_PIPELINE_ANALYSIS:
analyze_all_pipes(self)
return removed
def __call__(

View File

@ -42,19 +42,6 @@ def analyze_pipes(
return problems
def analyze_all_pipes(nlp: "Language", warn: bool = True) -> Dict[str, List[str]]:
"""Analyze all pipes in the pipeline in order.
nlp (Language): The current nlp object.
warn (bool): Show user warning if problem is found.
RETURNS (Dict[str, List[str]]): The problems found, keyed by component name.
"""
problems = {}
for i, name in enumerate(nlp.pipe_names):
problems[name] = analyze_pipes(nlp, name, i, warn=warn)
return problems
def validate_attrs(values: Iterable[str]) -> Iterable[str]:
"""Validate component attributes provided to "assigns", "requires" etc.
Raises error for invalid attributes and formatting. Doesn't check if
@ -133,27 +120,35 @@ def get_requires_for_attr(nlp: "Language", attr: str) -> List[str]:
def print_summary(
nlp: "Language", pretty: bool = True, no_print: bool = False
nlp: "Language",
*,
keys: List[str] = ["requires", "assigns", "scores", "retokenizes"],
pretty: bool = True,
no_print: bool = False,
) -> Optional[Dict[str, Union[List[str], Dict[str, List[str]]]]]:
"""Print a formatted summary for the current nlp object's pipeline. Shows
a table with the pipeline components and why they assign and require, as
well as any problems if available.
nlp (Language): The nlp object.
keys (List[str]): The meta keys to show in the table.
pretty (bool): Pretty-print the results (color etc).
no_print (bool): Don't print anything, just return the data.
RETURNS (dict): A dict with "overview" and "problems".
"""
msg = Printer(pretty=pretty, no_print=no_print)
overview = []
overview = {}
problems = {}
for i, name in enumerate(nlp.pipe_names):
meta = nlp.get_pipe_meta(name)
overview.append((i, name, meta.requires, meta.assigns, meta.retokenizes))
overview[name] = {"i": i, "name": name}
for key in keys:
overview[name][key] = getattr(meta, key, None)
problems[name] = analyze_pipes(nlp, name, i, warn=False)
msg.divider("Pipeline Overview")
header = ("#", "Component", "Requires", "Assigns", "Retokenizes")
msg.table(overview, header=header, divider=True, multiline=True)
header = ["#", "Component", *[key.capitalize() for key in keys]]
body = [[info for info in entry.values()] for entry in overview.values()]
msg.table(body, header=header, divider=True, multiline=True)
n_problems = sum(len(p) for p in problems.values())
if any(p for p in problems.values()):
msg.divider(f"Problems ({n_problems})")

View File

@ -1,15 +1,12 @@
import spacy.language
from spacy.language import Language
from spacy.pipe_analysis import print_summary, validate_attrs
from spacy.pipe_analysis import get_assigns_for_attr, get_requires_for_attr
from spacy.pipe_analysis import count_pipeline_interdependencies
from spacy.pipe_analysis import validate_attrs, count_pipeline_interdependencies
from mock import Mock
import pytest
def test_component_decorator_assigns():
spacy.language.ENABLE_PIPELINE_ANALYSIS = True
@Language.component("c1", assigns=["token.tag", "doc.tensor"])
def test_component1(doc):
return doc
@ -32,8 +29,9 @@ def test_component_decorator_assigns():
nlp = Language()
nlp.add_pipe("c1")
with pytest.warns(UserWarning):
nlp.add_pipe("c2")
nlp.add_pipe("c2")
problems = nlp.analyze_pipes(no_print=True)["problems"]
assert problems["c2"] == ["token.pos"]
nlp.add_pipe("c3")
assert get_assigns_for_attr(nlp, "doc.tensor") == ["c1", "c2"]
nlp.add_pipe("c1", name="c4")
@ -45,7 +43,6 @@ def test_component_decorator_assigns():
assert nlp.pipe_factories["c4"] == "c1"
assert get_assigns_for_attr(nlp, "doc.tensor") == ["c1", "c2", "c4"]
assert get_requires_for_attr(nlp, "token.pos") == ["c2"]
assert print_summary(nlp, no_print=True)
assert nlp("hello world")
@ -112,11 +109,12 @@ def test_analysis_validate_attrs_remove_pipe():
nlp = Language()
nlp.add_pipe("pipe_analysis_c6")
with pytest.warns(UserWarning):
nlp.add_pipe("pipe_analysis_c7")
with pytest.warns(None) as record:
nlp.remove_pipe("pipe_analysis_c7")
assert not record.list
nlp.add_pipe("pipe_analysis_c7")
problems = nlp.analyze_pipes(no_print=True)["problems"]
assert problems["pipe_analysis_c7"] == ["token.pos"]
nlp.remove_pipe("pipe_analysis_c7")
problems = nlp.analyze_pipes(no_print=True)["problems"]
assert all(p == [] for p in problems.values())
def test_pipe_interdependencies():