mirror of https://github.com/explosion/spaCy.git
Fix spancat tests on GPU (#8872)
* Fix spancat tests on GPU * Fix more spancat tests
This commit is contained in:
parent
77d698dcae
commit
fa2e7a4bbf
|
@ -1,9 +1,11 @@
|
|||
import pytest
|
||||
from numpy.testing import assert_equal
|
||||
from numpy.testing import assert_equal, assert_array_equal
|
||||
from thinc.api import get_current_ops
|
||||
from spacy.language import Language
|
||||
from spacy.training import Example
|
||||
from spacy.util import fix_random_seed, registry
|
||||
|
||||
OPS = get_current_ops()
|
||||
|
||||
SPAN_KEY = "labeled_spans"
|
||||
|
||||
|
@ -116,12 +118,12 @@ def test_ngram_suggester(en_tokenizer):
|
|||
for span in spans:
|
||||
assert 0 <= span[0] < len(doc)
|
||||
assert 0 < span[1] <= len(doc)
|
||||
spans_set.add((span[0], span[1]))
|
||||
spans_set.add((int(span[0]), int(span[1])))
|
||||
# spans are unique
|
||||
assert spans.shape[0] == len(spans_set)
|
||||
offset += ngrams.lengths[i]
|
||||
# the number of spans is correct
|
||||
assert_equal(ngrams.lengths, [max(0, len(doc) - (size - 1)) for doc in docs])
|
||||
assert_array_equal(OPS.to_numpy(ngrams.lengths), [max(0, len(doc) - (size - 1)) for doc in docs])
|
||||
|
||||
# test 1-3-gram suggestions
|
||||
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2, 3])
|
||||
|
@ -129,9 +131,9 @@ def test_ngram_suggester(en_tokenizer):
|
|||
en_tokenizer(text) for text in ["a", "a b", "a b c", "a b c d", "a b c d e"]
|
||||
]
|
||||
ngrams = ngram_suggester(docs)
|
||||
assert_equal(ngrams.lengths, [1, 3, 6, 9, 12])
|
||||
assert_equal(
|
||||
ngrams.data,
|
||||
assert_array_equal(OPS.to_numpy(ngrams.lengths), [1, 3, 6, 9, 12])
|
||||
assert_array_equal(
|
||||
OPS.to_numpy(ngrams.data),
|
||||
[
|
||||
# doc 0
|
||||
[0, 1],
|
||||
|
@ -176,13 +178,13 @@ def test_ngram_suggester(en_tokenizer):
|
|||
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1])
|
||||
docs = [en_tokenizer(text) for text in ["", "a", ""]]
|
||||
ngrams = ngram_suggester(docs)
|
||||
assert_equal(ngrams.lengths, [len(doc) for doc in docs])
|
||||
assert_array_equal(OPS.to_numpy(ngrams.lengths), [len(doc) for doc in docs])
|
||||
|
||||
# test all empty docs
|
||||
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1])
|
||||
docs = [en_tokenizer(text) for text in ["", "", ""]]
|
||||
ngrams = ngram_suggester(docs)
|
||||
assert_equal(ngrams.lengths, [len(doc) for doc in docs])
|
||||
assert_array_equal(OPS.to_numpy(ngrams.lengths), [len(doc) for doc in docs])
|
||||
|
||||
|
||||
def test_ngram_sizes(en_tokenizer):
|
||||
|
@ -195,12 +197,12 @@ def test_ngram_sizes(en_tokenizer):
|
|||
]
|
||||
ngrams_1 = size_suggester(docs)
|
||||
ngrams_2 = range_suggester(docs)
|
||||
assert_equal(ngrams_1.lengths, [1, 3, 6, 9, 12])
|
||||
assert_equal(ngrams_1.lengths, ngrams_2.lengths)
|
||||
assert_equal(ngrams_1.data, ngrams_2.data)
|
||||
assert_array_equal(OPS.to_numpy(ngrams_1.lengths), [1, 3, 6, 9, 12])
|
||||
assert_array_equal(OPS.to_numpy(ngrams_1.lengths), OPS.to_numpy(ngrams_2.lengths))
|
||||
assert_array_equal(OPS.to_numpy(ngrams_1.data), OPS.to_numpy(ngrams_2.data))
|
||||
|
||||
# one more variation
|
||||
suggester_factory = registry.misc.get("spacy.ngram_range_suggester.v1")
|
||||
range_suggester = suggester_factory(min_size=2, max_size=4)
|
||||
ngrams_3 = range_suggester(docs)
|
||||
assert_equal(ngrams_3.lengths, [0, 1, 3, 6, 9])
|
||||
assert_array_equal(OPS.to_numpy(ngrams_3.lengths), [0, 1, 3, 6, 9])
|
||||
|
|
Loading…
Reference in New Issue