Integrate and simplify pipe analysis

This commit is contained in:
Ines Montani 2020-07-31 18:34:35 +02:00
parent 2d955fbf98
commit 30a76fcf6f
3 changed files with 42 additions and 39 deletions

View File

@ -18,7 +18,7 @@ from timeit import default_timer as timer
from .tokens.underscore import Underscore from .tokens.underscore import Underscore
from .vocab import Vocab, create_vocab from .vocab import Vocab, create_vocab
from .pipe_analysis import analyze_pipes, analyze_all_pipes, validate_attrs from .pipe_analysis import validate_attrs, print_summary
from .gold import Example from .gold import Example
from .scorer import Scorer from .scorer import Scorer
from .util import create_default_optimizer, registry from .util import create_default_optimizer, registry
@ -37,8 +37,6 @@ from . import util
from . import about from . import about
# TODO: integrate pipeline analyis
ENABLE_PIPELINE_ANALYSIS = False
# This is the base config will all settings (training etc.) # This is the base config will all settings (training etc.)
DEFAULT_CONFIG_PATH = Path(__file__).parent / "default_config.cfg" DEFAULT_CONFIG_PATH = Path(__file__).parent / "default_config.cfg"
DEFAULT_CONFIG = Config().from_disk(DEFAULT_CONFIG_PATH) DEFAULT_CONFIG = Config().from_disk(DEFAULT_CONFIG_PATH)
@ -522,6 +520,24 @@ class Language:
return add_component(func) return add_component(func)
return add_component return add_component
def analyze_pipes(
self,
*,
keys: List[str] = ["assigns", "requires", "scores", "retokenizes"],
pretty: bool = True,
no_print: bool = False,
) -> Optional[Dict[str, Any]]:
"""Analyze the current pipeline components, print a summary of what
they assign or require and check that all requirements are met.
keys (List[str]): The meta values to display in the table. Corresponds
to values in FactoryMeta, defined by @Language.factory decorator.
pretty (bool): Pretty-print the results with colors and icons.
no_print (bool): Don't print anything and return structured dict instead.
RETURNS (dict): The data, if no_print is set to True.
"""
return print_summary(self, keys=keys, pretty=pretty, no_print=no_print)
def get_pipe(self, name: str) -> Callable[[Doc], Doc]: def get_pipe(self, name: str) -> Callable[[Doc], Doc]:
"""Get a pipeline component for a given component name. """Get a pipeline component for a given component name.
@ -666,8 +682,6 @@ class Language:
pipe_index = self._get_pipe_index(before, after, first, last) pipe_index = self._get_pipe_index(before, after, first, last)
self._pipe_meta[name] = self.get_factory_meta(factory_name) self._pipe_meta[name] = self.get_factory_meta(factory_name)
self.pipeline.insert(pipe_index, (name, pipe_component)) self.pipeline.insert(pipe_index, (name, pipe_component))
if ENABLE_PIPELINE_ANALYSIS:
analyze_pipes(self, name, pipe_index)
return pipe_component return pipe_component
def _get_pipe_index( def _get_pipe_index(
@ -758,8 +772,6 @@ class Language:
self.add_pipe(factory_name, name=name) self.add_pipe(factory_name, name=name)
else: else:
self.add_pipe(factory_name, name=name, before=pipe_index) self.add_pipe(factory_name, name=name, before=pipe_index)
if ENABLE_PIPELINE_ANALYSIS:
analyze_all_pipes(self)
def rename_pipe(self, old_name: str, new_name: str) -> None: def rename_pipe(self, old_name: str, new_name: str) -> None:
"""Rename a pipeline component. """Rename a pipeline component.
@ -793,8 +805,6 @@ class Language:
# because factory may be used for something else # because factory may be used for something else
self._pipe_meta.pop(name) self._pipe_meta.pop(name)
self._pipe_configs.pop(name) self._pipe_configs.pop(name)
if ENABLE_PIPELINE_ANALYSIS:
analyze_all_pipes(self)
return removed return removed
def __call__( def __call__(

View File

@ -42,19 +42,6 @@ def analyze_pipes(
return problems return problems
def analyze_all_pipes(nlp: "Language", warn: bool = True) -> Dict[str, List[str]]:
"""Analyze all pipes in the pipeline in order.
nlp (Language): The current nlp object.
warn (bool): Show user warning if problem is found.
RETURNS (Dict[str, List[str]]): The problems found, keyed by component name.
"""
problems = {}
for i, name in enumerate(nlp.pipe_names):
problems[name] = analyze_pipes(nlp, name, i, warn=warn)
return problems
def validate_attrs(values: Iterable[str]) -> Iterable[str]: def validate_attrs(values: Iterable[str]) -> Iterable[str]:
"""Validate component attributes provided to "assigns", "requires" etc. """Validate component attributes provided to "assigns", "requires" etc.
Raises error for invalid attributes and formatting. Doesn't check if Raises error for invalid attributes and formatting. Doesn't check if
@ -133,27 +120,35 @@ def get_requires_for_attr(nlp: "Language", attr: str) -> List[str]:
def print_summary( def print_summary(
nlp: "Language", pretty: bool = True, no_print: bool = False nlp: "Language",
*,
keys: List[str] = ["requires", "assigns", "scores", "retokenizes"],
pretty: bool = True,
no_print: bool = False,
) -> Optional[Dict[str, Union[List[str], Dict[str, List[str]]]]]: ) -> Optional[Dict[str, Union[List[str], Dict[str, List[str]]]]]:
"""Print a formatted summary for the current nlp object's pipeline. Shows """Print a formatted summary for the current nlp object's pipeline. Shows
a table with the pipeline components and why they assign and require, as a table with the pipeline components and why they assign and require, as
well as any problems if available. well as any problems if available.
nlp (Language): The nlp object. nlp (Language): The nlp object.
keys (List[str]): The meta keys to show in the table.
pretty (bool): Pretty-print the results (color etc). pretty (bool): Pretty-print the results (color etc).
no_print (bool): Don't print anything, just return the data. no_print (bool): Don't print anything, just return the data.
RETURNS (dict): A dict with "overview" and "problems". RETURNS (dict): A dict with "overview" and "problems".
""" """
msg = Printer(pretty=pretty, no_print=no_print) msg = Printer(pretty=pretty, no_print=no_print)
overview = [] overview = {}
problems = {} problems = {}
for i, name in enumerate(nlp.pipe_names): for i, name in enumerate(nlp.pipe_names):
meta = nlp.get_pipe_meta(name) meta = nlp.get_pipe_meta(name)
overview.append((i, name, meta.requires, meta.assigns, meta.retokenizes)) overview[name] = {"i": i, "name": name}
for key in keys:
overview[name][key] = getattr(meta, key, None)
problems[name] = analyze_pipes(nlp, name, i, warn=False) problems[name] = analyze_pipes(nlp, name, i, warn=False)
msg.divider("Pipeline Overview") msg.divider("Pipeline Overview")
header = ("#", "Component", "Requires", "Assigns", "Retokenizes") header = ["#", "Component", *[key.capitalize() for key in keys]]
msg.table(overview, header=header, divider=True, multiline=True) body = [[info for info in entry.values()] for entry in overview.values()]
msg.table(body, header=header, divider=True, multiline=True)
n_problems = sum(len(p) for p in problems.values()) n_problems = sum(len(p) for p in problems.values())
if any(p for p in problems.values()): if any(p for p in problems.values()):
msg.divider(f"Problems ({n_problems})") msg.divider(f"Problems ({n_problems})")

View File

@ -1,15 +1,12 @@
import spacy.language import spacy.language
from spacy.language import Language from spacy.language import Language
from spacy.pipe_analysis import print_summary, validate_attrs
from spacy.pipe_analysis import get_assigns_for_attr, get_requires_for_attr from spacy.pipe_analysis import get_assigns_for_attr, get_requires_for_attr
from spacy.pipe_analysis import count_pipeline_interdependencies from spacy.pipe_analysis import validate_attrs, count_pipeline_interdependencies
from mock import Mock from mock import Mock
import pytest import pytest
def test_component_decorator_assigns(): def test_component_decorator_assigns():
spacy.language.ENABLE_PIPELINE_ANALYSIS = True
@Language.component("c1", assigns=["token.tag", "doc.tensor"]) @Language.component("c1", assigns=["token.tag", "doc.tensor"])
def test_component1(doc): def test_component1(doc):
return doc return doc
@ -32,8 +29,9 @@ def test_component_decorator_assigns():
nlp = Language() nlp = Language()
nlp.add_pipe("c1") nlp.add_pipe("c1")
with pytest.warns(UserWarning): nlp.add_pipe("c2")
nlp.add_pipe("c2") problems = nlp.analyze_pipes(no_print=True)["problems"]
assert problems["c2"] == ["token.pos"]
nlp.add_pipe("c3") nlp.add_pipe("c3")
assert get_assigns_for_attr(nlp, "doc.tensor") == ["c1", "c2"] assert get_assigns_for_attr(nlp, "doc.tensor") == ["c1", "c2"]
nlp.add_pipe("c1", name="c4") nlp.add_pipe("c1", name="c4")
@ -45,7 +43,6 @@ def test_component_decorator_assigns():
assert nlp.pipe_factories["c4"] == "c1" assert nlp.pipe_factories["c4"] == "c1"
assert get_assigns_for_attr(nlp, "doc.tensor") == ["c1", "c2", "c4"] assert get_assigns_for_attr(nlp, "doc.tensor") == ["c1", "c2", "c4"]
assert get_requires_for_attr(nlp, "token.pos") == ["c2"] assert get_requires_for_attr(nlp, "token.pos") == ["c2"]
assert print_summary(nlp, no_print=True)
assert nlp("hello world") assert nlp("hello world")
@ -112,11 +109,12 @@ def test_analysis_validate_attrs_remove_pipe():
nlp = Language() nlp = Language()
nlp.add_pipe("pipe_analysis_c6") nlp.add_pipe("pipe_analysis_c6")
with pytest.warns(UserWarning): nlp.add_pipe("pipe_analysis_c7")
nlp.add_pipe("pipe_analysis_c7") problems = nlp.analyze_pipes(no_print=True)["problems"]
with pytest.warns(None) as record: assert problems["pipe_analysis_c7"] == ["token.pos"]
nlp.remove_pipe("pipe_analysis_c7") nlp.remove_pipe("pipe_analysis_c7")
assert not record.list problems = nlp.analyze_pipes(no_print=True)["problems"]
assert all(p == [] for p in problems.values())
def test_pipe_interdependencies(): def test_pipe_interdependencies():