diff --git a/spacy/ml/staticvectors.py b/spacy/ml/staticvectors.py index c77247d33..da731dadb 100644 --- a/spacy/ml/staticvectors.py +++ b/spacy/ml/staticvectors.py @@ -34,7 +34,7 @@ def StaticVectors( def forward( model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool ) -> Tuple[Ragged, Callable]: - if not len(docs): + if not sum(len(doc) for doc in docs): return _handle_empty(model.ops, model.get_dim("nO")) key_attr = model.attrs["key_attr"] W = cast(Floats2d, model.ops.as_contig(model.get_param("W"))) diff --git a/spacy/tests/test_models.py b/spacy/tests/test_models.py index 17408f7e8..8ca7f8b66 100644 --- a/spacy/tests/test_models.py +++ b/spacy/tests/test_models.py @@ -7,6 +7,7 @@ import numpy from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier +from spacy.ml.staticvectors import StaticVectors from spacy.lang.en import English from spacy.lang.en.examples import sentences as EN_SENTENCES @@ -185,3 +186,22 @@ def test_models_update_consistently(seed, dropout, model_func, kwargs, get_X): model1 = get_updated_model() model2 = get_updated_model() assert_array_equal(get_all_params(model1), get_all_params(model2)) + + +@pytest.mark.parametrize( + "model_func,kwargs", + [ + (StaticVectors, {"nO": 128, "nM": 300}), + ] +) +def test_empty_docs(model_func, kwargs): + nlp = English() + model = model_func(**kwargs).initialize() + # Test the layer can be called successfully with 0, 1 and 2 empty docs. + for n_docs in range(3): + docs = [nlp("") for _ in range(n_docs)] + # Test predict + _ = model.predict(docs) + # Test backprop + output, backprop = model.begin_update(docs) + _ = backprop(output)