slightly more challenging unit test

This commit is contained in:
svlandeg 2020-06-02 19:47:30 +02:00
parent 6651fafd5c
commit 6208d322d3
2 changed files with 5 additions and 5 deletions

View File

@ -12,8 +12,8 @@ from spacy.util import minibatch_by_words
([400, 400, 199], [3]), ([400, 400, 199], [3]),
([400, 400, 199, 3], [4]), ([400, 400, 199, 3], [4]),
([400, 400, 199, 3, 1], [5]), ([400, 400, 199, 3, 1], [5]),
([400, 400, 199, 3, 250], [3, 2]), ([400, 400, 199, 3, 200], [3, 2]),
([400, 400, 199, 3, 1, 250], [3, 3]), ([400, 400, 199, 3, 1, 200], [3, 3]),
], ],
) )
def test_util_minibatch(doc_sizes, expected_batches): def test_util_minibatch(doc_sizes, expected_batches):

View File

@ -682,13 +682,13 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
if not discard_oversize: if not discard_oversize:
yield [example] yield [example]
# add the example to the current batch if it still fits # add the example to the current batch if it still fits and there's no overflow yet
elif (current_size + n_words) < target_size: elif overflow_size == 0 and (current_size + n_words) < target_size:
batch.append(example) batch.append(example)
current_size += n_words current_size += n_words
# add the example to the overflow buffer if it fits in the tolerance margins # add the example to the overflow buffer if it fits in the tolerance margins
elif (current_size + n_words) < (target_size + tol_size): elif (current_size + overflow_size + n_words) < (target_size + tol_size):
overflow.append(example) overflow.append(example)
overflow_size += n_words overflow_size += n_words