Exclude strings from v3.2+ source vector checks (#9697)

Exclude strings from `Vector.to_bytes()` comparions for v3.2+ `Vectors`
that now include the string store so that the source vector comparison
is only comparing the vectors and not the strings.
This commit is contained in:
Adriane Boyd 2021-11-19 08:51:19 +01:00 committed by GitHub
parent f3981bd0c8
commit ea450d652c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 3 deletions

View File

@ -701,7 +701,8 @@ class Language:
if ( if (
self.vocab.vectors.shape != source.vocab.vectors.shape self.vocab.vectors.shape != source.vocab.vectors.shape
or self.vocab.vectors.key2row != source.vocab.vectors.key2row or self.vocab.vectors.key2row != source.vocab.vectors.key2row
or self.vocab.vectors.to_bytes() != source.vocab.vectors.to_bytes() or self.vocab.vectors.to_bytes(exclude=["strings"])
!= source.vocab.vectors.to_bytes(exclude=["strings"])
): ):
warnings.warn(Warnings.W113.format(name=source_name)) warnings.warn(Warnings.W113.format(name=source_name))
if source_name not in source.component_names: if source_name not in source.component_names:
@ -1822,7 +1823,9 @@ class Language:
) )
if model not in source_nlp_vectors_hashes: if model not in source_nlp_vectors_hashes:
source_nlp_vectors_hashes[model] = hash( source_nlp_vectors_hashes[model] = hash(
source_nlps[model].vocab.vectors.to_bytes() source_nlps[model].vocab.vectors.to_bytes(
exclude=["strings"]
)
) )
if "_sourced_vectors_hashes" not in nlp.meta: if "_sourced_vectors_hashes" not in nlp.meta:
nlp.meta["_sourced_vectors_hashes"] = {} nlp.meta["_sourced_vectors_hashes"] = {}

View File

@ -132,7 +132,7 @@ def init_vocab(
logger.info(f"Added vectors: {vectors}") logger.info(f"Added vectors: {vectors}")
# warn if source model vectors are not identical # warn if source model vectors are not identical
sourced_vectors_hashes = nlp.meta.pop("_sourced_vectors_hashes", {}) sourced_vectors_hashes = nlp.meta.pop("_sourced_vectors_hashes", {})
vectors_hash = hash(nlp.vocab.vectors.to_bytes()) vectors_hash = hash(nlp.vocab.vectors.to_bytes(exclude=["strings"]))
for sourced_component, sourced_vectors_hash in sourced_vectors_hashes.items(): for sourced_component, sourced_vectors_hash in sourced_vectors_hashes.items():
if vectors_hash != sourced_vectors_hash: if vectors_hash != sourced_vectors_hash:
warnings.warn(Warnings.W113.format(name=sourced_component)) warnings.warn(Warnings.W113.format(name=sourced_component))