genienlp/text/torchtext/data/dataset.py

182 lines
7.5 KiB
Python

import io
import os
import zipfile
import tarfile
import torch.utils.data
from .example import Example
from ..utils import download_from_url
class Dataset(torch.utils.data.Dataset):
"""Defines a dataset composed of Examples along with its Fields.
Attributes:
sort_key (callable): A key to use for sorting dataset examples for batching
together examples with similar lengths to minimize padding.
examples (list(Example)): The examples in this dataset.
fields: A dictionary containing the name of each column together with
its corresponding Field object. Two columns with the same Field
object will share a vocabulary.
fields (dict[str, Field]): Contains the name of each column or field, together
with the corresponding Field object. Two fields with the same Field object
will have a shared vocabulary.
"""
sort_key = None
def __init__(self, examples, fields, filter_pred=None):
"""Create a dataset from a list of Examples and Fields.
Arguments:
examples: List of Examples.
fields (List(tuple(str, Field))): The Fields to use in this tuple. The
string is a field name, and the Field is the associated field.
filter_pred (callable or None): Use only examples for which
filter_pred(example) is True, or use all examples if None.
Default is None.
"""
if filter_pred is not None:
make_list = isinstance(examples, list)
examples = filter(filter_pred, examples)
if make_list:
examples = list(examples)
self.examples = examples
self.fields = dict(fields)
@classmethod
def splits(cls, path=None, root='.data', train=None, validation=None,
test=None, **kwargs):
"""Create Dataset objects for multiple splits of a dataset.
Arguments:
path (str): Common prefix of the splits' file paths, or None to use
the result of cls.download(root).
root (str): Root dataset storage directory. Default is '.data'.
train (str): Suffix to add to path for the train set, or None for no
train set. Default is None.
validation (str): Suffix to add to path for the validation set, or None
for no validation set. Default is None.
test (str): Suffix to add to path for the test set, or None for no test
set. Default is None.
Remaining keyword arguments: Passed to the constructor of the
Dataset (sub)class being used.
Returns:
split_datasets (tuple(Dataset)): Datasets for train, validation, and
test splits in that order, if provided.
"""
if path is None:
path = cls.download(root)
train_data = None if train is None else cls(
os.path.join(path, train), **kwargs)
val_data = None if validation is None else cls(
os.path.join(path, validation), **kwargs)
test_data = None if test is None else cls(
os.path.join(path, test), **kwargs)
return tuple(d for d in (train_data, val_data, test_data)
if d is not None)
def __getitem__(self, i):
return self.examples[i]
def __len__(self):
try:
return len(self.examples)
except TypeError:
return 2**32
def __iter__(self):
for x in self.examples:
yield x
def __getattr__(self, attr):
if attr in self.fields:
for x in self.examples:
yield getattr(x, attr)
@classmethod
def download(cls, root, check=None):
"""Download and unzip an online archive (.zip, .gz, or .tgz).
Arguments:
root (str): Folder to download data to.
check (str or None): Folder whose existence indicates
that the dataset has already been downloaded, or
None to check the existence of root/{cls.name}.
Returns:
dataset_path (str): Path to extracted dataset.
"""
path = os.path.join(root, cls.name)
check = path if check is None else check
if not os.path.isdir(check):
for url in cls.urls:
if isinstance(url, tuple):
url, filename = url
else:
filename = os.path.basename(url)
zpath = os.path.join(path, filename)
if not os.path.isfile(zpath):
if not os.path.exists(os.path.dirname(zpath)):
os.makedirs(os.path.dirname(zpath))
print('downloading {}'.format(filename))
download_from_url(url, zpath)
ext = os.path.splitext(filename)[-1]
if ext == '.zip':
with zipfile.ZipFile(zpath, 'r') as zfile:
print('extracting')
zfile.extractall(path)
elif ext in ['.gz', '.tgz']:
with tarfile.open(zpath, 'r:gz') as tar:
dirs = [member for member in tar.getmembers()]
tar.extractall(path=path, members=dirs)
elif ext in ['.bz2', '.tar']:
with tarfile.open(zpath) as tar:
dirs = [member for member in tar.getmembers()]
tar.extractall(path=path, members=dirs)
return os.path.join(path, cls.dirname)
class TabularDataset(Dataset):
"""Defines a Dataset of columns stored in CSV, TSV, or JSON format."""
def __init__(self, path, format, fields, skip_header=False, subsample=False, **kwargs):
"""Create a TabularDataset given a path, file format, and field list.
Arguments:
path (str): Path to the data file.
format (str): The format of the data file. One of "CSV", "TSV", or
"JSON" (case-insensitive).
fields (list(tuple(str, Field)) or dict[str: tuple(str, Field)]: For CSV and
TSV formats, list of tuples of (name, field). The list should be in
the same order as the columns in the CSV or TSV file, while tuples of
(name, None) represent columns that will be ignored. For JSON format,
dictionary whose keys are the JSON keys and whose values are tuples of
(name, field). This allows the user to rename columns from their JSON key
names and also enables selecting a subset of columns to load
(since JSON keys not present in the input dictionary are ignored).
skip_header (bool): Whether to skip the first line of the input file.
"""
make_example = {
'json': Example.fromJSON, 'dict': Example.fromdict,
'tsv': Example.fromTSV, 'csv': Example.fromCSV}[format.lower()]
examples = []
with io.open(os.path.expanduser(path), encoding="utf8") as f:
if skip_header:
next(f)
for line in f:
examples.append(make_example(line, fields))
if make_example in (Example.fromdict, Example.fromJSON):
fields, field_dict = [], fields
for field in field_dict.values():
if isinstance(field, list):
fields.extend(field)
else:
fields.append(field)
super(TabularDataset, self).__init__(examples, fields, **kwargs)