diff --git a/spacy/tests/test_util.py b/spacy/tests/test_util.py index a0c6ab6c0..207805c81 100644 --- a/spacy/tests/test_util.py +++ b/spacy/tests/test_util.py @@ -11,13 +11,29 @@ from spacy.util import minibatch_by_words [ ([400, 400, 199], [3]), ([400, 400, 199, 3], [4]), - ([400, 400, 199, 3, 1], [5]), ([400, 400, 199, 3, 200], [3, 2]), + + ([400, 400, 199, 3, 1], [5]), + ([400, 400, 199, 3, 1, 1500], [5]), # 1500 will be discarded ([400, 400, 199, 3, 1, 200], [3, 3]), + ([400, 400, 199, 3, 1, 999], [3, 3]), + ([400, 400, 199, 3, 1, 999, 999], [3, 2, 1, 1]), + + ([1, 2, 999], [3]), + ([1, 2, 999, 1], [4]), + ([1, 200, 999, 1], [2, 2]), + ([1, 999, 200, 1], [2, 2]), ], ) def test_util_minibatch(doc_sizes, expected_batches): docs = [get_doc(doc_size) for doc_size in doc_sizes] examples = [Example(doc=doc) for doc in docs] - batches = list(minibatch_by_words(examples=examples, size=1000)) + tol = 0.2 + batch_size = 1000 + batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=True)) assert [len(batch) for batch in batches] == expected_batches + + max_size = batch_size + batch_size * tol + for batch in batches: + assert sum([len(example.doc) for example in batch]) < max_size + diff --git a/spacy/util.py b/spacy/util.py index 598545b84..2d732e2b7 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -671,24 +671,24 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o tol_size = target_size * tolerance batch = [] overflow = [] - current_size = 0 + batch_size = 0 overflow_size = 0 for example in examples: n_words = count_words(example.doc) - # if the current example exceeds the batch size, it is returned separately + # if the current example exceeds the maximum batch size, it is returned separately # but only if discard_oversize=False. if n_words > target_size + tol_size: if not discard_oversize: yield [example] # add the example to the current batch if there's no overflow yet and it still fits - elif overflow_size == 0 and (current_size + n_words) < target_size: + elif overflow_size == 0 and (batch_size + n_words) <= target_size: batch.append(example) - current_size += n_words + batch_size += n_words # add the example to the overflow buffer if it fits in the tolerance margin - elif (current_size + overflow_size + n_words) < (target_size + tol_size): + elif (batch_size + overflow_size + n_words) <= (target_size + tol_size): overflow.append(example) overflow_size += n_words @@ -697,14 +697,29 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o yield batch target_size = next(size_) tol_size = target_size * tolerance - # In theory it may happen that the current example + overflow examples now exceed the new - # target_size, but that seems like an unimportant edge case if batch sizes are variable? batch = overflow - batch.append(example) - current_size = overflow_size + n_words + batch_size = overflow_size overflow = [] overflow_size = 0 + # this example still fits + if (batch_size + n_words) <= target_size: + batch.append(example) + batch_size += n_words + + # this example fits in overflow + elif (batch_size + n_words) <= (target_size + tol_size): + overflow.append(example) + overflow_size += n_words + + # this example does not fit with the previous overflow: start another new batch + else: + yield batch + target_size = next(size_) + tol_size = target_size * tolerance + batch = [example] + batch_size = n_words + # yield the final batch if batch: batch.extend(overflow)