mirror of https://github.com/explosion/spaCy.git
add test for minibatch util
This commit is contained in:
parent
5b350a6c99
commit
85b0597ed5
|
@ -0,0 +1,23 @@
|
||||||
|
import pytest
|
||||||
|
from spacy.gold import Example
|
||||||
|
|
||||||
|
from .util import get_doc
|
||||||
|
|
||||||
|
from spacy.util import minibatch_by_words
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"doc_sizes, expected_batches",
|
||||||
|
[
|
||||||
|
([400, 400, 199], [3]),
|
||||||
|
([400, 400, 199, 3], [4]),
|
||||||
|
([400, 400, 199, 3, 250], [3, 2]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_util_minibatch(doc_sizes, expected_batches):
|
||||||
|
docs = [get_doc(doc_size) for doc_size in doc_sizes]
|
||||||
|
|
||||||
|
examples = [Example(doc=doc) for doc in docs]
|
||||||
|
|
||||||
|
batches = list(minibatch_by_words(examples=examples, size=1000))
|
||||||
|
assert [len(batch) for batch in batches] == expected_batches
|
|
@ -92,6 +92,13 @@ def get_batch(batch_size):
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
def get_doc(n_words):
|
||||||
|
vocab = Vocab()
|
||||||
|
# Make the words numbers, so that they're easy to track.
|
||||||
|
numbers = [str(i) for i in range(0, n_words)]
|
||||||
|
return Doc(vocab, words=numbers)
|
||||||
|
|
||||||
|
|
||||||
def apply_transition_sequence(parser, doc, sequence):
|
def apply_transition_sequence(parser, doc, sequence):
|
||||||
"""Perform a series of pre-specified transitions, to put the parser in a
|
"""Perform a series of pre-specified transitions, to put the parser in a
|
||||||
desired state."""
|
desired state."""
|
||||||
|
|
Loading…
Reference in New Issue