From 911dfcccfcdfe1e9effc38b013f0266dbbd79afb Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 29 Jan 2021 15:57:04 +1100 Subject: [PATCH] Add option to replace listeners for sourced components --- spacy/language.py | 43 +++++++++++++++++++++++++++- spacy/tests/pipeline/test_tok2vec.py | 32 ++++++++++++++++++--- spacy/tests/test_misc.py | 19 ++++++++++++ spacy/training/initialize.py | 28 ++++++++++++------ spacy/util.py | 39 ++++++++++++++++++++++++- 5 files changed, 146 insertions(+), 15 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 6e617e31c..7749ba360 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -8,7 +8,7 @@ from contextlib import contextmanager from copy import deepcopy from pathlib import Path import warnings -from thinc.api import Model, get_current_ops, Config, Optimizer +from thinc.api import get_current_ops, Config, Optimizer import srsly import multiprocessing as mp from itertools import chain, cycle @@ -670,6 +670,47 @@ class Language: self._pipe_configs[name] = filled return resolved[factory_name] + def replace_listeners( + self, + tok2vec_name: str, + pipe_name: str, + listeners: Iterable[str] = SimpleFrozenList(), + ): + if tok2vec_name not in self.pipe_names: + raise ValueError # TODO: + if pipe_name not in self.pipe_names: + raise ValueError # TODO: + tok2vec = self.get_pipe(tok2vec_name) + tok2vec_cfg = self.get_pipe_config(tok2vec_name) + if ( + not hasattr(tok2vec, "model") + or not hasattr(tok2vec, "listener_map") + or "model" not in tok2vec_cfg + ): + raise ValueError # TODO: likely bug in spaCy if this happens + pipe_listeners = tok2vec.listener_map.get(pipe_name, []) + pipe_cfg = self._pipe_configs[pipe_name] + if listeners: + util.logger.debug(f"Replacing listeners of component '{pipe_name}'") + if len(listeners) != len(pipe_listeners): + # The number of listeners defined in the component model doesn't + # match the listeners to replace, so we won't be able to update + # the nodes and generate a matching config + raise ValueError(f"{listeners}, {pipe_listeners}") # TODO: + pipe = self.get_pipe(pipe_name) + # Go over the listener layers and replace them + for listener in pipe_listeners: + util.replace_model_node(pipe.model, listener, tok2vec.model.copy()) + # Update the config accordingly by coping the tok2vec model to all + # sections defined in the listener paths + for listener_path in listeners: + # Check if the path actually exists in the config + try: + util.dot_to_object(pipe_cfg, listener_path) + except KeyError: + raise ValueError # TODO: + util.set_dot_to_object(pipe_cfg, listener_path, tok2vec_cfg["model"]) + def create_pipe_from_source( self, source_name: str, source: "Language", *, name: str ) -> Tuple[Callable[[Doc], Doc], str]: diff --git a/spacy/tests/pipeline/test_tok2vec.py b/spacy/tests/pipeline/test_tok2vec.py index 90052a9c8..56037e4b8 100644 --- a/spacy/tests/pipeline/test_tok2vec.py +++ b/spacy/tests/pipeline/test_tok2vec.py @@ -1,5 +1,4 @@ import pytest - from spacy.ml.models.tok2vec import build_Tok2Vec_model from spacy.ml.models.tok2vec import MultiHashEmbed, CharacterEmbed from spacy.ml.models.tok2vec import MishWindowEncoder, MaxoutWindowEncoder @@ -9,12 +8,11 @@ from spacy.tokens import Doc from spacy.training import Example from spacy import util from spacy.lang.en import English -from ..util import get_batch - from thinc.api import Config - from numpy.testing import assert_equal +from ..util import get_batch + def test_empty_doc(): width = 128 @@ -187,3 +185,29 @@ def test_tok2vec_listener_callback(): Y, get_dX = tagger.model.begin_update(docs) # assure that the backprop call works (and doesn't hit a 'None' callback) assert get_dX(Y) is not None + + +def test_replace_listeners(): + orig_config = Config().from_str(cfg_string) + nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True) + examples = [Example.from_dict(nlp.make_doc("x y"), {"tags": ["V", "Z"]})] + nlp.initialize(lambda: examples) + tok2vec = nlp.get_pipe("tok2vec") + tagger = nlp.get_pipe("tagger") + assert isinstance(tagger.model.layers[0], Tok2VecListener) + assert tok2vec.listener_map["tagger"][0] == tagger.model.layers[0] + assert nlp.config["components"]["tok2vec"]["model"]["@architectures"] == "spacy.Tok2Vec.v2" + assert nlp.config["components"]["tagger"]["model"]["tok2vec"]["@architectures"] == "spacy.Tok2VecListener.v1" + nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec"]) + assert not isinstance(tagger.model.layers[0], Tok2VecListener) + t2v_cfg = nlp.config["components"]["tok2vec"]["model"] + assert t2v_cfg["@architectures"] == "spacy.Tok2Vec.v2" + assert nlp.config["components"]["tagger"]["model"]["tok2vec"] == t2v_cfg + with pytest.raises(ValueError): + nlp.replace_listeners("invalid", "tagger", ["model.tok2vec"]) + with pytest.raises(ValueError): + nlp.replace_listeners("tok2vec", "parser", ["model.tok2vec"]) + with pytest.raises(ValueError): + nlp.replace_listeners("tok2vec", "tagger", ["model.yolo"]) + with pytest.raises(ValueError): + nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec", "model.yolo"]) diff --git a/spacy/tests/test_misc.py b/spacy/tests/test_misc.py index bdb2b9752..e694baa40 100644 --- a/spacy/tests/test_misc.py +++ b/spacy/tests/test_misc.py @@ -205,6 +205,25 @@ def test_dot_to_dict(dot_notation, expected): assert util.dict_to_dot(result) == dot_notation +def test_set_dot_to_object(): + config = {"foo": {"bar": 1, "baz": {"x": "y"}}, "test": {"a": {"b": "c"}}} + with pytest.raises(KeyError): + util.set_dot_to_object(config, "foo.bar.baz", 100) + with pytest.raises(KeyError): + util.set_dot_to_object(config, "hello.world", 100) + with pytest.raises(KeyError): + util.set_dot_to_object(config, "test.a.b.c", 100) + util.set_dot_to_object(config, "foo.bar", 100) + assert config["foo"]["bar"] == 100 + util.set_dot_to_object(config, "foo.baz.x", {"hello": "world"}) + assert config["foo"]["baz"]["x"]["hello"] == "world" + assert config["test"]["a"]["b"] == "c" + util.set_dot_to_object(config, "foo", 123) + assert config["foo"] == 123 + util.set_dot_to_object(config, "test", "hello") + assert dict(config) == {"foo": 123, "test": "hello"} + + @pytest.mark.parametrize( "doc_sizes, expected_batches", [ diff --git a/spacy/training/initialize.py b/spacy/training/initialize.py index 42bab6b4f..4cf8fa354 100644 --- a/spacy/training/initialize.py +++ b/spacy/training/initialize.py @@ -1,4 +1,4 @@ -from typing import Union, Dict, Optional, Any, List, IO, TYPE_CHECKING +from typing import Union, Dict, Optional, Any, IO, TYPE_CHECKING from thinc.api import Config, fix_random_seed, set_gpu_allocator from thinc.api import ConfigValidationError from pathlib import Path @@ -33,7 +33,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language": if use_gpu >= 0 and allocator: set_gpu_allocator(allocator) # Use original config here before it's resolved to functions - sourced_components = get_sourced_components(config) + sourced = get_sourced_components(config) nlp = load_model_from_config(raw_config, auto_fill=True) logger.info("Set up nlp object from config") config = nlp.config.interpolate() @@ -57,7 +57,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language": # Components that shouldn't be updated during training frozen_components = T["frozen_components"] # Sourced components that require resume_training - resume_components = [p for p in sourced_components if p not in frozen_components] + resume_components = [p for p in sourced if p not in frozen_components] logger.info(f"Pipeline: {nlp.pipe_names}") if resume_components: with nlp.select_pipes(enable=resume_components): @@ -68,10 +68,18 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language": logger.info(f"Initialized pipeline components: {nlp.pipe_names}") # Detect components with listeners that are not frozen consistently for name, proc in nlp.pipeline: - if getattr(proc, "listening_components", None): + if getattr(proc, "listening_components", None): # e.g. tok2vec/transformer for listener in proc.listening_components: - if listener in frozen_components and name not in frozen_components: + # If it's a component sourced from another pipeline, we check if + # the tok2vec listeners should be replaced with standalone tok2vec + # models (e.g. so component can be frozen without its performance + # degrading when other components/tok2vec are updated) + paths = sourced.get(listener, {}).get("replace_listeners", []) + if paths: + nlp.replace_listeners(name, listener, paths) + elif listener in frozen_components and name not in frozen_components: logger.warning(Warnings.W087.format(name=name, listener=listener)) + # We always check this regardless, in case user freezes tok2vec if listener not in frozen_components and name in frozen_components: logger.warning(Warnings.W086.format(name=name, listener=listener)) return nlp @@ -173,16 +181,18 @@ def init_tok2vec( return False -def get_sourced_components(config: Union[Dict[str, Any], Config]) -> List[str]: +def get_sourced_components( + config: Union[Dict[str, Any], Config] +) -> Dict[str, Dict[str, Any]]: """RETURNS (List[str]): All sourced components in the original config, e.g. {"source": "en_core_web_sm"}. If the config contains a key "factory", we assume it refers to a component factory. """ - return [ - name + return { + name: cfg for name, cfg in config.get("components", {}).items() if "factory" not in cfg and "source" in cfg - ] + } def convert_vectors( diff --git a/spacy/util.py b/spacy/util.py index 77aa712d1..dbd862687 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -8,7 +8,7 @@ import re from pathlib import Path import thinc from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer -from thinc.api import ConfigValidationError +from thinc.api import ConfigValidationError, Model import functools import itertools import numpy.random @@ -738,6 +738,24 @@ def get_package_path(name: str) -> Path: return Path(pkg.__file__).parent +def replace_model_node(model: Model, target: Model, replacement: Model) -> None: + """Replace a node within a model with a new one, updating refs. + + model (Model): The parent model. + target (Model): The target node. + replacement (Model): The node to replace the target with. + """ + # Place the node into the sublayers + for node in model.walk(): + if target in node.layers: + node.layers[node.layers.index(target)] = replacement + # Now fix any node references + for node in model.walk(): + for ref_name in node.ref_names: + if node.maybe_get_ref(ref_name) is target: + node.set_ref(ref_name, replacement) + + def split_command(command: str) -> List[str]: """Split a string command using shlex. Handles platform compatibility. @@ -1279,6 +1297,25 @@ def dot_to_object(config: Config, section: str): return component +def set_dot_to_object(config: Config, section: str, value: Any) -> None: + """Update a config at a given position from a dot notation. + + config (Config): The config. + section (str): The dot notation of the section in the config. + value (Any): The value to set in the config. + """ + component = config + parts = section.split(".") + for i, item in enumerate(parts): + try: + if i == len(parts) - 1: + component[item] = value + else: + component = component[item] + except (KeyError, TypeError): + raise KeyError(Errors.E952.format(name=section)) from None + + def walk_dict( node: Dict[str, Any], parent: List[str] = [] ) -> Iterator[Tuple[List[str], Any]]: