Add specific error when StaticVectors can't read the vectors data (#6450)

This commit is contained in:
Sofie Van Landeghem 2020-12-08 23:16:07 +01:00 committed by GitHub
parent 8921364579
commit de108ed3e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 3 deletions

View File

@ -457,6 +457,9 @@ class Errors:
"issue tracker: http://github.com/explosion/spaCy/issues") "issue tracker: http://github.com/explosion/spaCy/issues")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
E896 = ("There was an error using the static vectors. Ensure that the vectors "
"of the vocab are properly initialized, or set 'include_static_vectors' "
"to False.")
E897 = ("Field '{field}' should be a dot-notation string referring to the " E897 = ("Field '{field}' should be a dot-notation string referring to the "
"relevant section in the config, but found type {type} instead.") "relevant section in the config, but found type {type} instead.")
E898 = ("Can't serialize trainable pipe '{name}': the `model` attribute " E898 = ("Can't serialize trainable pipe '{name}': the `model` attribute "

View File

@ -42,9 +42,13 @@ def forward(
rows = model.ops.flatten( rows = model.ops.flatten(
[doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs] [doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs]
) )
try:
vectors_data = model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True)
except ValueError:
raise RuntimeError(Errors.E896)
output = Ragged( output = Ragged(
model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True), vectors_data,
model.ops.asarray([len(doc) for doc in docs], dtype="i"), model.ops.asarray([len(doc) for doc in docs], dtype="i")
) )
mask = None mask = None
if is_train: if is_train:

View File

@ -103,7 +103,7 @@ def load_vectors_into_model(
"with the packaged vectors. Make sure that the vectors package you're " "with the packaged vectors. Make sure that the vectors package you're "
"loading is compatible with the current version of spaCy." "loading is compatible with the current version of spaCy."
) )
err = ConfigValidationError.from_error(e, config=None, title=title, desc=desc) err = ConfigValidationError.from_error(e, title=title, desc=desc)
raise err from None raise err from None
nlp.vocab.vectors = vectors_nlp.vocab.vectors nlp.vocab.vectors = vectors_nlp.vocab.vectors
if add_strings: if add_strings: