From 30a76fcf6f662a3ef2d63648beba9f7a82e02150 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 31 Jul 2020 18:34:35 +0200 Subject: [PATCH] Integrate and simplify pipe analysis --- spacy/language.py | 28 ++++++++++++++++-------- spacy/pipe_analysis.py | 31 +++++++++++---------------- spacy/tests/pipeline/test_analysis.py | 22 +++++++++---------- 3 files changed, 42 insertions(+), 39 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 594a4b148..6230913b4 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -18,7 +18,7 @@ from timeit import default_timer as timer from .tokens.underscore import Underscore from .vocab import Vocab, create_vocab -from .pipe_analysis import analyze_pipes, analyze_all_pipes, validate_attrs +from .pipe_analysis import validate_attrs, print_summary from .gold import Example from .scorer import Scorer from .util import create_default_optimizer, registry @@ -37,8 +37,6 @@ from . import util from . import about -# TODO: integrate pipeline analyis -ENABLE_PIPELINE_ANALYSIS = False # This is the base config will all settings (training etc.) DEFAULT_CONFIG_PATH = Path(__file__).parent / "default_config.cfg" DEFAULT_CONFIG = Config().from_disk(DEFAULT_CONFIG_PATH) @@ -522,6 +520,24 @@ class Language: return add_component(func) return add_component + def analyze_pipes( + self, + *, + keys: List[str] = ["assigns", "requires", "scores", "retokenizes"], + pretty: bool = True, + no_print: bool = False, + ) -> Optional[Dict[str, Any]]: + """Analyze the current pipeline components, print a summary of what + they assign or require and check that all requirements are met. + + keys (List[str]): The meta values to display in the table. Corresponds + to values in FactoryMeta, defined by @Language.factory decorator. + pretty (bool): Pretty-print the results with colors and icons. + no_print (bool): Don't print anything and return structured dict instead. + RETURNS (dict): The data, if no_print is set to True. + """ + return print_summary(self, keys=keys, pretty=pretty, no_print=no_print) + def get_pipe(self, name: str) -> Callable[[Doc], Doc]: """Get a pipeline component for a given component name. @@ -666,8 +682,6 @@ class Language: pipe_index = self._get_pipe_index(before, after, first, last) self._pipe_meta[name] = self.get_factory_meta(factory_name) self.pipeline.insert(pipe_index, (name, pipe_component)) - if ENABLE_PIPELINE_ANALYSIS: - analyze_pipes(self, name, pipe_index) return pipe_component def _get_pipe_index( @@ -758,8 +772,6 @@ class Language: self.add_pipe(factory_name, name=name) else: self.add_pipe(factory_name, name=name, before=pipe_index) - if ENABLE_PIPELINE_ANALYSIS: - analyze_all_pipes(self) def rename_pipe(self, old_name: str, new_name: str) -> None: """Rename a pipeline component. @@ -793,8 +805,6 @@ class Language: # because factory may be used for something else self._pipe_meta.pop(name) self._pipe_configs.pop(name) - if ENABLE_PIPELINE_ANALYSIS: - analyze_all_pipes(self) return removed def __call__( diff --git a/spacy/pipe_analysis.py b/spacy/pipe_analysis.py index b57f1524b..71f99daef 100644 --- a/spacy/pipe_analysis.py +++ b/spacy/pipe_analysis.py @@ -42,19 +42,6 @@ def analyze_pipes( return problems -def analyze_all_pipes(nlp: "Language", warn: bool = True) -> Dict[str, List[str]]: - """Analyze all pipes in the pipeline in order. - - nlp (Language): The current nlp object. - warn (bool): Show user warning if problem is found. - RETURNS (Dict[str, List[str]]): The problems found, keyed by component name. - """ - problems = {} - for i, name in enumerate(nlp.pipe_names): - problems[name] = analyze_pipes(nlp, name, i, warn=warn) - return problems - - def validate_attrs(values: Iterable[str]) -> Iterable[str]: """Validate component attributes provided to "assigns", "requires" etc. Raises error for invalid attributes and formatting. Doesn't check if @@ -133,27 +120,35 @@ def get_requires_for_attr(nlp: "Language", attr: str) -> List[str]: def print_summary( - nlp: "Language", pretty: bool = True, no_print: bool = False + nlp: "Language", + *, + keys: List[str] = ["requires", "assigns", "scores", "retokenizes"], + pretty: bool = True, + no_print: bool = False, ) -> Optional[Dict[str, Union[List[str], Dict[str, List[str]]]]]: """Print a formatted summary for the current nlp object's pipeline. Shows a table with the pipeline components and why they assign and require, as well as any problems if available. nlp (Language): The nlp object. + keys (List[str]): The meta keys to show in the table. pretty (bool): Pretty-print the results (color etc). no_print (bool): Don't print anything, just return the data. RETURNS (dict): A dict with "overview" and "problems". """ msg = Printer(pretty=pretty, no_print=no_print) - overview = [] + overview = {} problems = {} for i, name in enumerate(nlp.pipe_names): meta = nlp.get_pipe_meta(name) - overview.append((i, name, meta.requires, meta.assigns, meta.retokenizes)) + overview[name] = {"i": i, "name": name} + for key in keys: + overview[name][key] = getattr(meta, key, None) problems[name] = analyze_pipes(nlp, name, i, warn=False) msg.divider("Pipeline Overview") - header = ("#", "Component", "Requires", "Assigns", "Retokenizes") - msg.table(overview, header=header, divider=True, multiline=True) + header = ["#", "Component", *[key.capitalize() for key in keys]] + body = [[info for info in entry.values()] for entry in overview.values()] + msg.table(body, header=header, divider=True, multiline=True) n_problems = sum(len(p) for p in problems.values()) if any(p for p in problems.values()): msg.divider(f"Problems ({n_problems})") diff --git a/spacy/tests/pipeline/test_analysis.py b/spacy/tests/pipeline/test_analysis.py index 4e1407707..7d22bb1a0 100644 --- a/spacy/tests/pipeline/test_analysis.py +++ b/spacy/tests/pipeline/test_analysis.py @@ -1,15 +1,12 @@ import spacy.language from spacy.language import Language -from spacy.pipe_analysis import print_summary, validate_attrs from spacy.pipe_analysis import get_assigns_for_attr, get_requires_for_attr -from spacy.pipe_analysis import count_pipeline_interdependencies +from spacy.pipe_analysis import validate_attrs, count_pipeline_interdependencies from mock import Mock import pytest def test_component_decorator_assigns(): - spacy.language.ENABLE_PIPELINE_ANALYSIS = True - @Language.component("c1", assigns=["token.tag", "doc.tensor"]) def test_component1(doc): return doc @@ -32,8 +29,9 @@ def test_component_decorator_assigns(): nlp = Language() nlp.add_pipe("c1") - with pytest.warns(UserWarning): - nlp.add_pipe("c2") + nlp.add_pipe("c2") + problems = nlp.analyze_pipes(no_print=True)["problems"] + assert problems["c2"] == ["token.pos"] nlp.add_pipe("c3") assert get_assigns_for_attr(nlp, "doc.tensor") == ["c1", "c2"] nlp.add_pipe("c1", name="c4") @@ -45,7 +43,6 @@ def test_component_decorator_assigns(): assert nlp.pipe_factories["c4"] == "c1" assert get_assigns_for_attr(nlp, "doc.tensor") == ["c1", "c2", "c4"] assert get_requires_for_attr(nlp, "token.pos") == ["c2"] - assert print_summary(nlp, no_print=True) assert nlp("hello world") @@ -112,11 +109,12 @@ def test_analysis_validate_attrs_remove_pipe(): nlp = Language() nlp.add_pipe("pipe_analysis_c6") - with pytest.warns(UserWarning): - nlp.add_pipe("pipe_analysis_c7") - with pytest.warns(None) as record: - nlp.remove_pipe("pipe_analysis_c7") - assert not record.list + nlp.add_pipe("pipe_analysis_c7") + problems = nlp.analyze_pipes(no_print=True)["problems"] + assert problems["pipe_analysis_c7"] == ["token.pos"] + nlp.remove_pipe("pipe_analysis_c7") + problems = nlp.analyze_pipes(no_print=True)["problems"] + assert all(p == [] for p in problems.values()) def test_pipe_interdependencies():