Fix empty input into StaticVectors layer (#6211)

* Add test for empty doc(s)

* Fix empty check in staticvectors

* Remove xfail

* Update spacy/ml/staticvectors.py
This commit is contained in:
Matthew Honnibal 2020-10-06 14:15:41 +02:00 committed by GitHub
parent 2a17566da3
commit cfb9770a94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 1 deletions

View File

@ -34,7 +34,7 @@ def StaticVectors(
def forward(
model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool
) -> Tuple[Ragged, Callable]:
if not len(docs):
if not sum(len(doc) for doc in docs):
return _handle_empty(model.ops, model.get_dim("nO"))
key_attr = model.attrs["key_attr"]
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))

View File

@ -7,6 +7,7 @@ import numpy
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier
from spacy.ml.staticvectors import StaticVectors
from spacy.lang.en import English
from spacy.lang.en.examples import sentences as EN_SENTENCES
@ -185,3 +186,22 @@ def test_models_update_consistently(seed, dropout, model_func, kwargs, get_X):
model1 = get_updated_model()
model2 = get_updated_model()
assert_array_equal(get_all_params(model1), get_all_params(model2))
@pytest.mark.parametrize(
"model_func,kwargs",
[
(StaticVectors, {"nO": 128, "nM": 300}),
]
)
def test_empty_docs(model_func, kwargs):
nlp = English()
model = model_func(**kwargs).initialize()
# Test the layer can be called successfully with 0, 1 and 2 empty docs.
for n_docs in range(3):
docs = [nlp("") for _ in range(n_docs)]
# Test predict
_ = model.predict(docs)
# Test backprop
output, backprop = model.begin_update(docs)
_ = backprop(output)