mirror of https://github.com/explosion/spaCy.git
Add extra batch util
This commit is contained in:
parent
eb0798c421
commit
3a7f275c02
|
@ -722,6 +722,50 @@ def minibatch(items, size=8):
|
|||
yield list(batch)
|
||||
|
||||
|
||||
def minibatch_by_padded_size(docs, size, buffer=256, discard_oversize=False):
|
||||
if isinstance(size, int):
|
||||
size_ = itertools.repeat(size)
|
||||
else:
|
||||
size_ = size
|
||||
for outer_batch in minibatch(docs, buffer):
|
||||
outer_batch = list(outer_batch)
|
||||
target_size = next(size_)
|
||||
for indices in _batch_by_length(outer_batch, target_size):
|
||||
subbatch = [outer_batch[i] for i in indices]
|
||||
padded_size = max(len(seq) for seq in subbatch) * len(subbatch)
|
||||
if discard_oversize and padded_size >= target_size:
|
||||
pass
|
||||
else:
|
||||
yield subbatch
|
||||
|
||||
|
||||
def _batch_by_length(seqs, max_words):
|
||||
"""Given a list of sequences, return a batched list of indices into the
|
||||
list, where the batches are grouped by length, in descending order.
|
||||
|
||||
Batches may be at most max_words in size, defined as max sequence length * size.
|
||||
"""
|
||||
# Use negative index so we can get sort by position ascending.
|
||||
lengths_indices = [(len(seq), i) for i, seq in enumerate(seqs)]
|
||||
lengths_indices.sort()
|
||||
batches = []
|
||||
batch = []
|
||||
for length, i in lengths_indices:
|
||||
if not batch:
|
||||
batch.append(i)
|
||||
elif length * (len(batch) + 1) <= max_words:
|
||||
batch.append(i)
|
||||
else:
|
||||
batches.append(batch)
|
||||
batch = [i]
|
||||
if batch:
|
||||
batches.append(batch)
|
||||
# Check lengths match
|
||||
assert sum(len(b) for b in batches) == len(seqs)
|
||||
batches = [list(sorted(batch)) for batch in batches]
|
||||
batches.reverse()
|
||||
return batches
|
||||
|
||||
def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
||||
"""Create minibatches of roughly a given number of words. If any examples
|
||||
are longer than the specified batch length, they will appear in a batch by
|
||||
|
@ -768,6 +812,7 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
|||
|
||||
# yield the previous batch and start a new one. The new one gets the overflow examples.
|
||||
else:
|
||||
if batch:
|
||||
yield batch
|
||||
target_size = next(size_)
|
||||
tol_size = target_size * tolerance
|
||||
|
@ -788,15 +833,15 @@ def minibatch_by_words(docs, size, tolerance=0.2, discard_oversize=False):
|
|||
|
||||
# this example does not fit with the previous overflow: start another new batch
|
||||
else:
|
||||
if batch:
|
||||
yield batch
|
||||
target_size = next(size_)
|
||||
tol_size = target_size * tolerance
|
||||
batch = [doc]
|
||||
batch_size = n_words
|
||||
|
||||
# yield the final batch
|
||||
if batch:
|
||||
batch.extend(overflow)
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue