diff --git a/spacy/errors.py b/spacy/errors.py index a50e986ac..4d66fd0ef 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -475,7 +475,18 @@ class Errors: "issue tracker: http://github.com/explosion/spaCy/issues") # TODO: fix numbering after merging develop into master - E890 = ("Can not add the alias '{alias}' to the Knowledge base. " + E886 = ("Can't replace {name} -> {tok2vec} listeners: path '{path}' not " + "found in config for component '{name}'.") + E887 = ("Can't replace {name} -> {tok2vec} listeners: the paths to replace " + "({paths}) don't match the available listeners in the model ({n_listeners}).") + E888 = ("Can't replace listeners for '{name}' ({pipe}): invalid upstream " + "component that doesn't seem to support listeners. Expected Tok2Vec " + "or Transformer component. If you didn't call nlp.replace_listeners " + "manually, this is likely a bug in spaCy.") + E889 = ("Can't replace listeners of component '{name}' because it's not " + "in the pipeline. Available components: {opts}. If you didn't call " + "nlp.replace_listeners manually, this is likely a bug in spaCy.") + E890 = ("Cannot add the alias '{alias}' to the Knowledge base. " "Each alias should be a meaningful string.") E891 = ("Alias '{alias}' could not be added to the Knowledge base. " "This is likely a bug in spaCy.") diff --git a/spacy/language.py b/spacy/language.py index cc079af62..f0d311e5d 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -676,11 +676,36 @@ class Language: tok2vec_name: str, pipe_name: str, listeners: Iterable[str] = SimpleFrozenList(), - ): + ) -> None: + """Find listener layers (connecting to a token-to-vector embedding + component) of a given pipeline component model and replace + them with a standalone copy of the token-to-vector layer. This can be + useful when training a pipeline with components sourced from an existing + pipeline: if multiple components (e.g. tagger, parser, NER) listen to + the same tok2vec component, but some of them are frozen and not updated, + their performance may degrade significally as the tok2vec component is + updated with new data. To prevent this, listeners can be replaced with + a standalone tok2vec layer that is owned by the component and doesn't + change if the component isn't updated. + + tok2vec_name (str): Name of the token-to-vector component, typically + "tok2vec" or "transformer". + pipe_name (str): Name of pipeline component to replace listeners for. + listeners (Iterable[str]): The paths to the listeners, relative to the + component config, e.g. ["model.tok2vec"]. Typically, implementations + will only connect to one tok2vec component, [model.tok2vec], but in + theory, custom models can use multiple listeners. The value here can + either be an empty list to not replace any listeners, or a complete + (!) list of the paths to all listener layers used by the model. + + DOCS: https://nightly.spacy.io/api/language#replace_listeners + """ if tok2vec_name not in self.pipe_names: - raise ValueError # TODO: + err = Errors.E889.format(name=tok2vec_name, opts=", ".join(self.pipe_names)) + raise ValueError(err) if pipe_name not in self.pipe_names: - raise ValueError # TODO: + err = Errors.E889.format(name=pipe_name, opts=", ".join(self.pipe_names)) + raise ValueError(err) tok2vec = self.get_pipe(tok2vec_name) tok2vec_cfg = self.get_pipe_config(tok2vec_name) if ( @@ -688,7 +713,7 @@ class Language: or not hasattr(tok2vec, "listener_map") or "model" not in tok2vec_cfg ): - raise ValueError # TODO: likely bug in spaCy if this happens + raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec))) pipe_listeners = tok2vec.listener_map.get(pipe_name, []) pipe_cfg = self._pipe_configs[pipe_name] if listeners: @@ -697,7 +722,13 @@ class Language: # The number of listeners defined in the component model doesn't # match the listeners to replace, so we won't be able to update # the nodes and generate a matching config - raise ValueError(f"{listeners}, {pipe_listeners}") # TODO: + err = Errors.E887.format( + name=pipe_name, + tok2vec=tok2vec_name, + paths=listeners, + n_listeners=len(pipe_listeners), + ) + raise ValueError(err) pipe = self.get_pipe(pipe_name) # Go over the listener layers and replace them for listener in pipe_listeners: @@ -709,7 +740,10 @@ class Language: try: util.dot_to_object(pipe_cfg, listener_path) except KeyError: - raise ValueError # TODO: + err = Errors.E886.format( + name=pipe_name, tok2vec=tok2vec_name, path=listener_path + ) + raise ValueError(err) util.set_dot_to_object(pipe_cfg, listener_path, tok2vec_cfg["model"]) def create_pipe_from_source(