Don't use the same vocab for source models (#8388)

* Don't use the same vocab for source models

The source models should not be loaded with the vocab from the current
pipeline because this loads the vectors from the source model into the
current vocab.

The strings are all copied in `Language.create_pipe_from_source`, so if
the vectors are configured correctly in the current pipeline, the
sourced component will work as expected. If there is a vector mismatch,
a warning is shown. (It's not possible to inspect whether the vectors
are actually used by the component, so a warning is the best option.)

* Update comment on source model loading
This commit is contained in:
Adriane Boyd 2021-06-21 09:33:33 +02:00 committed by GitHub
parent 02d2fdb123
commit 7abfa25035
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 3 deletions

View File

@ -1696,9 +1696,12 @@ class Language:
else: else:
model = pipe_cfg["source"] model = pipe_cfg["source"]
if model not in source_nlps: if model not in source_nlps:
# We only need the components here and we need to init # We only need the components here and we intentionally
# model with the same vocab as the current nlp object # do not load the model with the same vocab because
source_nlps[model] = util.load_model(model, vocab=nlp.vocab) # this would cause the vectors to be copied into the
# current nlp object (all the strings will be added in
# create_pipe_from_source)
source_nlps[model] = util.load_model(model)
source_name = pipe_cfg.get("component", pipe_name) source_name = pipe_cfg.get("component", pipe_name)
listeners_replaced = False listeners_replaced = False
if "replace_listeners" in pipe_cfg: if "replace_listeners" in pipe_cfg:

View File

@ -475,3 +475,26 @@ def test_language_init_invalid_vocab(value):
with pytest.raises(ValueError) as e: with pytest.raises(ValueError) as e:
Language(value) Language(value)
assert err_fragment in str(e.value) assert err_fragment in str(e.value)
def test_language_source_and_vectors(nlp2):
nlp = Language(Vocab())
textcat = nlp.add_pipe("textcat")
for label in ("POSITIVE", "NEGATIVE"):
textcat.add_label(label)
nlp.initialize()
long_string = "thisisalongstring"
assert long_string not in nlp.vocab.strings
assert long_string not in nlp2.vocab.strings
nlp.vocab.strings.add(long_string)
assert nlp.vocab.vectors.to_bytes() != nlp2.vocab.vectors.to_bytes()
vectors_bytes = nlp.vocab.vectors.to_bytes()
# TODO: convert to pytest.warns for v3.1
logger = logging.getLogger("spacy")
with mock.patch.object(logger, "warning") as mock_warning:
nlp2.add_pipe("textcat", name="textcat2", source=nlp)
mock_warning.assert_called()
# strings should be added
assert long_string in nlp2.vocab.strings
# vectors should remain unmodified
assert nlp.vocab.vectors.to_bytes() == vectors_bytes