mirror of https://github.com/explosion/spaCy.git
Add option to replace listeners for sourced components
This commit is contained in:
parent
78d6ff4dd4
commit
911dfcccfc
|
@ -8,7 +8,7 @@ from contextlib import contextmanager
|
|||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
from thinc.api import Model, get_current_ops, Config, Optimizer
|
||||
from thinc.api import get_current_ops, Config, Optimizer
|
||||
import srsly
|
||||
import multiprocessing as mp
|
||||
from itertools import chain, cycle
|
||||
|
@ -670,6 +670,47 @@ class Language:
|
|||
self._pipe_configs[name] = filled
|
||||
return resolved[factory_name]
|
||||
|
||||
def replace_listeners(
|
||||
self,
|
||||
tok2vec_name: str,
|
||||
pipe_name: str,
|
||||
listeners: Iterable[str] = SimpleFrozenList(),
|
||||
):
|
||||
if tok2vec_name not in self.pipe_names:
|
||||
raise ValueError # TODO:
|
||||
if pipe_name not in self.pipe_names:
|
||||
raise ValueError # TODO:
|
||||
tok2vec = self.get_pipe(tok2vec_name)
|
||||
tok2vec_cfg = self.get_pipe_config(tok2vec_name)
|
||||
if (
|
||||
not hasattr(tok2vec, "model")
|
||||
or not hasattr(tok2vec, "listener_map")
|
||||
or "model" not in tok2vec_cfg
|
||||
):
|
||||
raise ValueError # TODO: likely bug in spaCy if this happens
|
||||
pipe_listeners = tok2vec.listener_map.get(pipe_name, [])
|
||||
pipe_cfg = self._pipe_configs[pipe_name]
|
||||
if listeners:
|
||||
util.logger.debug(f"Replacing listeners of component '{pipe_name}'")
|
||||
if len(listeners) != len(pipe_listeners):
|
||||
# The number of listeners defined in the component model doesn't
|
||||
# match the listeners to replace, so we won't be able to update
|
||||
# the nodes and generate a matching config
|
||||
raise ValueError(f"{listeners}, {pipe_listeners}") # TODO:
|
||||
pipe = self.get_pipe(pipe_name)
|
||||
# Go over the listener layers and replace them
|
||||
for listener in pipe_listeners:
|
||||
util.replace_model_node(pipe.model, listener, tok2vec.model.copy())
|
||||
# Update the config accordingly by coping the tok2vec model to all
|
||||
# sections defined in the listener paths
|
||||
for listener_path in listeners:
|
||||
# Check if the path actually exists in the config
|
||||
try:
|
||||
util.dot_to_object(pipe_cfg, listener_path)
|
||||
except KeyError:
|
||||
raise ValueError # TODO:
|
||||
util.set_dot_to_object(pipe_cfg, listener_path, tok2vec_cfg["model"])
|
||||
|
||||
def create_pipe_from_source(
|
||||
self, source_name: str, source: "Language", *, name: str
|
||||
) -> Tuple[Callable[[Doc], Doc], str]:
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import pytest
|
||||
|
||||
from spacy.ml.models.tok2vec import build_Tok2Vec_model
|
||||
from spacy.ml.models.tok2vec import MultiHashEmbed, CharacterEmbed
|
||||
from spacy.ml.models.tok2vec import MishWindowEncoder, MaxoutWindowEncoder
|
||||
|
@ -9,12 +8,11 @@ from spacy.tokens import Doc
|
|||
from spacy.training import Example
|
||||
from spacy import util
|
||||
from spacy.lang.en import English
|
||||
from ..util import get_batch
|
||||
|
||||
from thinc.api import Config
|
||||
|
||||
from numpy.testing import assert_equal
|
||||
|
||||
from ..util import get_batch
|
||||
|
||||
|
||||
def test_empty_doc():
|
||||
width = 128
|
||||
|
@ -187,3 +185,29 @@ def test_tok2vec_listener_callback():
|
|||
Y, get_dX = tagger.model.begin_update(docs)
|
||||
# assure that the backprop call works (and doesn't hit a 'None' callback)
|
||||
assert get_dX(Y) is not None
|
||||
|
||||
|
||||
def test_replace_listeners():
|
||||
orig_config = Config().from_str(cfg_string)
|
||||
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||
examples = [Example.from_dict(nlp.make_doc("x y"), {"tags": ["V", "Z"]})]
|
||||
nlp.initialize(lambda: examples)
|
||||
tok2vec = nlp.get_pipe("tok2vec")
|
||||
tagger = nlp.get_pipe("tagger")
|
||||
assert isinstance(tagger.model.layers[0], Tok2VecListener)
|
||||
assert tok2vec.listener_map["tagger"][0] == tagger.model.layers[0]
|
||||
assert nlp.config["components"]["tok2vec"]["model"]["@architectures"] == "spacy.Tok2Vec.v2"
|
||||
assert nlp.config["components"]["tagger"]["model"]["tok2vec"]["@architectures"] == "spacy.Tok2VecListener.v1"
|
||||
nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec"])
|
||||
assert not isinstance(tagger.model.layers[0], Tok2VecListener)
|
||||
t2v_cfg = nlp.config["components"]["tok2vec"]["model"]
|
||||
assert t2v_cfg["@architectures"] == "spacy.Tok2Vec.v2"
|
||||
assert nlp.config["components"]["tagger"]["model"]["tok2vec"] == t2v_cfg
|
||||
with pytest.raises(ValueError):
|
||||
nlp.replace_listeners("invalid", "tagger", ["model.tok2vec"])
|
||||
with pytest.raises(ValueError):
|
||||
nlp.replace_listeners("tok2vec", "parser", ["model.tok2vec"])
|
||||
with pytest.raises(ValueError):
|
||||
nlp.replace_listeners("tok2vec", "tagger", ["model.yolo"])
|
||||
with pytest.raises(ValueError):
|
||||
nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec", "model.yolo"])
|
||||
|
|
|
@ -205,6 +205,25 @@ def test_dot_to_dict(dot_notation, expected):
|
|||
assert util.dict_to_dot(result) == dot_notation
|
||||
|
||||
|
||||
def test_set_dot_to_object():
|
||||
config = {"foo": {"bar": 1, "baz": {"x": "y"}}, "test": {"a": {"b": "c"}}}
|
||||
with pytest.raises(KeyError):
|
||||
util.set_dot_to_object(config, "foo.bar.baz", 100)
|
||||
with pytest.raises(KeyError):
|
||||
util.set_dot_to_object(config, "hello.world", 100)
|
||||
with pytest.raises(KeyError):
|
||||
util.set_dot_to_object(config, "test.a.b.c", 100)
|
||||
util.set_dot_to_object(config, "foo.bar", 100)
|
||||
assert config["foo"]["bar"] == 100
|
||||
util.set_dot_to_object(config, "foo.baz.x", {"hello": "world"})
|
||||
assert config["foo"]["baz"]["x"]["hello"] == "world"
|
||||
assert config["test"]["a"]["b"] == "c"
|
||||
util.set_dot_to_object(config, "foo", 123)
|
||||
assert config["foo"] == 123
|
||||
util.set_dot_to_object(config, "test", "hello")
|
||||
assert dict(config) == {"foo": 123, "test": "hello"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"doc_sizes, expected_batches",
|
||||
[
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Union, Dict, Optional, Any, List, IO, TYPE_CHECKING
|
||||
from typing import Union, Dict, Optional, Any, IO, TYPE_CHECKING
|
||||
from thinc.api import Config, fix_random_seed, set_gpu_allocator
|
||||
from thinc.api import ConfigValidationError
|
||||
from pathlib import Path
|
||||
|
@ -33,7 +33,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
|||
if use_gpu >= 0 and allocator:
|
||||
set_gpu_allocator(allocator)
|
||||
# Use original config here before it's resolved to functions
|
||||
sourced_components = get_sourced_components(config)
|
||||
sourced = get_sourced_components(config)
|
||||
nlp = load_model_from_config(raw_config, auto_fill=True)
|
||||
logger.info("Set up nlp object from config")
|
||||
config = nlp.config.interpolate()
|
||||
|
@ -57,7 +57,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
|||
# Components that shouldn't be updated during training
|
||||
frozen_components = T["frozen_components"]
|
||||
# Sourced components that require resume_training
|
||||
resume_components = [p for p in sourced_components if p not in frozen_components]
|
||||
resume_components = [p for p in sourced if p not in frozen_components]
|
||||
logger.info(f"Pipeline: {nlp.pipe_names}")
|
||||
if resume_components:
|
||||
with nlp.select_pipes(enable=resume_components):
|
||||
|
@ -68,10 +68,18 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
|||
logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
|
||||
# Detect components with listeners that are not frozen consistently
|
||||
for name, proc in nlp.pipeline:
|
||||
if getattr(proc, "listening_components", None):
|
||||
if getattr(proc, "listening_components", None): # e.g. tok2vec/transformer
|
||||
for listener in proc.listening_components:
|
||||
if listener in frozen_components and name not in frozen_components:
|
||||
# If it's a component sourced from another pipeline, we check if
|
||||
# the tok2vec listeners should be replaced with standalone tok2vec
|
||||
# models (e.g. so component can be frozen without its performance
|
||||
# degrading when other components/tok2vec are updated)
|
||||
paths = sourced.get(listener, {}).get("replace_listeners", [])
|
||||
if paths:
|
||||
nlp.replace_listeners(name, listener, paths)
|
||||
elif listener in frozen_components and name not in frozen_components:
|
||||
logger.warning(Warnings.W087.format(name=name, listener=listener))
|
||||
# We always check this regardless, in case user freezes tok2vec
|
||||
if listener not in frozen_components and name in frozen_components:
|
||||
logger.warning(Warnings.W086.format(name=name, listener=listener))
|
||||
return nlp
|
||||
|
@ -173,16 +181,18 @@ def init_tok2vec(
|
|||
return False
|
||||
|
||||
|
||||
def get_sourced_components(config: Union[Dict[str, Any], Config]) -> List[str]:
|
||||
def get_sourced_components(
|
||||
config: Union[Dict[str, Any], Config]
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""RETURNS (List[str]): All sourced components in the original config,
|
||||
e.g. {"source": "en_core_web_sm"}. If the config contains a key
|
||||
"factory", we assume it refers to a component factory.
|
||||
"""
|
||||
return [
|
||||
name
|
||||
return {
|
||||
name: cfg
|
||||
for name, cfg in config.get("components", {}).items()
|
||||
if "factory" not in cfg and "source" in cfg
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def convert_vectors(
|
||||
|
|
|
@ -8,7 +8,7 @@ import re
|
|||
from pathlib import Path
|
||||
import thinc
|
||||
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
|
||||
from thinc.api import ConfigValidationError
|
||||
from thinc.api import ConfigValidationError, Model
|
||||
import functools
|
||||
import itertools
|
||||
import numpy.random
|
||||
|
@ -738,6 +738,24 @@ def get_package_path(name: str) -> Path:
|
|||
return Path(pkg.__file__).parent
|
||||
|
||||
|
||||
def replace_model_node(model: Model, target: Model, replacement: Model) -> None:
|
||||
"""Replace a node within a model with a new one, updating refs.
|
||||
|
||||
model (Model): The parent model.
|
||||
target (Model): The target node.
|
||||
replacement (Model): The node to replace the target with.
|
||||
"""
|
||||
# Place the node into the sublayers
|
||||
for node in model.walk():
|
||||
if target in node.layers:
|
||||
node.layers[node.layers.index(target)] = replacement
|
||||
# Now fix any node references
|
||||
for node in model.walk():
|
||||
for ref_name in node.ref_names:
|
||||
if node.maybe_get_ref(ref_name) is target:
|
||||
node.set_ref(ref_name, replacement)
|
||||
|
||||
|
||||
def split_command(command: str) -> List[str]:
|
||||
"""Split a string command using shlex. Handles platform compatibility.
|
||||
|
||||
|
@ -1279,6 +1297,25 @@ def dot_to_object(config: Config, section: str):
|
|||
return component
|
||||
|
||||
|
||||
def set_dot_to_object(config: Config, section: str, value: Any) -> None:
|
||||
"""Update a config at a given position from a dot notation.
|
||||
|
||||
config (Config): The config.
|
||||
section (str): The dot notation of the section in the config.
|
||||
value (Any): The value to set in the config.
|
||||
"""
|
||||
component = config
|
||||
parts = section.split(".")
|
||||
for i, item in enumerate(parts):
|
||||
try:
|
||||
if i == len(parts) - 1:
|
||||
component[item] = value
|
||||
else:
|
||||
component = component[item]
|
||||
except (KeyError, TypeError):
|
||||
raise KeyError(Errors.E952.format(name=section)) from None
|
||||
|
||||
|
||||
def walk_dict(
|
||||
node: Dict[str, Any], parent: List[str] = []
|
||||
) -> Iterator[Tuple[List[str], Any]]:
|
||||
|
|
Loading…
Reference in New Issue