mirror of https://github.com/explosion/spaCy.git
190 lines
7.9 KiB
Python
190 lines
7.9 KiB
Python
from typing import List, Dict, Iterable, Optional, Union, TYPE_CHECKING
|
|
from wasabi import Printer
|
|
import warnings
|
|
|
|
from .tokens import Doc, Token, Span
|
|
from .errors import Errors, Warnings
|
|
from .util import dot_to_dict
|
|
|
|
if TYPE_CHECKING:
|
|
# This lets us add type hints for mypy etc. without causing circular imports
|
|
from .language import Language # noqa: F401
|
|
|
|
|
|
def analyze_pipes(
|
|
nlp: "Language", name: str, index: int, warn: bool = True
|
|
) -> List[str]:
|
|
"""Analyze a pipeline component with respect to its position in the current
|
|
pipeline and the other components. Will check whether requirements are
|
|
fulfilled (e.g. if previous components assign the attributes).
|
|
|
|
nlp (Language): The current nlp object.
|
|
name (str): The name of the pipeline component to analyze.
|
|
index (int): The index of the component in the pipeline.
|
|
warn (bool): Show user warning if problem is found.
|
|
RETURNS (List[str]): The problems found for the given pipeline component.
|
|
"""
|
|
assert nlp.pipeline[index][0] == name
|
|
prev_pipes = nlp.pipeline[:index]
|
|
meta = nlp.get_pipe_meta(name)
|
|
requires = {annot: False for annot in meta.requires}
|
|
if requires:
|
|
for prev_name, prev_pipe in prev_pipes:
|
|
prev_meta = nlp.get_pipe_meta(prev_name)
|
|
for annot in prev_meta.assigns:
|
|
requires[annot] = True
|
|
problems = []
|
|
for annot, fulfilled in requires.items():
|
|
if not fulfilled:
|
|
problems.append(annot)
|
|
if warn:
|
|
warnings.warn(Warnings.W025.format(name=name, attr=annot))
|
|
return problems
|
|
|
|
|
|
def analyze_all_pipes(nlp: "Language", warn: bool = True) -> Dict[str, List[str]]:
|
|
"""Analyze all pipes in the pipeline in order.
|
|
|
|
nlp (Language): The current nlp object.
|
|
warn (bool): Show user warning if problem is found.
|
|
RETURNS (Dict[str, List[str]]): The problems found, keyed by component name.
|
|
"""
|
|
problems = {}
|
|
for i, name in enumerate(nlp.pipe_names):
|
|
problems[name] = analyze_pipes(nlp, name, i, warn=warn)
|
|
return problems
|
|
|
|
|
|
def validate_attrs(values: Iterable[str]) -> Iterable[str]:
|
|
"""Validate component attributes provided to "assigns", "requires" etc.
|
|
Raises error for invalid attributes and formatting. Doesn't check if
|
|
custom extension attributes are registered, since this is something the
|
|
user might want to do themselves later in the component.
|
|
|
|
values (Iterable[str]): The string attributes to check, e.g. `["token.pos"]`.
|
|
RETURNS (Iterable[str]): The checked attributes.
|
|
"""
|
|
data = dot_to_dict({value: True for value in values})
|
|
objs = {"doc": Doc, "token": Token, "span": Span}
|
|
for obj_key, attrs in data.items():
|
|
if obj_key == "span":
|
|
# Support Span only for custom extension attributes
|
|
span_attrs = [attr for attr in values if attr.startswith("span.")]
|
|
span_attrs = [attr for attr in span_attrs if not attr.startswith("span._.")]
|
|
if span_attrs:
|
|
raise ValueError(Errors.E180.format(attrs=", ".join(span_attrs)))
|
|
if obj_key not in objs: # first element is not doc/token/span
|
|
invalid_attrs = ", ".join(a for a in values if a.startswith(obj_key))
|
|
raise ValueError(Errors.E181.format(obj=obj_key, attrs=invalid_attrs))
|
|
if not isinstance(attrs, dict): # attr is something like "doc"
|
|
raise ValueError(Errors.E182.format(attr=obj_key))
|
|
for attr, value in attrs.items():
|
|
if attr == "_":
|
|
if value is True: # attr is something like "doc._"
|
|
raise ValueError(Errors.E182.format(attr="{}._".format(obj_key)))
|
|
for ext_attr, ext_value in value.items():
|
|
# We don't check whether the attribute actually exists
|
|
if ext_value is not True: # attr is something like doc._.x.y
|
|
good = f"{obj_key}._.{ext_attr}"
|
|
bad = f"{good}.{'.'.join(ext_value)}"
|
|
raise ValueError(Errors.E183.format(attr=bad, solution=good))
|
|
continue # we can't validate those further
|
|
if attr.endswith("_"): # attr is something like "token.pos_"
|
|
raise ValueError(Errors.E184.format(attr=attr, solution=attr[:-1]))
|
|
if value is not True: # attr is something like doc.x.y
|
|
good = f"{obj_key}.{attr}"
|
|
bad = f"{good}.{'.'.join(value)}"
|
|
raise ValueError(Errors.E183.format(attr=bad, solution=good))
|
|
obj = objs[obj_key]
|
|
if not hasattr(obj, attr):
|
|
raise ValueError(Errors.E185.format(obj=obj_key, attr=attr))
|
|
return values
|
|
|
|
|
|
def _get_feature_for_attr(nlp: "Language", attr: str, feature: str) -> List[str]:
|
|
assert feature in ["assigns", "requires"]
|
|
result = []
|
|
for pipe_name in nlp.pipe_names:
|
|
meta = nlp.get_pipe_meta(pipe_name)
|
|
pipe_assigns = getattr(meta, feature, [])
|
|
if attr in pipe_assigns:
|
|
result.append(pipe_name)
|
|
return result
|
|
|
|
|
|
def get_assigns_for_attr(nlp: "Language", attr: str) -> List[str]:
|
|
"""Get all pipeline components that assign an attr, e.g. "doc.tensor".
|
|
|
|
pipeline (Language): The current nlp object.
|
|
attr (str): The attribute to check.
|
|
RETURNS (List[str]): Names of components that require the attr.
|
|
"""
|
|
return _get_feature_for_attr(nlp, attr, "assigns")
|
|
|
|
|
|
def get_requires_for_attr(nlp: "Language", attr: str) -> List[str]:
|
|
"""Get all pipeline components that require an attr, e.g. "doc.tensor".
|
|
|
|
pipeline (Language): The current nlp object.
|
|
attr (str): The attribute to check.
|
|
RETURNS (List[str]): Names of components that require the attr.
|
|
"""
|
|
return _get_feature_for_attr(nlp, attr, "requires")
|
|
|
|
|
|
def print_summary(
|
|
nlp: "Language", pretty: bool = True, no_print: bool = False
|
|
) -> Optional[Dict[str, Union[List[str], Dict[str, List[str]]]]]:
|
|
"""Print a formatted summary for the current nlp object's pipeline. Shows
|
|
a table with the pipeline components and why they assign and require, as
|
|
well as any problems if available.
|
|
|
|
nlp (Language): The nlp object.
|
|
pretty (bool): Pretty-print the results (color etc).
|
|
no_print (bool): Don't print anything, just return the data.
|
|
RETURNS (dict): A dict with "overview" and "problems".
|
|
"""
|
|
msg = Printer(pretty=pretty, no_print=no_print)
|
|
overview = []
|
|
problems = {}
|
|
for i, name in enumerate(nlp.pipe_names):
|
|
meta = nlp.get_pipe_meta(name)
|
|
overview.append((i, name, meta.requires, meta.assigns, meta.retokenizes))
|
|
problems[name] = analyze_pipes(nlp, name, i, warn=False)
|
|
msg.divider("Pipeline Overview")
|
|
header = ("#", "Component", "Requires", "Assigns", "Retokenizes")
|
|
msg.table(overview, header=header, divider=True, multiline=True)
|
|
n_problems = sum(len(p) for p in problems.values())
|
|
if any(p for p in problems.values()):
|
|
msg.divider(f"Problems ({n_problems})")
|
|
for name, problem in problems.items():
|
|
if problem:
|
|
msg.warn(f"'{name}' requirements not met: {', '.join(problem)}")
|
|
else:
|
|
msg.good("No problems found.")
|
|
if no_print:
|
|
return {"overview": overview, "problems": problems}
|
|
|
|
|
|
def count_pipeline_interdependencies(nlp: "Language") -> List[int]:
|
|
"""Count how many subsequent components require an annotation set by each
|
|
component in the pipeline.
|
|
|
|
nlp (Language): The current nlp object.
|
|
RETURNS (List[int]): The interdependency counts.
|
|
"""
|
|
pipe_assigns = []
|
|
pipe_requires = []
|
|
for name in nlp.pipe_names:
|
|
meta = nlp.get_pipe_meta(name)
|
|
pipe_assigns.append(set(meta.assigns))
|
|
pipe_requires.append(set(meta.requires))
|
|
counts = []
|
|
for i, assigns in enumerate(pipe_assigns):
|
|
count = 0
|
|
for requires in pipe_requires[i + 1 :]:
|
|
if assigns.intersection(requires):
|
|
count += 1
|
|
counts.append(count)
|
|
return counts
|