From 5a38f79f1863d738dd9fb1a9febca94da489b297 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Thu, 21 Oct 2021 15:31:06 +0200 Subject: [PATCH] Custom component types in spacy.ty (#9469) * add custom protocols in spacy.ty * add a test for the new types in spacy.ty * import Example when type checking * some type fixes * put Protocol in compat * revert update check back to hasattr * runtime_checkable in compat as well --- spacy/compat.py | 4 +-- spacy/language.py | 64 ++++++++++++++++++--------------------- spacy/pipeline/spancat.py | 2 +- spacy/tests/test_ty.py | 18 +++++++++++ spacy/ty.py | 55 +++++++++++++++++++++++++++++++++ 5 files changed, 106 insertions(+), 37 deletions(-) create mode 100644 spacy/tests/test_ty.py create mode 100644 spacy/ty.py diff --git a/spacy/compat.py b/spacy/compat.py index a0066741a..89132735d 100644 --- a/spacy/compat.py +++ b/spacy/compat.py @@ -23,9 +23,9 @@ except ImportError: cupy = None if sys.version_info[:2] >= (3, 8): # Python 3.8+ - from typing import Literal + from typing import Literal, Protocol, runtime_checkable else: - from typing_extensions import Literal # noqa: F401 + from typing_extensions import Literal, Protocol, runtime_checkable # noqa: F401 # Important note: The importlib_metadata "backport" includes functionality # that's not part of the built-in importlib.metadata. We should treat this diff --git a/spacy/language.py b/spacy/language.py index 37fdf9e0d..2dfb43b73 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -17,6 +17,7 @@ from itertools import chain, cycle from timeit import default_timer as timer import traceback +from . import ty from .tokens.underscore import Underscore from .vocab import Vocab, create_vocab from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis @@ -1135,11 +1136,11 @@ class Language: if sgd not in (None, False): if ( name not in exclude - and hasattr(proc, "is_trainable") + and isinstance(proc, ty.TrainableComponent) and proc.is_trainable - and proc.model not in (True, False, None) # type: ignore + and proc.model not in (True, False, None) ): - proc.finish_update(sgd) # type: ignore + proc.finish_update(sgd) if name in annotates: for doc, eg in zip( _pipe( @@ -1278,12 +1279,12 @@ class Language: ) self.tokenizer.initialize(get_examples, nlp=self, **tok_settings) # type: ignore[union-attr] for name, proc in self.pipeline: - if hasattr(proc, "initialize"): + if isinstance(proc, ty.InitializableComponent): p_settings = I["components"].get(name, {}) p_settings = validate_init_settings( proc.initialize, p_settings, section="components", name=name ) - proc.initialize(get_examples, nlp=self, **p_settings) # type: ignore[call-arg] + proc.initialize(get_examples, nlp=self, **p_settings) pretrain_cfg = config.get("pretraining") if pretrain_cfg: P = registry.resolve(pretrain_cfg, schema=ConfigSchemaPretrain) @@ -1622,9 +1623,9 @@ class Language: # components don't receive the pipeline then. So this does have to be # here :( for i, (name1, proc1) in enumerate(self.pipeline): - if hasattr(proc1, "find_listeners"): + if isinstance(proc1, ty.ListenedToComponent): for name2, proc2 in self.pipeline[i + 1 :]: - proc1.find_listeners(proc2) # type: ignore[attr-defined] + proc1.find_listeners(proc2) @classmethod def from_config( @@ -1810,25 +1811,25 @@ class Language: ) # Detect components with listeners that are not frozen consistently for name, proc in nlp.pipeline: - # Remove listeners not in the pipeline - listener_names = getattr(proc, "listening_components", []) - unused_listener_names = [ - ll for ll in listener_names if ll not in nlp.pipe_names - ] - for listener_name in unused_listener_names: - for listener in proc.listener_map.get(listener_name, []): # type: ignore[attr-defined] - proc.remove_listener(listener, listener_name) # type: ignore[attr-defined] + if isinstance(proc, ty.ListenedToComponent): + # Remove listeners not in the pipeline + listener_names = proc.listening_components + unused_listener_names = [ + ll for ll in listener_names if ll not in nlp.pipe_names + ] + for listener_name in unused_listener_names: + for listener in proc.listener_map.get(listener_name, []): + proc.remove_listener(listener, listener_name) - for listener in getattr( - proc, "listening_components", [] - ): # e.g. tok2vec/transformer - # If it's a component sourced from another pipeline, we check if - # the tok2vec listeners should be replaced with standalone tok2vec - # models (e.g. so component can be frozen without its performance - # degrading when other components/tok2vec are updated) - paths = sourced.get(listener, {}).get("replace_listeners", []) - if paths: - nlp.replace_listeners(name, listener, paths) + for listener_name in proc.listening_components: + # e.g. tok2vec/transformer + # If it's a component sourced from another pipeline, we check if + # the tok2vec listeners should be replaced with standalone tok2vec + # models (e.g. so component can be frozen without its performance + # degrading when other components/tok2vec are updated) + paths = sourced.get(listener_name, {}).get("replace_listeners", []) + if paths: + nlp.replace_listeners(name, listener_name, paths) return nlp def replace_listeners( @@ -1878,15 +1879,10 @@ class Language: raise ValueError(err) tok2vec = self.get_pipe(tok2vec_name) tok2vec_cfg = self.get_pipe_config(tok2vec_name) - if ( - not hasattr(tok2vec, "model") - or not hasattr(tok2vec, "listener_map") - or not hasattr(tok2vec, "remove_listener") - or "model" not in tok2vec_cfg - ): + if not isinstance(tok2vec, ty.ListenedToComponent): raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec))) - tok2vec_model = tok2vec.model # type: ignore[attr-defined] - pipe_listeners = tok2vec.listener_map.get(pipe_name, []) # type: ignore[attr-defined] + tok2vec_model = tok2vec.model + pipe_listeners = tok2vec.listener_map.get(pipe_name, []) pipe = self.get_pipe(pipe_name) pipe_cfg = self._pipe_configs[pipe_name] if listeners: @@ -1926,7 +1922,7 @@ class Language: if "replace_listener" in tok2vec_model.attrs: new_model = tok2vec_model.attrs["replace_listener"](new_model) util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined] - tok2vec.remove_listener(listener, pipe_name) # type: ignore[attr-defined] + tok2vec.remove_listener(listener, pipe_name) def to_disk( self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList() diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index 4e9a82423..84a9b69cc 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -1,10 +1,10 @@ import numpy from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast -from typing_extensions import Protocol, runtime_checkable from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops from thinc.api import Optimizer from thinc.types import Ragged, Ints2d, Floats2d, Ints1d +from ..compat import Protocol, runtime_checkable from ..scorer import Scorer from ..language import Language from .trainable_pipe import TrainablePipe diff --git a/spacy/tests/test_ty.py b/spacy/tests/test_ty.py new file mode 100644 index 000000000..2037520df --- /dev/null +++ b/spacy/tests/test_ty.py @@ -0,0 +1,18 @@ +import spacy +from spacy import ty + + +def test_component_types(): + nlp = spacy.blank("en") + tok2vec = nlp.create_pipe("tok2vec") + tagger = nlp.create_pipe("tagger") + entity_ruler = nlp.create_pipe("entity_ruler") + assert isinstance(tok2vec, ty.TrainableComponent) + assert isinstance(tagger, ty.TrainableComponent) + assert not isinstance(entity_ruler, ty.TrainableComponent) + assert isinstance(tok2vec, ty.InitializableComponent) + assert isinstance(tagger, ty.InitializableComponent) + assert isinstance(entity_ruler, ty.InitializableComponent) + assert isinstance(tok2vec, ty.ListenedToComponent) + assert not isinstance(tagger, ty.ListenedToComponent) + assert not isinstance(entity_ruler, ty.ListenedToComponent) diff --git a/spacy/ty.py b/spacy/ty.py new file mode 100644 index 000000000..8f2903d78 --- /dev/null +++ b/spacy/ty.py @@ -0,0 +1,55 @@ +from typing import TYPE_CHECKING +from typing import Optional, Any, Iterable, Dict, Callable, Sequence, List +from .compat import Protocol, runtime_checkable + +from thinc.api import Optimizer, Model + +if TYPE_CHECKING: + from .training import Example + + +@runtime_checkable +class TrainableComponent(Protocol): + model: Any + is_trainable: bool + + def update( + self, + examples: Iterable["Example"], + *, + drop: float = 0.0, + sgd: Optional[Optimizer] = None, + losses: Optional[Dict[str, float]] = None + ) -> Dict[str, float]: + ... + + def finish_update(self, sgd: Optimizer) -> None: + ... + + +@runtime_checkable +class InitializableComponent(Protocol): + def initialize( + self, + get_examples: Callable[[], Iterable["Example"]], + nlp: Iterable["Example"], + **kwargs: Any + ): + ... + + +@runtime_checkable +class ListenedToComponent(Protocol): + model: Any + listeners: Sequence[Model] + listener_map: Dict[str, Sequence[Model]] + listening_components: List[str] + + def add_listener(self, listener: Model, component_name: str) -> None: + ... + + def remove_listener(self, listener: Model, component_name: str) -> bool: + ... + + def find_listeners(self, component) -> None: + ...