mirror of https://github.com/explosion/spaCy.git
Fix spancat tests on GPU (#8872)
* Fix spancat tests on GPU * Fix more spancat tests
This commit is contained in:
parent
77d698dcae
commit
fa2e7a4bbf
|
@ -1,9 +1,11 @@
|
||||||
import pytest
|
import pytest
|
||||||
from numpy.testing import assert_equal
|
from numpy.testing import assert_equal, assert_array_equal
|
||||||
|
from thinc.api import get_current_ops
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
from spacy.util import fix_random_seed, registry
|
from spacy.util import fix_random_seed, registry
|
||||||
|
|
||||||
|
OPS = get_current_ops()
|
||||||
|
|
||||||
SPAN_KEY = "labeled_spans"
|
SPAN_KEY = "labeled_spans"
|
||||||
|
|
||||||
|
@ -116,12 +118,12 @@ def test_ngram_suggester(en_tokenizer):
|
||||||
for span in spans:
|
for span in spans:
|
||||||
assert 0 <= span[0] < len(doc)
|
assert 0 <= span[0] < len(doc)
|
||||||
assert 0 < span[1] <= len(doc)
|
assert 0 < span[1] <= len(doc)
|
||||||
spans_set.add((span[0], span[1]))
|
spans_set.add((int(span[0]), int(span[1])))
|
||||||
# spans are unique
|
# spans are unique
|
||||||
assert spans.shape[0] == len(spans_set)
|
assert spans.shape[0] == len(spans_set)
|
||||||
offset += ngrams.lengths[i]
|
offset += ngrams.lengths[i]
|
||||||
# the number of spans is correct
|
# the number of spans is correct
|
||||||
assert_equal(ngrams.lengths, [max(0, len(doc) - (size - 1)) for doc in docs])
|
assert_array_equal(OPS.to_numpy(ngrams.lengths), [max(0, len(doc) - (size - 1)) for doc in docs])
|
||||||
|
|
||||||
# test 1-3-gram suggestions
|
# test 1-3-gram suggestions
|
||||||
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2, 3])
|
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2, 3])
|
||||||
|
@ -129,9 +131,9 @@ def test_ngram_suggester(en_tokenizer):
|
||||||
en_tokenizer(text) for text in ["a", "a b", "a b c", "a b c d", "a b c d e"]
|
en_tokenizer(text) for text in ["a", "a b", "a b c", "a b c d", "a b c d e"]
|
||||||
]
|
]
|
||||||
ngrams = ngram_suggester(docs)
|
ngrams = ngram_suggester(docs)
|
||||||
assert_equal(ngrams.lengths, [1, 3, 6, 9, 12])
|
assert_array_equal(OPS.to_numpy(ngrams.lengths), [1, 3, 6, 9, 12])
|
||||||
assert_equal(
|
assert_array_equal(
|
||||||
ngrams.data,
|
OPS.to_numpy(ngrams.data),
|
||||||
[
|
[
|
||||||
# doc 0
|
# doc 0
|
||||||
[0, 1],
|
[0, 1],
|
||||||
|
@ -176,13 +178,13 @@ def test_ngram_suggester(en_tokenizer):
|
||||||
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1])
|
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1])
|
||||||
docs = [en_tokenizer(text) for text in ["", "a", ""]]
|
docs = [en_tokenizer(text) for text in ["", "a", ""]]
|
||||||
ngrams = ngram_suggester(docs)
|
ngrams = ngram_suggester(docs)
|
||||||
assert_equal(ngrams.lengths, [len(doc) for doc in docs])
|
assert_array_equal(OPS.to_numpy(ngrams.lengths), [len(doc) for doc in docs])
|
||||||
|
|
||||||
# test all empty docs
|
# test all empty docs
|
||||||
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1])
|
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1])
|
||||||
docs = [en_tokenizer(text) for text in ["", "", ""]]
|
docs = [en_tokenizer(text) for text in ["", "", ""]]
|
||||||
ngrams = ngram_suggester(docs)
|
ngrams = ngram_suggester(docs)
|
||||||
assert_equal(ngrams.lengths, [len(doc) for doc in docs])
|
assert_array_equal(OPS.to_numpy(ngrams.lengths), [len(doc) for doc in docs])
|
||||||
|
|
||||||
|
|
||||||
def test_ngram_sizes(en_tokenizer):
|
def test_ngram_sizes(en_tokenizer):
|
||||||
|
@ -195,12 +197,12 @@ def test_ngram_sizes(en_tokenizer):
|
||||||
]
|
]
|
||||||
ngrams_1 = size_suggester(docs)
|
ngrams_1 = size_suggester(docs)
|
||||||
ngrams_2 = range_suggester(docs)
|
ngrams_2 = range_suggester(docs)
|
||||||
assert_equal(ngrams_1.lengths, [1, 3, 6, 9, 12])
|
assert_array_equal(OPS.to_numpy(ngrams_1.lengths), [1, 3, 6, 9, 12])
|
||||||
assert_equal(ngrams_1.lengths, ngrams_2.lengths)
|
assert_array_equal(OPS.to_numpy(ngrams_1.lengths), OPS.to_numpy(ngrams_2.lengths))
|
||||||
assert_equal(ngrams_1.data, ngrams_2.data)
|
assert_array_equal(OPS.to_numpy(ngrams_1.data), OPS.to_numpy(ngrams_2.data))
|
||||||
|
|
||||||
# one more variation
|
# one more variation
|
||||||
suggester_factory = registry.misc.get("spacy.ngram_range_suggester.v1")
|
suggester_factory = registry.misc.get("spacy.ngram_range_suggester.v1")
|
||||||
range_suggester = suggester_factory(min_size=2, max_size=4)
|
range_suggester = suggester_factory(min_size=2, max_size=4)
|
||||||
ngrams_3 = range_suggester(docs)
|
ngrams_3 = range_suggester(docs)
|
||||||
assert_equal(ngrams_3.lengths, [0, 1, 3, 6, 9])
|
assert_array_equal(OPS.to_numpy(ngrams_3.lengths), [0, 1, 3, 6, 9])
|
||||||
|
|
Loading…
Reference in New Issue