182 lines
7.5 KiB
Python
182 lines
7.5 KiB
Python
import io
|
|
import os
|
|
import zipfile
|
|
import tarfile
|
|
|
|
import torch.utils.data
|
|
|
|
from .example import Example
|
|
from ..utils import download_from_url
|
|
|
|
|
|
class Dataset(torch.utils.data.Dataset):
|
|
"""Defines a dataset composed of Examples along with its Fields.
|
|
|
|
Attributes:
|
|
sort_key (callable): A key to use for sorting dataset examples for batching
|
|
together examples with similar lengths to minimize padding.
|
|
examples (list(Example)): The examples in this dataset.
|
|
fields: A dictionary containing the name of each column together with
|
|
its corresponding Field object. Two columns with the same Field
|
|
object will share a vocabulary.
|
|
fields (dict[str, Field]): Contains the name of each column or field, together
|
|
with the corresponding Field object. Two fields with the same Field object
|
|
will have a shared vocabulary.
|
|
"""
|
|
sort_key = None
|
|
|
|
def __init__(self, examples, fields, filter_pred=None):
|
|
"""Create a dataset from a list of Examples and Fields.
|
|
|
|
Arguments:
|
|
examples: List of Examples.
|
|
fields (List(tuple(str, Field))): The Fields to use in this tuple. The
|
|
string is a field name, and the Field is the associated field.
|
|
filter_pred (callable or None): Use only examples for which
|
|
filter_pred(example) is True, or use all examples if None.
|
|
Default is None.
|
|
"""
|
|
if filter_pred is not None:
|
|
make_list = isinstance(examples, list)
|
|
examples = filter(filter_pred, examples)
|
|
if make_list:
|
|
examples = list(examples)
|
|
self.examples = examples
|
|
self.fields = dict(fields)
|
|
|
|
@classmethod
|
|
def splits(cls, path=None, root='.data', train=None, validation=None,
|
|
test=None, **kwargs):
|
|
"""Create Dataset objects for multiple splits of a dataset.
|
|
|
|
Arguments:
|
|
path (str): Common prefix of the splits' file paths, or None to use
|
|
the result of cls.download(root).
|
|
root (str): Root dataset storage directory. Default is '.data'.
|
|
train (str): Suffix to add to path for the train set, or None for no
|
|
train set. Default is None.
|
|
validation (str): Suffix to add to path for the validation set, or None
|
|
for no validation set. Default is None.
|
|
test (str): Suffix to add to path for the test set, or None for no test
|
|
set. Default is None.
|
|
Remaining keyword arguments: Passed to the constructor of the
|
|
Dataset (sub)class being used.
|
|
|
|
Returns:
|
|
split_datasets (tuple(Dataset)): Datasets for train, validation, and
|
|
test splits in that order, if provided.
|
|
"""
|
|
if path is None:
|
|
path = cls.download(root)
|
|
train_data = None if train is None else cls(
|
|
os.path.join(path, train), **kwargs)
|
|
val_data = None if validation is None else cls(
|
|
os.path.join(path, validation), **kwargs)
|
|
test_data = None if test is None else cls(
|
|
os.path.join(path, test), **kwargs)
|
|
return tuple(d for d in (train_data, val_data, test_data)
|
|
if d is not None)
|
|
|
|
def __getitem__(self, i):
|
|
return self.examples[i]
|
|
|
|
def __len__(self):
|
|
try:
|
|
return len(self.examples)
|
|
except TypeError:
|
|
return 2**32
|
|
|
|
def __iter__(self):
|
|
for x in self.examples:
|
|
yield x
|
|
|
|
def __getattr__(self, attr):
|
|
if attr in self.fields:
|
|
for x in self.examples:
|
|
yield getattr(x, attr)
|
|
|
|
@classmethod
|
|
def download(cls, root, check=None):
|
|
"""Download and unzip an online archive (.zip, .gz, or .tgz).
|
|
|
|
Arguments:
|
|
root (str): Folder to download data to.
|
|
check (str or None): Folder whose existence indicates
|
|
that the dataset has already been downloaded, or
|
|
None to check the existence of root/{cls.name}.
|
|
|
|
Returns:
|
|
dataset_path (str): Path to extracted dataset.
|
|
"""
|
|
path = os.path.join(root, cls.name)
|
|
check = path if check is None else check
|
|
if not os.path.isdir(check):
|
|
for url in cls.urls:
|
|
if isinstance(url, tuple):
|
|
url, filename = url
|
|
else:
|
|
filename = os.path.basename(url)
|
|
zpath = os.path.join(path, filename)
|
|
if not os.path.isfile(zpath):
|
|
if not os.path.exists(os.path.dirname(zpath)):
|
|
os.makedirs(os.path.dirname(zpath))
|
|
print('downloading {}'.format(filename))
|
|
download_from_url(url, zpath)
|
|
ext = os.path.splitext(filename)[-1]
|
|
if ext == '.zip':
|
|
with zipfile.ZipFile(zpath, 'r') as zfile:
|
|
print('extracting')
|
|
zfile.extractall(path)
|
|
elif ext in ['.gz', '.tgz']:
|
|
with tarfile.open(zpath, 'r:gz') as tar:
|
|
dirs = [member for member in tar.getmembers()]
|
|
tar.extractall(path=path, members=dirs)
|
|
elif ext in ['.bz2', '.tar']:
|
|
with tarfile.open(zpath) as tar:
|
|
dirs = [member for member in tar.getmembers()]
|
|
tar.extractall(path=path, members=dirs)
|
|
|
|
return os.path.join(path, cls.dirname)
|
|
|
|
|
|
class TabularDataset(Dataset):
|
|
"""Defines a Dataset of columns stored in CSV, TSV, or JSON format."""
|
|
|
|
def __init__(self, path, format, fields, skip_header=False, subsample=False, **kwargs):
|
|
"""Create a TabularDataset given a path, file format, and field list.
|
|
|
|
Arguments:
|
|
path (str): Path to the data file.
|
|
format (str): The format of the data file. One of "CSV", "TSV", or
|
|
"JSON" (case-insensitive).
|
|
fields (list(tuple(str, Field)) or dict[str: tuple(str, Field)]: For CSV and
|
|
TSV formats, list of tuples of (name, field). The list should be in
|
|
the same order as the columns in the CSV or TSV file, while tuples of
|
|
(name, None) represent columns that will be ignored. For JSON format,
|
|
dictionary whose keys are the JSON keys and whose values are tuples of
|
|
(name, field). This allows the user to rename columns from their JSON key
|
|
names and also enables selecting a subset of columns to load
|
|
(since JSON keys not present in the input dictionary are ignored).
|
|
skip_header (bool): Whether to skip the first line of the input file.
|
|
"""
|
|
make_example = {
|
|
'json': Example.fromJSON, 'dict': Example.fromdict,
|
|
'tsv': Example.fromTSV, 'csv': Example.fromCSV}[format.lower()]
|
|
|
|
examples = []
|
|
with io.open(os.path.expanduser(path), encoding="utf8") as f:
|
|
if skip_header:
|
|
next(f)
|
|
for line in f:
|
|
examples.append(make_example(line, fields))
|
|
|
|
if make_example in (Example.fromdict, Example.fromJSON):
|
|
fields, field_dict = [], fields
|
|
for field in field_dict.values():
|
|
if isinstance(field, list):
|
|
fields.extend(field)
|
|
else:
|
|
fields.append(field)
|
|
|
|
super(TabularDataset, self).__init__(examples, fields, **kwargs)
|