diff --git a/spacy/language.py b/spacy/language.py index 5a2a0cd65..12b319fd3 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -671,81 +671,6 @@ class Language: self._pipe_configs[name] = filled return resolved[factory_name] - def replace_listeners( - self, - tok2vec_name: str, - pipe_name: str, - listeners: Iterable[str] = SimpleFrozenList(), - ) -> None: - """Find listener layers (connecting to a token-to-vector embedding - component) of a given pipeline component model and replace - them with a standalone copy of the token-to-vector layer. This can be - useful when training a pipeline with components sourced from an existing - pipeline: if multiple components (e.g. tagger, parser, NER) listen to - the same tok2vec component, but some of them are frozen and not updated, - their performance may degrade significally as the tok2vec component is - updated with new data. To prevent this, listeners can be replaced with - a standalone tok2vec layer that is owned by the component and doesn't - change if the component isn't updated. - - tok2vec_name (str): Name of the token-to-vector component, typically - "tok2vec" or "transformer". - pipe_name (str): Name of pipeline component to replace listeners for. - listeners (Iterable[str]): The paths to the listeners, relative to the - component config, e.g. ["model.tok2vec"]. Typically, implementations - will only connect to one tok2vec component, [model.tok2vec], but in - theory, custom models can use multiple listeners. The value here can - either be an empty list to not replace any listeners, or a complete - (!) list of the paths to all listener layers used by the model. - - DOCS: https://nightly.spacy.io/api/language#replace_listeners - """ - if tok2vec_name not in self.pipe_names: - err = Errors.E889.format(name=tok2vec_name, opts=", ".join(self.pipe_names)) - raise ValueError(err) - if pipe_name not in self.pipe_names: - err = Errors.E889.format(name=pipe_name, opts=", ".join(self.pipe_names)) - raise ValueError(err) - tok2vec = self.get_pipe(tok2vec_name) - tok2vec_cfg = self.get_pipe_config(tok2vec_name) - if ( - not hasattr(tok2vec, "model") - or not hasattr(tok2vec, "listener_map") - or "model" not in tok2vec_cfg - ): - raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec))) - pipe_listeners = tok2vec.listener_map.get(pipe_name, []) - pipe_cfg = self._pipe_configs[pipe_name] - if listeners: - util.logger.debug(f"Replacing listeners of component '{pipe_name}'") - if len(listeners) != len(pipe_listeners): - # The number of listeners defined in the component model doesn't - # match the listeners to replace, so we won't be able to update - # the nodes and generate a matching config - err = Errors.E887.format( - name=pipe_name, - tok2vec=tok2vec_name, - paths=listeners, - n_listeners=len(pipe_listeners), - ) - raise ValueError(err) - pipe = self.get_pipe(pipe_name) - # Update the config accordingly by coping the tok2vec model to all - # sections defined in the listener paths - for listener_path in listeners: - # Check if the path actually exists in the config - try: - util.dot_to_object(pipe_cfg, listener_path) - except KeyError: - err = Errors.E886.format( - name=pipe_name, tok2vec=tok2vec_name, path=listener_path - ) - raise ValueError(err) - util.set_dot_to_object(pipe_cfg, listener_path, tok2vec_cfg["model"]) - # Go over the listener layers and replace them - for listener in pipe_listeners: - util.replace_model_node(pipe.model, listener, tok2vec.model.copy()) - def create_pipe_from_source( self, source_name: str, source: "Language", *, name: str ) -> Tuple[Callable[[Doc], Doc], str]: @@ -1748,6 +1673,81 @@ class Language: ) return nlp + def replace_listeners( + self, + tok2vec_name: str, + pipe_name: str, + listeners: Iterable[str] = SimpleFrozenList(), + ) -> None: + """Find listener layers (connecting to a token-to-vector embedding + component) of a given pipeline component model and replace + them with a standalone copy of the token-to-vector layer. This can be + useful when training a pipeline with components sourced from an existing + pipeline: if multiple components (e.g. tagger, parser, NER) listen to + the same tok2vec component, but some of them are frozen and not updated, + their performance may degrade significally as the tok2vec component is + updated with new data. To prevent this, listeners can be replaced with + a standalone tok2vec layer that is owned by the component and doesn't + change if the component isn't updated. + + tok2vec_name (str): Name of the token-to-vector component, typically + "tok2vec" or "transformer". + pipe_name (str): Name of pipeline component to replace listeners for. + listeners (Iterable[str]): The paths to the listeners, relative to the + component config, e.g. ["model.tok2vec"]. Typically, implementations + will only connect to one tok2vec component, [model.tok2vec], but in + theory, custom models can use multiple listeners. The value here can + either be an empty list to not replace any listeners, or a complete + (!) list of the paths to all listener layers used by the model. + + DOCS: https://nightly.spacy.io/api/language#replace_listeners + """ + if tok2vec_name not in self.pipe_names: + err = Errors.E889.format(name=tok2vec_name, opts=", ".join(self.pipe_names)) + raise ValueError(err) + if pipe_name not in self.pipe_names: + err = Errors.E889.format(name=pipe_name, opts=", ".join(self.pipe_names)) + raise ValueError(err) + tok2vec = self.get_pipe(tok2vec_name) + tok2vec_cfg = self.get_pipe_config(tok2vec_name) + if ( + not hasattr(tok2vec, "model") + or not hasattr(tok2vec, "listener_map") + or "model" not in tok2vec_cfg + ): + raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec))) + pipe_listeners = tok2vec.listener_map.get(pipe_name, []) + pipe_cfg = self._pipe_configs[pipe_name] + if listeners: + util.logger.debug(f"Replacing listeners of component '{pipe_name}'") + if len(listeners) != len(pipe_listeners): + # The number of listeners defined in the component model doesn't + # match the listeners to replace, so we won't be able to update + # the nodes and generate a matching config + err = Errors.E887.format( + name=pipe_name, + tok2vec=tok2vec_name, + paths=listeners, + n_listeners=len(pipe_listeners), + ) + raise ValueError(err) + pipe = self.get_pipe(pipe_name) + # Update the config accordingly by coping the tok2vec model to all + # sections defined in the listener paths + for listener_path in listeners: + # Check if the path actually exists in the config + try: + util.dot_to_object(pipe_cfg, listener_path) + except KeyError: + err = Errors.E886.format( + name=pipe_name, tok2vec=tok2vec_name, path=listener_path + ) + raise ValueError(err) + util.set_dot_to_object(pipe_cfg, listener_path, tok2vec_cfg["model"]) + # Go over the listener layers and replace them + for listener in pipe_listeners: + util.replace_model_node(pipe.model, listener, tok2vec.model.copy()) + def to_disk( self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList() ) -> None: