Fix vectors check for sourced components (#8559)

* Fix vectors check for sourced components

Since vectors are not loaded when components are sourced, store a hash
for the vectors of each sourced component and compare it to the loaded
vectors after the vectors are loaded from the `[initialize]` block.

* Pop temporary info

* Remove stored hash in remove_pipe

* Add default for pop

* Add additional convert/debug/assemble CLI tests
This commit is contained in:
Adriane Boyd 2021-07-06 12:43:17 +02:00 committed by GitHub
parent 29906884c5
commit 5fd0b5207e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 1 deletions

View File

@ -62,6 +62,30 @@ steps:
- script: |
python -m spacy download ca_core_news_sm
python -m spacy download ca_core_news_md
python -c "import spacy; nlp=spacy.load('ca_core_news_sm'); doc=nlp('test')"
displayName: 'Test download CLI'
condition: eq(variables['python_version'], '3.8')
- script: |
python -m spacy convert extra/example_data/ner_example_data/ner-token-per-line-conll2003.json .
displayName: 'Test convert CLI'
condition: eq(variables['python_version'], '3.8')
- script: |
python -m spacy init config -p ner -l ca ner.cfg
python -m spacy debug config ner.cfg --paths.train ner-token-per-line-conll2003.spacy --paths.dev ner-token-per-line-conll2003.spacy
displayName: 'Test debug config CLI'
condition: eq(variables['python_version'], '3.8')
- script: |
python -c "import spacy; config = spacy.util.load_config('ner.cfg'); config['components']['ner'] = {'source': 'ca_core_news_sm'}; config.to_disk('ner_source_sm.cfg')"
PYTHONWARNINGS="error,ignore::DeprecationWarning" python -m spacy assemble ner_source_sm.cfg output_dir
displayName: 'Test assemble CLI'
condition: eq(variables['python_version'], '3.8')
- script: |
python -c "import spacy; config = spacy.util.load_config('ner.cfg'); config['components']['ner'] = {'source': 'ca_core_news_md'}; config.to_disk('ner_source_md.cfg')"
python -m spacy assemble ner_source_md.cfg output_dir 2>&1 | grep -q W113
displayName: 'Test assemble CLI vectors warning'
condition: eq(variables['python_version'], '3.8')

View File

@ -934,6 +934,7 @@ class Language:
# because factory may be used for something else
self._pipe_meta.pop(name)
self._pipe_configs.pop(name)
self.meta.get("_sourced_vectors_hashes", {}).pop(name, None)
# Make sure name is removed from the [initialize] config
if name in self._config["initialize"]["components"]:
self._config["initialize"]["components"].pop(name)
@ -1680,6 +1681,8 @@ class Language:
# If components are loaded from a source (existing models), we cache
# them here so they're only loaded once
source_nlps = {}
source_nlp_vectors_hashes = {}
nlp.meta["_sourced_vectors_hashes"] = {}
for pipe_name in config["nlp"]["pipeline"]:
if pipe_name not in pipeline:
opts = ", ".join(pipeline.keys())
@ -1719,7 +1722,12 @@ class Language:
name, source_name, pipe_cfg["replace_listeners"]
)
listeners_replaced = True
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="\\[W113\\]")
nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name)
if model not in source_nlp_vectors_hashes:
source_nlp_vectors_hashes[model] = hash(source_nlps[model].vocab.vectors.to_bytes())
nlp.meta["_sourced_vectors_hashes"][pipe_name] = source_nlp_vectors_hashes[model]
# Delete from cache if listeners were replaced
if listeners_replaced:
del source_nlps[model]

View File

@ -9,6 +9,7 @@ import gzip
import zipfile
import tqdm
from itertools import islice
import warnings
from .pretrain import get_tok2vec_ref
from ..lookups import Lookups
@ -124,6 +125,12 @@ def init_vocab(
if vectors is not None:
load_vectors_into_model(nlp, vectors)
logger.info(f"Added vectors: {vectors}")
# warn if source model vectors are not identical
sourced_vectors_hashes = nlp.meta.pop("_sourced_vectors_hashes", {})
vectors_hash = hash(nlp.vocab.vectors.to_bytes())
for sourced_component, sourced_vectors_hash in sourced_vectors_hashes.items():
if vectors_hash != sourced_vectors_hash:
warnings.warn(Warnings.W113.format(name=sourced_component))
logger.info("Finished initializing nlp object")