2020-08-04 13:09:37 +00:00
|
|
|
from typing import Union, Iterator, Iterable, Sequence, TypeVar, List, Callable
|
|
|
|
from typing import Optional, Any
|
|
|
|
from functools import partial
|
|
|
|
import itertools
|
|
|
|
|
|
|
|
from ..util import registry, minibatch
|
|
|
|
|
|
|
|
|
|
|
|
Sizing = Union[Iterable[int], int]
|
|
|
|
ItemT = TypeVar("ItemT")
|
|
|
|
BatcherT = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
|
|
|
|
|
|
|
|
|
|
|
|
@registry.batchers("batch_by_padded.v1")
|
|
|
|
def configure_minibatch_by_padded_size(
|
|
|
|
*,
|
|
|
|
size: Sizing,
|
|
|
|
buffer: int,
|
|
|
|
discard_oversize: bool,
|
|
|
|
get_length: Optional[Callable[[ItemT], int]] = None
|
|
|
|
) -> BatcherT:
|
|
|
|
# Avoid displacing optional values from the underlying function.
|
|
|
|
optionals = {"get_length": get_length} if get_length is not None else {}
|
|
|
|
return partial(
|
|
|
|
minibatch_by_padded_size,
|
|
|
|
size=size,
|
|
|
|
buffer=buffer,
|
|
|
|
discard_oversize=discard_oversize,
|
|
|
|
**optionals
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@registry.batchers("batch_by_words.v1")
|
|
|
|
def configure_minibatch_by_words(
|
|
|
|
*,
|
|
|
|
size: Sizing,
|
|
|
|
tolerance: float,
|
|
|
|
discard_oversize: bool,
|
|
|
|
get_length: Optional[Callable[[ItemT], int]] = None
|
|
|
|
) -> BatcherT:
|
|
|
|
optionals = {"get_length": get_length} if get_length is not None else {}
|
|
|
|
return partial(
|
2020-08-05 14:00:59 +00:00
|
|
|
minibatch_by_words, size=size, discard_oversize=discard_oversize, **optionals
|
2020-08-04 13:09:37 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@registry.batchers("batch_by_sequence.v1")
|
|
|
|
def configure_minibatch(size: Sizing, get_length=None) -> BatcherT:
|
2020-08-05 14:00:59 +00:00
|
|
|
optionals = {"get_length": get_length} if get_length is not None else {}
|
2020-08-04 13:09:37 +00:00
|
|
|
return partial(minibatch, size=size, **optionals)
|
|
|
|
|
|
|
|
|
|
|
|
def minibatch_by_padded_size(
|
|
|
|
docs: Iterator["Doc"],
|
|
|
|
size: Sizing,
|
|
|
|
buffer: int = 256,
|
|
|
|
discard_oversize: bool = False,
|
|
|
|
get_length: Callable = len,
|
|
|
|
) -> Iterator[Iterator["Doc"]]:
|
|
|
|
if isinstance(size, int):
|
|
|
|
size_ = itertools.repeat(size)
|
|
|
|
else:
|
|
|
|
size_ = size
|
|
|
|
for outer_batch in minibatch(docs, size=buffer):
|
|
|
|
outer_batch = list(outer_batch)
|
|
|
|
target_size = next(size_)
|
|
|
|
for indices in _batch_by_length(outer_batch, target_size, get_length):
|
|
|
|
subbatch = [outer_batch[i] for i in indices]
|
|
|
|
padded_size = max(len(seq) for seq in subbatch) * len(subbatch)
|
|
|
|
if discard_oversize and padded_size >= target_size:
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
yield subbatch
|
|
|
|
|
|
|
|
|
|
|
|
def minibatch_by_words(
|
|
|
|
docs, size, tolerance=0.2, discard_oversize=False, get_length=len
|
|
|
|
):
|
|
|
|
"""Create minibatches of roughly a given number of words. If any examples
|
|
|
|
are longer than the specified batch length, they will appear in a batch by
|
|
|
|
themselves, or be discarded if discard_oversize=True.
|
|
|
|
The argument 'docs' can be a list of strings, Doc's or Example's. """
|
|
|
|
if isinstance(size, int):
|
|
|
|
size_ = itertools.repeat(size)
|
|
|
|
elif isinstance(size, List):
|
|
|
|
size_ = iter(size)
|
|
|
|
else:
|
|
|
|
size_ = size
|
|
|
|
target_size = next(size_)
|
|
|
|
tol_size = target_size * tolerance
|
|
|
|
batch = []
|
|
|
|
overflow = []
|
|
|
|
batch_size = 0
|
|
|
|
overflow_size = 0
|
|
|
|
for doc in docs:
|
|
|
|
n_words = get_length(doc)
|
|
|
|
# if the current example exceeds the maximum batch size, it is returned separately
|
|
|
|
# but only if discard_oversize=False.
|
|
|
|
if n_words > target_size + tol_size:
|
|
|
|
if not discard_oversize:
|
|
|
|
yield [doc]
|
|
|
|
# add the example to the current batch if there's no overflow yet and it still fits
|
|
|
|
elif overflow_size == 0 and (batch_size + n_words) <= target_size:
|
|
|
|
batch.append(doc)
|
|
|
|
batch_size += n_words
|
|
|
|
# add the example to the overflow buffer if it fits in the tolerance margin
|
|
|
|
elif (batch_size + overflow_size + n_words) <= (target_size + tol_size):
|
|
|
|
overflow.append(doc)
|
|
|
|
overflow_size += n_words
|
|
|
|
# yield the previous batch and start a new one. The new one gets the overflow examples.
|
|
|
|
else:
|
|
|
|
if batch:
|
|
|
|
yield batch
|
|
|
|
target_size = next(size_)
|
|
|
|
tol_size = target_size * tolerance
|
|
|
|
batch = overflow
|
|
|
|
batch_size = overflow_size
|
|
|
|
overflow = []
|
|
|
|
overflow_size = 0
|
|
|
|
# this example still fits
|
|
|
|
if (batch_size + n_words) <= target_size:
|
|
|
|
batch.append(doc)
|
|
|
|
batch_size += n_words
|
|
|
|
# this example fits in overflow
|
|
|
|
elif (batch_size + n_words) <= (target_size + tol_size):
|
|
|
|
overflow.append(doc)
|
|
|
|
overflow_size += n_words
|
|
|
|
# this example does not fit with the previous overflow: start another new batch
|
|
|
|
else:
|
|
|
|
if batch:
|
|
|
|
yield batch
|
|
|
|
target_size = next(size_)
|
|
|
|
tol_size = target_size * tolerance
|
|
|
|
batch = [doc]
|
|
|
|
batch_size = n_words
|
|
|
|
batch.extend(overflow)
|
|
|
|
if batch:
|
|
|
|
yield batch
|
|
|
|
|
|
|
|
|
|
|
|
def _batch_by_length(
|
|
|
|
seqs: Sequence[Any], max_words: int, get_length=len
|
|
|
|
) -> List[List[Any]]:
|
|
|
|
"""Given a list of sequences, return a batched list of indices into the
|
|
|
|
list, where the batches are grouped by length, in descending order.
|
|
|
|
|
|
|
|
Batches may be at most max_words in size, defined as max sequence length * size.
|
|
|
|
"""
|
|
|
|
# Use negative index so we can get sort by position ascending.
|
|
|
|
lengths_indices = [(get_length(seq), i) for i, seq in enumerate(seqs)]
|
|
|
|
lengths_indices.sort()
|
|
|
|
batches = []
|
|
|
|
batch = []
|
|
|
|
for length, i in lengths_indices:
|
|
|
|
if not batch:
|
|
|
|
batch.append(i)
|
|
|
|
elif length * (len(batch) + 1) <= max_words:
|
|
|
|
batch.append(i)
|
|
|
|
else:
|
|
|
|
batches.append(batch)
|
|
|
|
batch = [i]
|
|
|
|
if batch:
|
|
|
|
batches.append(batch)
|
|
|
|
# Check lengths match
|
|
|
|
assert sum(len(b) for b in batches) == len(seqs)
|
|
|
|
batches = [list(sorted(batch)) for batch in batches]
|
|
|
|
batches.reverse()
|
|
|
|
return batches
|