diff --git a/spacy/language.py b/spacy/language.py index d71c27406..afc988583 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -545,13 +545,14 @@ class Language(object): if component_cfg is None: component_cfg = {} + component_deps = _count_pipeline_inter_dependencies(self.pipeline) # Determine whether component should set annotations. In theory I guess # we should do this by inspecting the meta? Or we could just always # say "yes" - for name, proc in self.pipeline: + for i, (name, proc) in enumerate(self.pipeline): component_cfg.setdefault(name, {}) component_cfg[name].setdefault("drop", drop) - component_cfg[name].setdefault("set_annotations", False) + component_cfg[name]["set_annotations"] = bool(component_deps[i]) for name, proc in self.pipeline: if not hasattr(proc, "update"): continue @@ -1159,6 +1160,25 @@ class DisabledPipes(list): self[:] = [] +def _count_pipeline_inter_dependencies(pipeline): + """Count how many subsequent components require an annotation set by each + component in the pipeline. + """ + pipe_assigns = [] + pipe_requires = [] + for name, pipe in pipeline: + pipe_assigns.append(set(getattr(pipe, "assigns", []))) + pipe_requires.append(set(getattr(pipe, "requires", []))) + counts = [] + for i, assigns in enumerate(pipe_assigns): + count = 0 + for requires in pipe_requires[i+1:]: + if assigns.intersection(requires): + count += 1 + counts.append(count) + return counts + + def _pipe(examples, proc, kwargs): # We added some args for pipe that __call__ doesn't expect. kwargs = dict(kwargs) diff --git a/spacy/tests/pipeline/test_pipe_methods.py b/spacy/tests/pipeline/test_pipe_methods.py index d42216655..0397d490d 100644 --- a/spacy/tests/pipeline/test_pipe_methods.py +++ b/spacy/tests/pipeline/test_pipe_methods.py @@ -1,5 +1,5 @@ import pytest -from spacy.language import Language +from spacy.language import Language, _count_pipeline_inter_dependencies @pytest.fixture @@ -198,3 +198,19 @@ def test_pipe_labels(nlp): assert len(nlp.pipe_labels) == len(input_labels) for name, labels in nlp.pipe_labels.items(): assert sorted(input_labels[name]) == sorted(labels) + + +def test_pipe_inter_dependencies(): + class Fancifier: + name = "fancifier" + assigns = ("doc._.fancy",) + requires = tuple() + + class FancyNeeder: + name = "needer" + assigns = tuple() + requires = ("doc._.fancy",) + + pipeline = [("fancifier", Fancifier()), ("needer", FancyNeeder())] + counts = _count_pipeline_inter_dependencies(pipeline) + assert counts == [1, 0]