Fix spancat tests on GPU (#8872)

* Fix spancat tests on GPU

* Fix more spancat tests
This commit is contained in:
Adriane Boyd 2021-08-04 14:29:43 +02:00 committed by GitHub
parent 77d698dcae
commit fa2e7a4bbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 12 deletions

View File

@ -1,9 +1,11 @@
import pytest import pytest
from numpy.testing import assert_equal from numpy.testing import assert_equal, assert_array_equal
from thinc.api import get_current_ops
from spacy.language import Language from spacy.language import Language
from spacy.training import Example from spacy.training import Example
from spacy.util import fix_random_seed, registry from spacy.util import fix_random_seed, registry
OPS = get_current_ops()
SPAN_KEY = "labeled_spans" SPAN_KEY = "labeled_spans"
@ -116,12 +118,12 @@ def test_ngram_suggester(en_tokenizer):
for span in spans: for span in spans:
assert 0 <= span[0] < len(doc) assert 0 <= span[0] < len(doc)
assert 0 < span[1] <= len(doc) assert 0 < span[1] <= len(doc)
spans_set.add((span[0], span[1])) spans_set.add((int(span[0]), int(span[1])))
# spans are unique # spans are unique
assert spans.shape[0] == len(spans_set) assert spans.shape[0] == len(spans_set)
offset += ngrams.lengths[i] offset += ngrams.lengths[i]
# the number of spans is correct # the number of spans is correct
assert_equal(ngrams.lengths, [max(0, len(doc) - (size - 1)) for doc in docs]) assert_array_equal(OPS.to_numpy(ngrams.lengths), [max(0, len(doc) - (size - 1)) for doc in docs])
# test 1-3-gram suggestions # test 1-3-gram suggestions
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2, 3]) ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1, 2, 3])
@ -129,9 +131,9 @@ def test_ngram_suggester(en_tokenizer):
en_tokenizer(text) for text in ["a", "a b", "a b c", "a b c d", "a b c d e"] en_tokenizer(text) for text in ["a", "a b", "a b c", "a b c d", "a b c d e"]
] ]
ngrams = ngram_suggester(docs) ngrams = ngram_suggester(docs)
assert_equal(ngrams.lengths, [1, 3, 6, 9, 12]) assert_array_equal(OPS.to_numpy(ngrams.lengths), [1, 3, 6, 9, 12])
assert_equal( assert_array_equal(
ngrams.data, OPS.to_numpy(ngrams.data),
[ [
# doc 0 # doc 0
[0, 1], [0, 1],
@ -176,13 +178,13 @@ def test_ngram_suggester(en_tokenizer):
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1]) ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1])
docs = [en_tokenizer(text) for text in ["", "a", ""]] docs = [en_tokenizer(text) for text in ["", "a", ""]]
ngrams = ngram_suggester(docs) ngrams = ngram_suggester(docs)
assert_equal(ngrams.lengths, [len(doc) for doc in docs]) assert_array_equal(OPS.to_numpy(ngrams.lengths), [len(doc) for doc in docs])
# test all empty docs # test all empty docs
ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1]) ngram_suggester = registry.misc.get("spacy.ngram_suggester.v1")(sizes=[1])
docs = [en_tokenizer(text) for text in ["", "", ""]] docs = [en_tokenizer(text) for text in ["", "", ""]]
ngrams = ngram_suggester(docs) ngrams = ngram_suggester(docs)
assert_equal(ngrams.lengths, [len(doc) for doc in docs]) assert_array_equal(OPS.to_numpy(ngrams.lengths), [len(doc) for doc in docs])
def test_ngram_sizes(en_tokenizer): def test_ngram_sizes(en_tokenizer):
@ -195,12 +197,12 @@ def test_ngram_sizes(en_tokenizer):
] ]
ngrams_1 = size_suggester(docs) ngrams_1 = size_suggester(docs)
ngrams_2 = range_suggester(docs) ngrams_2 = range_suggester(docs)
assert_equal(ngrams_1.lengths, [1, 3, 6, 9, 12]) assert_array_equal(OPS.to_numpy(ngrams_1.lengths), [1, 3, 6, 9, 12])
assert_equal(ngrams_1.lengths, ngrams_2.lengths) assert_array_equal(OPS.to_numpy(ngrams_1.lengths), OPS.to_numpy(ngrams_2.lengths))
assert_equal(ngrams_1.data, ngrams_2.data) assert_array_equal(OPS.to_numpy(ngrams_1.data), OPS.to_numpy(ngrams_2.data))
# one more variation # one more variation
suggester_factory = registry.misc.get("spacy.ngram_range_suggester.v1") suggester_factory = registry.misc.get("spacy.ngram_range_suggester.v1")
range_suggester = suggester_factory(min_size=2, max_size=4) range_suggester = suggester_factory(min_size=2, max_size=4)
ngrams_3 = range_suggester(docs) ngrams_3 = range_suggester(docs)
assert_equal(ngrams_3.lengths, [0, 1, 3, 6, 9]) assert_array_equal(OPS.to_numpy(ngrams_3.lengths), [0, 1, 3, 6, 9])