mirror of https://github.com/explosion/spaCy.git
rewrite minibatch_by_words function
This commit is contained in:
parent
ec52e7f886
commit
fdfd822936
|
@ -656,45 +656,47 @@ def decaying(start, stop, decay):
|
||||||
curr -= decay
|
curr -= decay
|
||||||
|
|
||||||
|
|
||||||
def minibatch_by_words(examples, size, tuples=True, count_words=len, tolerance=0.2):
|
def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_oversize=False):
|
||||||
"""Create minibatches of roughly a given number of words. If any examples
|
"""Create minibatches of roughly a given number of words. If any examples
|
||||||
are longer than the specified batch length, they will appear in a batch by
|
are longer than the specified batch length, they will appear in a batch by
|
||||||
themselves."""
|
themselves, or be discarded if discard_oversize=True."""
|
||||||
if isinstance(size, int):
|
if isinstance(size, int):
|
||||||
size_ = itertools.repeat(size)
|
size_ = itertools.repeat(size)
|
||||||
elif isinstance(size, List):
|
elif isinstance(size, List):
|
||||||
size_ = iter(size)
|
size_ = iter(size)
|
||||||
else:
|
else:
|
||||||
size_ = size
|
size_ = size
|
||||||
examples = iter(examples)
|
|
||||||
oversize = []
|
target_size = next(size_)
|
||||||
while True:
|
tol_size = target_size * tolerance
|
||||||
batch_size = next(size_)
|
batch = []
|
||||||
tol_size = batch_size * 0.2
|
current_size = 0
|
||||||
batch = []
|
|
||||||
if oversize:
|
for example in examples:
|
||||||
example = oversize.pop(0)
|
n_words = count_words(example.doc)
|
||||||
n_words = count_words(example.doc)
|
# add the example to the current batch if it still fits
|
||||||
|
if (current_size + n_words) < (target_size + tol_size):
|
||||||
batch.append(example)
|
batch.append(example)
|
||||||
batch_size -= n_words
|
current_size += n_words
|
||||||
while batch_size >= 1:
|
else:
|
||||||
try:
|
# if the current example exceeds the batch size, it is returned separately
|
||||||
example = next(examples)
|
# but only if discard_oversize=False.
|
||||||
except StopIteration:
|
if current_size > target_size:
|
||||||
if oversize:
|
if not discard_oversize:
|
||||||
example = oversize.pop(0)
|
yield [example]
|
||||||
batch.append(example)
|
# yield the previous batch and start a new one
|
||||||
if batch:
|
|
||||||
yield batch
|
|
||||||
return
|
|
||||||
n_words = count_words(example.doc)
|
|
||||||
if n_words < (batch_size + tol_size):
|
|
||||||
batch_size -= n_words
|
|
||||||
batch.append(example)
|
|
||||||
else:
|
else:
|
||||||
oversize.append(example)
|
yield batch
|
||||||
if batch:
|
target_size = next(size_)
|
||||||
yield batch
|
tol_size = target_size * tolerance
|
||||||
|
# In theory it may happen that the current example now exceeds the new target_size,
|
||||||
|
# but that seems like an unimportant edge case if batch sizes are variable anyway?
|
||||||
|
batch = [example]
|
||||||
|
current_size = n_words
|
||||||
|
|
||||||
|
# yield the final batch
|
||||||
|
if batch:
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
def itershuffle(iterable, bufsize=1000):
|
def itershuffle(iterable, bufsize=1000):
|
||||||
|
|
Loading…
Reference in New Issue