add test for minibatch util

This commit is contained in:
svlandeg 2020-06-02 18:26:21 +02:00
parent 5b350a6c99
commit 85b0597ed5
2 changed files with 30 additions and 0 deletions

23
spacy/tests/test_util.py Normal file
View File

@ -0,0 +1,23 @@
import pytest
from spacy.gold import Example
from .util import get_doc
from spacy.util import minibatch_by_words
@pytest.mark.parametrize(
"doc_sizes, expected_batches",
[
([400, 400, 199], [3]),
([400, 400, 199, 3], [4]),
([400, 400, 199, 3, 250], [3, 2]),
],
)
def test_util_minibatch(doc_sizes, expected_batches):
docs = [get_doc(doc_size) for doc_size in doc_sizes]
examples = [Example(doc=doc) for doc in docs]
batches = list(minibatch_by_words(examples=examples, size=1000))
assert [len(batch) for batch in batches] == expected_batches

View File

@ -92,6 +92,13 @@ def get_batch(batch_size):
return docs return docs
def get_doc(n_words):
vocab = Vocab()
# Make the words numbers, so that they're easy to track.
numbers = [str(i) for i in range(0, n_words)]
return Doc(vocab, words=numbers)
def apply_transition_sequence(parser, doc, sequence): def apply_transition_sequence(parser, doc, sequence):
"""Perform a series of pre-specified transitions, to put the parser in a """Perform a series of pre-specified transitions, to put the parser in a
desired state.""" desired state."""