extending algorithm to deal better with edge cases

This commit is contained in:
svlandeg 2020-06-02 22:05:08 +02:00
parent f2e162fc60
commit aa6271b16c
2 changed files with 42 additions and 11 deletions

View File

@ -11,13 +11,29 @@ from spacy.util import minibatch_by_words
[ [
([400, 400, 199], [3]), ([400, 400, 199], [3]),
([400, 400, 199, 3], [4]), ([400, 400, 199, 3], [4]),
([400, 400, 199, 3, 1], [5]),
([400, 400, 199, 3, 200], [3, 2]), ([400, 400, 199, 3, 200], [3, 2]),
([400, 400, 199, 3, 1], [5]),
([400, 400, 199, 3, 1, 1500], [5]), # 1500 will be discarded
([400, 400, 199, 3, 1, 200], [3, 3]), ([400, 400, 199, 3, 1, 200], [3, 3]),
([400, 400, 199, 3, 1, 999], [3, 3]),
([400, 400, 199, 3, 1, 999, 999], [3, 2, 1, 1]),
([1, 2, 999], [3]),
([1, 2, 999, 1], [4]),
([1, 200, 999, 1], [2, 2]),
([1, 999, 200, 1], [2, 2]),
], ],
) )
def test_util_minibatch(doc_sizes, expected_batches): def test_util_minibatch(doc_sizes, expected_batches):
docs = [get_doc(doc_size) for doc_size in doc_sizes] docs = [get_doc(doc_size) for doc_size in doc_sizes]
examples = [Example(doc=doc) for doc in docs] examples = [Example(doc=doc) for doc in docs]
batches = list(minibatch_by_words(examples=examples, size=1000)) tol = 0.2
batch_size = 1000
batches = list(minibatch_by_words(examples=examples, size=batch_size, tolerance=tol, discard_oversize=True))
assert [len(batch) for batch in batches] == expected_batches assert [len(batch) for batch in batches] == expected_batches
max_size = batch_size + batch_size * tol
for batch in batches:
assert sum([len(example.doc) for example in batch]) < max_size

View File

@ -671,24 +671,24 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
tol_size = target_size * tolerance tol_size = target_size * tolerance
batch = [] batch = []
overflow = [] overflow = []
current_size = 0 batch_size = 0
overflow_size = 0 overflow_size = 0
for example in examples: for example in examples:
n_words = count_words(example.doc) n_words = count_words(example.doc)
# if the current example exceeds the batch size, it is returned separately # if the current example exceeds the maximum batch size, it is returned separately
# but only if discard_oversize=False. # but only if discard_oversize=False.
if n_words > target_size + tol_size: if n_words > target_size + tol_size:
if not discard_oversize: if not discard_oversize:
yield [example] yield [example]
# add the example to the current batch if there's no overflow yet and it still fits # add the example to the current batch if there's no overflow yet and it still fits
elif overflow_size == 0 and (current_size + n_words) < target_size: elif overflow_size == 0 and (batch_size + n_words) <= target_size:
batch.append(example) batch.append(example)
current_size += n_words batch_size += n_words
# add the example to the overflow buffer if it fits in the tolerance margin # add the example to the overflow buffer if it fits in the tolerance margin
elif (current_size + overflow_size + n_words) < (target_size + tol_size): elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
overflow.append(example) overflow.append(example)
overflow_size += n_words overflow_size += n_words
@ -697,14 +697,29 @@ def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_o
yield batch yield batch
target_size = next(size_) target_size = next(size_)
tol_size = target_size * tolerance tol_size = target_size * tolerance
# In theory it may happen that the current example + overflow examples now exceed the new
# target_size, but that seems like an unimportant edge case if batch sizes are variable?
batch = overflow batch = overflow
batch.append(example) batch_size = overflow_size
current_size = overflow_size + n_words
overflow = [] overflow = []
overflow_size = 0 overflow_size = 0
# this example still fits
if (batch_size + n_words) <= target_size:
batch.append(example)
batch_size += n_words
# this example fits in overflow
elif (batch_size + n_words) <= (target_size + tol_size):
overflow.append(example)
overflow_size += n_words
# this example does not fit with the previous overflow: start another new batch
else:
yield batch
target_size = next(size_)
tol_size = target_size * tolerance
batch = [example]
batch_size = n_words
# yield the final batch # yield the final batch
if batch: if batch:
batch.extend(overflow) batch.extend(overflow)