90 lines
3.6 KiB
Python
90 lines
3.6 KiB
Python
from .. import data
|
|
|
|
|
|
class LanguageModelingDataset(data.Dataset):
|
|
"""Defines a dataset for language modeling."""
|
|
|
|
def __init__(self, path, text_field, newline_eos=True, **kwargs):
|
|
"""Create a LanguageModelingDataset given a path and a field.
|
|
|
|
Arguments:
|
|
path: Path to the data file.
|
|
text_field: The field that will be used for text data.
|
|
newline_eos: Whether to add an <eos> token for every newline in the
|
|
data file. Default: True.
|
|
Remaining keyword arguments: Passed to the constructor of
|
|
data.Dataset.
|
|
"""
|
|
fields = [('text', text_field)]
|
|
text = []
|
|
with open(path) as f:
|
|
for line in f:
|
|
text += text_field.preprocess(line)
|
|
if newline_eos:
|
|
text.append('<eos>')
|
|
|
|
examples = [data.Example.fromlist([text], fields)]
|
|
super(LanguageModelingDataset, self).__init__(
|
|
examples, fields, **kwargs)
|
|
|
|
|
|
class WikiText2(LanguageModelingDataset):
|
|
|
|
urls = ['https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip']
|
|
name = 'wikitext-2'
|
|
dirname = 'wikitext-2'
|
|
|
|
@classmethod
|
|
def splits(cls, text_field, root='.data', train='wiki.train.tokens',
|
|
validation='wiki.valid.tokens', test='wiki.test.tokens',
|
|
**kwargs):
|
|
"""Create dataset objects for splits of the WikiText-2 dataset.
|
|
|
|
This is the most flexible way to use the dataset.
|
|
|
|
Arguments:
|
|
text_field: The field that will be used for text data.
|
|
root: The root directory that the dataset's zip archive will be
|
|
expanded into; therefore the directory in whose wikitext-2
|
|
subdirectory the data files will be stored.
|
|
train: The filename of the train data. Default: 'wiki.train.tokens'.
|
|
validation: The filename of the validation data, or None to not
|
|
load the validation set. Default: 'wiki.valid.tokens'.
|
|
test: The filename of the test data, or None to not load the test
|
|
set. Default: 'wiki.test.tokens'.
|
|
"""
|
|
return super(WikiText2, cls).splits(
|
|
root=root, train=train, validation=validation, test=test,
|
|
text_field=text_field, **kwargs)
|
|
|
|
@classmethod
|
|
def iters(cls, batch_size=32, bptt_len=35, device=0, root='.data',
|
|
vectors=None, **kwargs):
|
|
"""Create iterator objects for splits of the WikiText-2 dataset.
|
|
|
|
This is the simplest way to use the dataset, and assumes common
|
|
defaults for field, vocabulary, and iterator parameters.
|
|
|
|
Arguments:
|
|
batch_size: Batch size.
|
|
bptt_len: Length of sequences for backpropagation through time.
|
|
device: Device to create batches on. Use -1 for CPU and None for
|
|
the currently active GPU device.
|
|
root: The root directory that the dataset's zip archive will be
|
|
expanded into; therefore the directory in whose wikitext-2
|
|
subdirectory the data files will be stored.
|
|
wv_dir, wv_type, wv_dim: Passed to the Vocab constructor for the
|
|
text field. The word vectors are accessible as
|
|
train.dataset.fields['text'].vocab.vectors.
|
|
Remaining keyword arguments: Passed to the splits method.
|
|
"""
|
|
TEXT = data.Field()
|
|
|
|
train, val, test = cls.splits(TEXT, root=root, **kwargs)
|
|
|
|
TEXT.build_vocab(train, vectors=vectors)
|
|
|
|
return data.BPTTIterator.splits(
|
|
(train, val, test), batch_size=batch_size, bptt_len=bptt_len,
|
|
device=device)
|