mirror of https://github.com/explosion/spaCy.git
Fix empty input into StaticVectors layer (#6211)
* Add test for empty doc(s) * Fix empty check in staticvectors * Remove xfail * Update spacy/ml/staticvectors.py
This commit is contained in:
parent
2a17566da3
commit
cfb9770a94
|
@ -34,7 +34,7 @@ def StaticVectors(
|
|||
def forward(
|
||||
model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool
|
||||
) -> Tuple[Ragged, Callable]:
|
||||
if not len(docs):
|
||||
if not sum(len(doc) for doc in docs):
|
||||
return _handle_empty(model.ops, model.get_dim("nO"))
|
||||
key_attr = model.attrs["key_attr"]
|
||||
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
|
||||
|
|
|
@ -7,6 +7,7 @@ import numpy
|
|||
|
||||
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
|
||||
from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier
|
||||
from spacy.ml.staticvectors import StaticVectors
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.en.examples import sentences as EN_SENTENCES
|
||||
|
||||
|
@ -185,3 +186,22 @@ def test_models_update_consistently(seed, dropout, model_func, kwargs, get_X):
|
|||
model1 = get_updated_model()
|
||||
model2 = get_updated_model()
|
||||
assert_array_equal(get_all_params(model1), get_all_params(model2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_func,kwargs",
|
||||
[
|
||||
(StaticVectors, {"nO": 128, "nM": 300}),
|
||||
]
|
||||
)
|
||||
def test_empty_docs(model_func, kwargs):
|
||||
nlp = English()
|
||||
model = model_func(**kwargs).initialize()
|
||||
# Test the layer can be called successfully with 0, 1 and 2 empty docs.
|
||||
for n_docs in range(3):
|
||||
docs = [nlp("") for _ in range(n_docs)]
|
||||
# Test predict
|
||||
_ = model.predict(docs)
|
||||
# Test backprop
|
||||
output, backprop = model.begin_update(docs)
|
||||
_ = backprop(output)
|
||||
|
|
Loading…
Reference in New Issue