From ece8be4feca968cd294c00e0423d8daceafa639d Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 12 May 2021 11:32:22 +0200 Subject: [PATCH 1/3] extend test to training with replaced tok2vec layer --- spacy/tests/pipeline/test_tok2vec.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/spacy/tests/pipeline/test_tok2vec.py b/spacy/tests/pipeline/test_tok2vec.py index e3b71c502..7a9e96b14 100644 --- a/spacy/tests/pipeline/test_tok2vec.py +++ b/spacy/tests/pipeline/test_tok2vec.py @@ -218,6 +218,13 @@ def test_replace_listeners(): nlp.replace_listeners("tok2vec", "tagger", ["model.yolo"]) with pytest.raises(ValueError): nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec", "model.yolo"]) + # attempt training with the new pipeline + optimizer = nlp.initialize(lambda: examples) + for i in range(2): + losses = {} + nlp.update(examples, sgd=optimizer, losses=losses) + assert losses["tok2vec"] == 0.0 + assert losses["tagger"] > 0.0 cfg_string_multi = """ From 44a3a585992bcdf7625aacbb3984796f489cb10e Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 12 May 2021 16:01:02 +0200 Subject: [PATCH 2/3] call replace_listener attr if it's available --- spacy/language.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/spacy/language.py b/spacy/language.py index 95a902380..4959716e2 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1801,7 +1801,10 @@ class Language: util.set_dot_to_object(pipe_cfg, listener_path, tok2vec_cfg["model"]) # Go over the listener layers and replace them for listener in pipe_listeners: - util.replace_model_node(pipe.model, listener, tok2vec.model.copy()) + new_model = tok2vec.model.copy() + if "replace_listener" in new_model.attrs: + new_model = new_model.attrs["replace_listener"](new_model) + util.replace_model_node(pipe.model, listener, new_model) tok2vec.remove_listener(listener, pipe_name) def to_disk( From 235e9f548868510611541846d931f53df9c99c95 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 12 May 2021 17:19:38 +0200 Subject: [PATCH 3/3] call replace_listener_cfg attr if it's available --- spacy/language.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 4959716e2..c30333dc9 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1764,6 +1764,7 @@ class Language: raise ValueError(err) tok2vec = self.get_pipe(tok2vec_name) tok2vec_cfg = self.get_pipe_config(tok2vec_name) + tok2vec_model = tok2vec.model if ( not hasattr(tok2vec, "model") or not hasattr(tok2vec, "listener_map") @@ -1772,6 +1773,7 @@ class Language: ): raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec))) pipe_listeners = tok2vec.listener_map.get(pipe_name, []) + pipe = self.get_pipe(pipe_name) pipe_cfg = self._pipe_configs[pipe_name] if listeners: util.logger.debug(f"Replacing listeners of component '{pipe_name}'") @@ -1786,7 +1788,6 @@ class Language: n_listeners=len(pipe_listeners), ) raise ValueError(err) - pipe = self.get_pipe(pipe_name) # Update the config accordingly by copying the tok2vec model to all # sections defined in the listener paths for listener_path in listeners: @@ -1798,12 +1799,16 @@ class Language: name=pipe_name, tok2vec=tok2vec_name, path=listener_path ) raise ValueError(err) - util.set_dot_to_object(pipe_cfg, listener_path, tok2vec_cfg["model"]) + new_config = tok2vec_cfg["model"] + if "replace_listener_cfg" in tok2vec_model.attrs: + replace_func = tok2vec_model.attrs["replace_listener_cfg"] + new_config = replace_func(tok2vec_cfg["model"], pipe_cfg["model"]["tok2vec"]) + util.set_dot_to_object(pipe_cfg, listener_path, new_config) # Go over the listener layers and replace them for listener in pipe_listeners: - new_model = tok2vec.model.copy() - if "replace_listener" in new_model.attrs: - new_model = new_model.attrs["replace_listener"](new_model) + new_model = tok2vec_model.copy() + if "replace_listener" in tok2vec_model.attrs: + new_model = tok2vec_model.attrs["replace_listener"](new_model) util.replace_model_node(pipe.model, listener, new_model) tok2vec.remove_listener(listener, pipe_name)