mirror of https://github.com/explosion/spaCy.git
Fix test for spancat (#9446)
* fix test for spancat * increase tolerance for almost equal checks * Update spacy/tests/test_models.py * Update spacy/tests/test_models.py
This commit is contained in:
parent
5e8e8525f0
commit
2e3d6b8b5a
|
@ -114,7 +114,7 @@ def test_make_spangroup(max_positive, nr_results):
|
||||||
doc = nlp.make_doc("Greater London")
|
doc = nlp.make_doc("Greater London")
|
||||||
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2])
|
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2])
|
||||||
indices = ngram_suggester([doc])[0].dataXd
|
indices = ngram_suggester([doc])[0].dataXd
|
||||||
assert_array_equal(indices, numpy.asarray([[0, 1], [1, 2], [0, 2]]))
|
assert_array_equal(OPS.to_numpy(indices), numpy.asarray([[0, 1], [1, 2], [0, 2]]))
|
||||||
labels = ["Thing", "City", "Person", "GreatCity"]
|
labels = ["Thing", "City", "Person", "GreatCity"]
|
||||||
scores = numpy.asarray(
|
scores = numpy.asarray(
|
||||||
[[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f"
|
[[0.2, 0.4, 0.3, 0.1], [0.1, 0.6, 0.2, 0.4], [0.8, 0.7, 0.3, 0.9]], dtype="f"
|
||||||
|
|
|
@ -49,8 +49,8 @@ def test_issue5551(textcat_config):
|
||||||
# All results should be the same because of the fixed seed
|
# All results should be the same because of the fixed seed
|
||||||
assert len(results) == 3
|
assert len(results) == 3
|
||||||
ops = get_current_ops()
|
ops = get_current_ops()
|
||||||
assert_almost_equal(ops.to_numpy(results[0]), ops.to_numpy(results[1]))
|
assert_almost_equal(ops.to_numpy(results[0]), ops.to_numpy(results[1]), decimal=5)
|
||||||
assert_almost_equal(ops.to_numpy(results[0]), ops.to_numpy(results[2]))
|
assert_almost_equal(ops.to_numpy(results[0]), ops.to_numpy(results[2]), decimal=5)
|
||||||
|
|
||||||
|
|
||||||
def test_issue5838():
|
def test_issue5838():
|
||||||
|
|
|
@ -193,6 +193,7 @@ def test_models_update_consistently(seed, dropout, model_func, kwargs, get_X):
|
||||||
assert_array_almost_equal(
|
assert_array_almost_equal(
|
||||||
model1.ops.to_numpy(get_all_params(model1)),
|
model1.ops.to_numpy(get_all_params(model1)),
|
||||||
model2.ops.to_numpy(get_all_params(model2)),
|
model2.ops.to_numpy(get_all_params(model2)),
|
||||||
|
decimal=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue