Fix test for spancat (#9446)

* fix test for spancat

* increase tolerance for almost equal checks

* Update spacy/tests/test_models.py

* Update spacy/tests/test_models.py
This commit is contained in:
Sofie Van Landeghem 2021-10-13 10:47:56 +02:00 committed by GitHub
parent 5e8e8525f0
commit 2e3d6b8b5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 3 deletions

View File

@ -114,7 +114,7 @@ def test_make_spangroup(max_positive, nr_results):
doc = nlp.make_doc("Greater London") doc = nlp.make_doc("Greater London")
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2]) ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2])
indices = ngram_suggester([doc])[0].dataXd indices = ngram_suggester([doc])[0].dataXd
assert_array_equal(indices, numpy.asarray([[0, 1], [1, 2], [0, 2]])) assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]]))
labels = ["Thing", "City", "Person", "GreatCity"] labels = ["Thing", "City", "Person", "GreatCity"]
scores = numpy.asarray( scores = numpy.asarray(
[[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f" [[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f"

View File

@ -49,8 +49,8 @@ def test_issue5551(textcat_config):
# All results should be the same because of the fixed seed # All results should be the same because of the fixed seed
assert len(results) == 3 assert len(results) == 3
ops = get_current_ops() ops = get_current_ops()
assert_almost_equal(ops.to_numpy(results[0]), ops.to_numpy(results[1])) assert_almost_equal(ops.to_numpy(results[0]), ops.to_numpy(results[1]), decimal=5)
assert_almost_equal(ops.to_numpy(results[0]), ops.to_numpy(results[2])) assert_almost_equal(ops.to_numpy(results[0]), ops.to_numpy(results[2]), decimal=5)
def test_issue5838(): def test_issue5838():

View File

@ -193,6 +193,7 @@ def test_models_update_consistently(seed, dropout, model_func, kwargs, get_X):
assert_array_almost_equal( assert_array_almost_equal(
model1.ops.to_numpy(get_all_params(model1)), model1.ops.to_numpy(get_all_params(model1)),
model2.ops.to_numpy(get_all_params(model2)), model2.ops.to_numpy(get_all_params(model2)),
decimal=5,
) )