Add option to replace listeners for sourced components

This commit is contained in:
Ines Montani 2021-01-29 15:57:04 +11:00
parent 78d6ff4dd4
commit 911dfcccfc
5 changed files with 146 additions and 15 deletions

View File

@ -8,7 +8,7 @@ from contextlib import contextmanager
from copy import deepcopy
from pathlib import Path
import warnings
from thinc.api import Model, get_current_ops, Config, Optimizer
from thinc.api import get_current_ops, Config, Optimizer
import srsly
import multiprocessing as mp
from itertools import chain, cycle
@ -670,6 +670,47 @@ class Language:
self._pipe_configs[name] = filled
return resolved[factory_name]
def replace_listeners(
self,
tok2vec_name: str,
pipe_name: str,
listeners: Iterable[str] = SimpleFrozenList(),
):
if tok2vec_name not in self.pipe_names:
raise ValueError # TODO:
if pipe_name not in self.pipe_names:
raise ValueError # TODO:
tok2vec = self.get_pipe(tok2vec_name)
tok2vec_cfg = self.get_pipe_config(tok2vec_name)
if (
not hasattr(tok2vec, "model")
or not hasattr(tok2vec, "listener_map")
or "model" not in tok2vec_cfg
):
raise ValueError # TODO: likely bug in spaCy if this happens
pipe_listeners = tok2vec.listener_map.get(pipe_name, [])
pipe_cfg = self._pipe_configs[pipe_name]
if listeners:
util.logger.debug(f"Replacing listeners of component '{pipe_name}'")
if len(listeners) != len(pipe_listeners):
# The number of listeners defined in the component model doesn't
# match the listeners to replace, so we won't be able to update
# the nodes and generate a matching config
raise ValueError(f"{listeners}, {pipe_listeners}") # TODO:
pipe = self.get_pipe(pipe_name)
# Go over the listener layers and replace them
for listener in pipe_listeners:
util.replace_model_node(pipe.model, listener, tok2vec.model.copy())
# Update the config accordingly by coping the tok2vec model to all
# sections defined in the listener paths
for listener_path in listeners:
# Check if the path actually exists in the config
try:
util.dot_to_object(pipe_cfg, listener_path)
except KeyError:
raise ValueError # TODO:
util.set_dot_to_object(pipe_cfg, listener_path, tok2vec_cfg["model"])
def create_pipe_from_source(
self, source_name: str, source: "Language", *, name: str
) -> Tuple[Callable[[Doc], Doc], str]:

View File

@ -1,5 +1,4 @@
import pytest
from spacy.ml.models.tok2vec import build_Tok2Vec_model
from spacy.ml.models.tok2vec import MultiHashEmbed, CharacterEmbed
from spacy.ml.models.tok2vec import MishWindowEncoder, MaxoutWindowEncoder
@ -9,12 +8,11 @@ from spacy.tokens import Doc
from spacy.training import Example
from spacy import util
from spacy.lang.en import English
from ..util import get_batch
from thinc.api import Config
from numpy.testing import assert_equal
from ..util import get_batch
def test_empty_doc():
width = 128
@ -187,3 +185,29 @@ def test_tok2vec_listener_callback():
Y, get_dX = tagger.model.begin_update(docs)
# assure that the backprop call works (and doesn't hit a 'None' callback)
assert get_dX(Y) is not None
def test_replace_listeners():
orig_config = Config().from_str(cfg_string)
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
examples = [Example.from_dict(nlp.make_doc("x y"), {"tags": ["V", "Z"]})]
nlp.initialize(lambda: examples)
tok2vec = nlp.get_pipe("tok2vec")
tagger = nlp.get_pipe("tagger")
assert isinstance(tagger.model.layers[0], Tok2VecListener)
assert tok2vec.listener_map["tagger"][0] == tagger.model.layers[0]
assert nlp.config["components"]["tok2vec"]["model"]["@architectures"] == "spacy.Tok2Vec.v2"
assert nlp.config["components"]["tagger"]["model"]["tok2vec"]["@architectures"] == "spacy.Tok2VecListener.v1"
nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec"])
assert not isinstance(tagger.model.layers[0], Tok2VecListener)
t2v_cfg = nlp.config["components"]["tok2vec"]["model"]
assert t2v_cfg["@architectures"] == "spacy.Tok2Vec.v2"
assert nlp.config["components"]["tagger"]["model"]["tok2vec"] == t2v_cfg
with pytest.raises(ValueError):
nlp.replace_listeners("invalid", "tagger", ["model.tok2vec"])
with pytest.raises(ValueError):
nlp.replace_listeners("tok2vec", "parser", ["model.tok2vec"])
with pytest.raises(ValueError):
nlp.replace_listeners("tok2vec", "tagger", ["model.yolo"])
with pytest.raises(ValueError):
nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec", "model.yolo"])

View File

