mirror of https://github.com/explosion/spaCy.git
Fix empty input into StaticVectors layer (#6211)
* Add test for empty doc(s) * Fix empty check in staticvectors * Remove xfail * Update spacy/ml/staticvectors.py
This commit is contained in:
parent
2a17566da3
commit
cfb9770a94
|
@ -34,7 +34,7 @@ def StaticVectors(
|
||||||
def forward(
|
def forward(
|
||||||
model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool
|
model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool
|
||||||
) -> Tuple[Ragged, Callable]:
|
) -> Tuple[Ragged, Callable]:
|
||||||
if not len(docs):
|
if not sum(len(doc) for doc in docs):
|
||||||
return _handle_empty(model.ops, model.get_dim("nO"))
|
return _handle_empty(model.ops, model.get_dim("nO"))
|
||||||
key_attr = model.attrs["key_attr"]
|
key_attr = model.attrs["key_attr"]
|
||||||
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
|
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
|
||||||
|
|
|
@ -7,6 +7,7 @@ import numpy
|
||||||
|
|
||||||
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
|
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
|
||||||
from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier
|
from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier
|
||||||
|
from spacy.ml.staticvectors import StaticVectors
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.lang.en.examples import sentences as EN_SENTENCES
|
from spacy.lang.en.examples import sentences as EN_SENTENCES
|
||||||
|
|
||||||
|
@ -185,3 +186,22 @@ def test_models_update_consistently(seed, dropout, model_func, kwargs, get_X):
|
||||||
model1 = get_updated_model()
|
model1 = get_updated_model()
|
||||||
model2 = get_updated_model()
|
model2 = get_updated_model()
|
||||||
assert_array_equal(get_all_params(model1), get_all_params(model2))
|
assert_array_equal(get_all_params(model1), get_all_params(model2))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_func,kwargs",
|
||||||
|
[
|
||||||
|
(StaticVectors, {"nO": 128, "nM": 300}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_empty_docs(model_func, kwargs):
|
||||||
|
nlp = English()
|
||||||
|
model = model_func(**kwargs).initialize()
|
||||||
|
# Test the layer can be called successfully with 0, 1 and 2 empty docs.
|
||||||
|
for n_docs in range(3):
|
||||||
|
docs = [nlp("") for _ in range(n_docs)]
|
||||||
|
# Test predict
|
||||||
|
_ = model.predict(docs)
|
||||||
|
# Test backprop
|
||||||
|
output, backprop = model.begin_update(docs)
|
||||||
|
_ = backprop(output)
|
||||||
|
|
Loading…
Reference in New Issue