spaCy/spacy/tests/test_util.py

66 lines
2.1 KiB
Python
Raw Normal View History

2020-06-02 16:26:21 +00:00
import pytest
from spacy.gold import Example
2020-06-02 20:24:57 +00:00
from .util import get_random_doc
2020-06-02 16:26:21 +00:00
from spacy.util import minibatch_by_words
@pytest.mark.parametrize(
"doc_sizes, expected_batches",
[
([400, 400, 199], [3]),
([400, 400, 199, 3], [4]),
2020-06-02 17:47:30 +00:00
([400, 400, 199, 3, 200], [3, 2]),
([400, 400, 199, 3, 1], [5]),
2020-06-20 12:15:04 +00:00
([400, 400, 199, 3, 1, 1500], [5]), # 1500 will be discarded
2020-06-02 17:47:30 +00:00
([400, 400, 199, 3, 1, 200], [3, 3]),
([400, 400, 199, 3, 1, 999], [3, 3]),
([400, 400, 199, 3, 1, 999, 999], [3, 2, 1, 1]),
([1, 2, 999], [3]),
([1, 2, 999, 1], [4]),
([1, 200, 999, 1], [2, 2]),
([1, 999, 200, 1], [2, 2]),
2020-06-02 16:26:21 +00:00
],
)
def test_util_minibatch(doc_sizes, expected_batches):
2020-06-02 20:24:57 +00:00
docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
2020-06-02 16:26:21 +00:00
examples = [Example(doc=doc) for doc in docs]
tol = 0.2
batch_size = 1000
2020-06-20 12:15:04 +00:00
batches = list(
minibatch_by_words(
examples=examples, size=batch_size, tolerance=tol, discard_oversize=True
)
)
2020-06-02 16:26:21 +00:00
assert [len(batch) for batch in batches] == expected_batches
max_size = batch_size + batch_size * tol
for batch in batches:
assert sum([len(example.doc) for example in batch]) < max_size
@pytest.mark.parametrize(
"doc_sizes, expected_batches",
[
([400, 4000, 199], [1, 2]),
([400, 400, 199, 3000, 200], [1, 4]),
([400, 400, 199, 3, 1, 1500], [1, 5]),
([400, 400, 199, 3000, 2000, 200, 200], [1, 1, 3, 2]),
([1, 2, 9999], [1, 2]),
([2000, 1, 2000, 1, 1, 1, 2000], [1, 1, 1, 4]),
],
)
def test_util_minibatch_oversize(doc_sizes, expected_batches):
""" Test that oversized documents are returned in their own batch"""
2020-06-02 20:24:57 +00:00
docs = [get_random_doc(doc_size) for doc_size in doc_sizes]
examples = [Example(doc=doc) for doc in docs]
tol = 0.2
batch_size = 1000
2020-06-20 12:15:04 +00:00
batches = list(
minibatch_by_words(
examples=examples, size=batch_size, tolerance=tol, discard_oversize=False
)
)
assert [len(batch) for batch in batches] == expected_batches