52 lines
2.2 KiB
Python
52 lines
2.2 KiB
Python
from copy import deepcopy
|
|
|
|
|
|
|
|
class Batch(object):
|
|
"""Defines a batch of examples along with its Fields.
|
|
|
|
Attributes:
|
|
batch_size: Number of examples in the batch.
|
|
dataset: A reference to the dataset object the examples come from
|
|
(which itself contains the dataset's Field objects).
|
|
train: Whether the batch is from a training set.
|
|
|
|
Also stores the Variable for each column in the batch as an attribute.
|
|
"""
|
|
|
|
def __init__(self, data=None, dataset=None, device=None, train=True):
|
|
"""Create a Batch from a list of examples."""
|
|
if data is not None:
|
|
self.batch_size = len(data)
|
|
self.dataset = dataset
|
|
self.train = train
|
|
field = list(dataset.fields.values())[0]
|
|
limited_idx_to_full_idx = deepcopy(field.decoder_to_vocab) # should avoid this with a conditional in map to full
|
|
oov_to_limited_idx = {}
|
|
for (name, field) in dataset.fields.items():
|
|
if field is not None:
|
|
batch = [x.__dict__[name] for x in data]
|
|
if not field.include_lengths:
|
|
setattr(self, name, field.process(batch, device=device, train=train))
|
|
else:
|
|
entry, lengths, limited_entry, raw = field.process(batch, device=device, train=train,
|
|
limited=field.decoder_stoi, l2f=limited_idx_to_full_idx, oov2l=oov_to_limited_idx)
|
|
setattr(self, name, entry)
|
|
setattr(self, f'{name}_lengths', lengths)
|
|
setattr(self, f'{name}_limited', limited_entry)
|
|
setattr(self, f'{name}_elmo', [[s.strip() for s in l] for l in raw])
|
|
setattr(self, f'limited_idx_to_full_idx', limited_idx_to_full_idx)
|
|
setattr(self, f'oov_to_limited_idx', oov_to_limited_idx)
|
|
|
|
|
|
@classmethod
|
|
def fromvars(cls, dataset, batch_size, train=True, **kwargs):
|
|
"""Create a Batch directly from a number of Variables."""
|
|
batch = cls()
|
|
batch.batch_size = batch_size
|
|
batch.dataset = dataset
|
|
batch.train = train
|
|
for k, v in kwargs.items():
|
|
setattr(batch, k, v)
|
|
return batch
|