@ -205,6 +205,25 @@ def test_dot_to_dict(dot_notation, expected):
assert util.dict_to_dot(result) == dot_notation
def test_set_dot_to_object():
config = {"foo": {"bar": 1, "baz": {"x": "y"}}, "test": {"a": {"b": "c"}}}
with pytest.raises(KeyError):
util.set_dot_to_object(config, "foo.bar.baz", 100)
with pytest.raises(KeyError):
util.set_dot_to_object(config, "hello.world", 100)
with pytest.raises(KeyError):
util.set_dot_to_object(config, "test.a.b.c", 100)
util.set_dot_to_object(config, "foo.bar", 100)
assert config["foo"]["bar"] == 100
util.set_dot_to_object(config, "foo.baz.x", {"hello": "world"})
assert config["foo"]["baz"]["x"]["hello"] == "world"
assert config["test"]["a"]["b"] == "c"
util.set_dot_to_object(config, "foo", 123)
assert config["foo"] == 123
util.set_dot_to_object(config, "test", "hello")
assert dict(config) == {"foo": 123, "test": "hello"}
@pytest.mark.parametrize(
"doc_sizes, expected_batches",
[

View File

@ -1,4 +1,4 @@
from typing import Union, Dict, Optional, Any, List, IO, TYPE_CHECKING
from typing import Union, Dict, Optional, Any, IO, TYPE_CHECKING
from thinc.api import Config, fix_random_seed, set_gpu_allocator
from thinc.api import ConfigValidationError
from pathlib import Path
@ -33,7 +33,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator)
# Use original config here before it's resolved to functions
sourced_components = get_sourced_components(config)
sourced = get_sourced_components(config)
nlp = load_model_from_config(raw_config, auto_fill=True)
logger.info("Set up nlp object from config")
config = nlp.config.interpolate()
@ -57,7 +57,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
# Components that shouldn't be updated during training
frozen_components = T["frozen_components"]
# Sourced components that require resume_training
resume_components = [p for p in sourced_components if p not in frozen_components]
resume_components = [p for p in sourced if p not in frozen_components]
logger.info(f"Pipeline: {nlp.pipe_names}")
if resume_components:
with nlp.select_pipes(enable=resume_components):
@ -68,10 +68,18 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
# Detect components with listeners that are not frozen consistently
for name, proc in nlp.pipeline:
if getattr(proc, "listening_components", None):
if getattr(proc, "listening_components", None): # e.g. tok2vec/transformer
for listener in proc.listening_components:
if listener in frozen_components and name not in frozen_components:
# If it's a component sourced from another pipeline, we check if
# the tok2vec listeners should be replaced with standalone tok2vec
# models (e.g. so component can be frozen without its performance
# degrading when other components/tok2vec are updated)
paths = sourced.get(listener, {}).get("replace_listeners", [])
if paths:
nlp.replace_listeners(name, listener, paths)
elif listener in frozen_components and name not in frozen_components:
logger.warning(Warnings.W087.format(name=name, listener=listener))
# We always check this regardless, in case user freezes tok2vec
if listener not in frozen_components and name in frozen_components:
logger.warning(Warnings.W086.format(name=name, listener=listener))
return nlp
@ -173,16 +181,18 @@ def init_tok2vec(
return False
def get_sourced_components(config: Union[Dict[str, Any], Config]) -> List[str]:
def get_sourced_components(
config: Union[Dict[str, Any], Config]
) -> Dict[str, Dict[str, Any]]:
"""RETURNS (List[str]): All sourced components in the original config,
e.g. {"source": "en_core_web_sm"}. If the config contains a key
"factory", we assume it refers to a component factory.
"""
return [
name
return {
name: cfg
for name, cfg in config.get("components", {}).items()
if "factory" not in cfg and "source" in cfg
]
}
def convert_vectors(

View File

@ -8,7 +8,7 @@ import re
from pathlib import Path
import thinc
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
from thinc.api import ConfigValidationError
from thinc.api import ConfigValidationError, Model
import functools
import itertools
import numpy.random
@ -738,6 +738,24 @@ def get_package_path(name: str) -> Path:
return Path(pkg.__file__).parent
def replace_model_node(model: Model, target: Model, replacement: Model) -> None:
"""Replace a node within a model with a new one, updating refs.
model (Model): The parent model.
target (Model): The target node.
replacement (Model): The node to replace the target with.
"""
# Place the node into the sublayers
for node in model.walk():
if target in node.layers:
node.layers[node.layers.index(target)] = replacement
# Now fix any node references
for node in model.walk():
for ref_name in node.ref_names:
if node.maybe_get_ref(ref_name) is target:
node.set_ref(ref_name, replacement)
def split_command(command: str) -> List[str]:
"""Split a string command using shlex. Handles platform compatibility.
@ -1279,6 +1297,25 @@ def dot_to_object(config: Config, section: str):
return component
def set_dot_to_object(config: Config, section: str, value: Any) -> None:
"""Update a config at a given position from a dot notation.
config (Config): The config.
section (str): The dot notation of the section in the config.
value (Any): The value to set in the config.
"""
component = config
parts = section.split(".")
for i, item in enumerate(parts):
try:
if i == len(parts) - 1:
component[item] = value
else:
component = component[item]
except (KeyError, TypeError):
raise KeyError(Errors.E952.format(name=section)) from None
def walk_dict(
node: Dict[str, Any], parent: List[str] = []
) -> Iterator[Tuple[List[str], Any]]: