genienlp/text/torchtext/data/batch.py

52 lines
2.2 KiB
Python

from copy import deepcopy
class Batch(object):
"""Defines a batch of examples along with its Fields.
Attributes:
batch_size: Number of examples in the batch.
dataset: A reference to the dataset object the examples come from
(which itself contains the dataset's Field objects).
train: Whether the batch is from a training set.
Also stores the Variable for each column in the batch as an attribute.
"""
def __init__(self, data=None, dataset=None, device=None, train=True):
"""Create a Batch from a list of examples."""
if data is not None:
self.batch_size = len(data)
self.dataset = dataset
self.train = train
field = list(dataset.fields.values())[0]
limited_idx_to_full_idx = deepcopy(field.decoder_to_vocab) # should avoid this with a conditional in map to full
oov_to_limited_idx = {}
for (name, field) in dataset.fields.items():
if field is not None:
batch = [x.__dict__[name] for x in data]
if not field.include_lengths:
setattr(self, name, field.process(batch, device=device, train=train))
else:
entry, lengths, limited_entry, raw = field.process(batch, device=device, train=train,
limited=field.decoder_stoi, l2f=limited_idx_to_full_idx, oov2l=oov_to_limited_idx)
setattr(self, name, entry)
setattr(self, f'{name}_lengths', lengths)
setattr(self, f'{name}_limited', limited_entry)
setattr(self, f'{name}_elmo', [[s.strip() for s in l] for l in raw])
setattr(self, f'limited_idx_to_full_idx', limited_idx_to_full_idx)
setattr(self, f'oov_to_limited_idx', oov_to_limited_idx)
@classmethod
def fromvars(cls, dataset, batch_size, train=True, **kwargs):
"""Create a Batch directly from a number of Variables."""
batch = cls()
batch.batch_size = batch_size
batch.dataset = dataset
batch.train = train
for k, v in kwargs.items():
setattr(batch, k, v)
return batch