From 7abfa250353f7fb2da1e41026993fa79393607f3 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Mon, 21 Jun 2021 09:33:33 +0200 Subject: [PATCH] Don't use the same vocab for source models (#8388) * Don't use the same vocab for source models The source models should not be loaded with the vocab from the current pipeline because this loads the vectors from the source model into the current vocab. The strings are all copied in `Language.create_pipe_from_source`, so if the vectors are configured correctly in the current pipeline, the sourced component will work as expected. If there is a vector mismatch, a warning is shown. (It's not possible to inspect whether the vectors are actually used by the component, so a warning is the best option.) * Update comment on source model loading --- spacy/language.py | 9 ++++++--- spacy/tests/test_language.py | 23 +++++++++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 643a0c0f8..c35a8c016 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1696,9 +1696,12 @@ class Language: else: model = pipe_cfg["source"] if model not in source_nlps: - # We only need the components here and we need to init - # model with the same vocab as the current nlp object - source_nlps[model] = util.load_model(model, vocab=nlp.vocab) + # We only need the components here and we intentionally + # do not load the model with the same vocab because + # this would cause the vectors to be copied into the + # current nlp object (all the strings will be added in + # create_pipe_from_source) + source_nlps[model] = util.load_model(model) source_name = pipe_cfg.get("component", pipe_name) listeners_replaced = False if "replace_listeners" in pipe_cfg: diff --git a/spacy/tests/test_language.py b/spacy/tests/test_language.py index 86cce5f9e..137aea04f 100644 --- a/spacy/tests/test_language.py +++ b/spacy/tests/test_language.py @@ -475,3 +475,26 @@ def test_language_init_invalid_vocab(value): with pytest.raises(ValueError) as e: Language(value) assert err_fragment in str(e.value) + + +def test_language_source_and_vectors(nlp2): + nlp = Language(Vocab()) + textcat = nlp.add_pipe("textcat") + for label in ("POSITIVE", "NEGATIVE"): + textcat.add_label(label) + nlp.initialize() + long_string = "thisisalongstring" + assert long_string not in nlp.vocab.strings + assert long_string not in nlp2.vocab.strings + nlp.vocab.strings.add(long_string) + assert nlp.vocab.vectors.to_bytes() != nlp2.vocab.vectors.to_bytes() + vectors_bytes = nlp.vocab.vectors.to_bytes() + # TODO: convert to pytest.warns for v3.1 + logger = logging.getLogger("spacy") + with mock.patch.object(logger, "warning") as mock_warning: + nlp2.add_pipe("textcat", name="textcat2", source=nlp) + mock_warning.assert_called() + # strings should be added + assert long_string in nlp2.vocab.strings + # vectors should remain unmodified + assert nlp.vocab.vectors.to_bytes() == vectors_bytes