genienlp/text/torchtext/datasets/snli.py

116 lines
4.7 KiB
Python
Raw Normal View History

2018-06-20 06:22:34 +00:00
from .. import data
class ShiftReduceField(data.Field):
def __init__(self):
super(ShiftReduceField, self).__init__(preprocessing=lambda parse: [
'reduce' if t == ')' else 'shift' for t in parse if t != '('])
self.build_vocab([['reduce'], ['shift']])
class ParsedTextField(data.Field):
def __init__(self, eos_token='<pad>', lower=False):
super(ParsedTextField, self).__init__(
eos_token=eos_token, lower=lower, preprocessing=lambda parse: [
t for t in parse if t not in ('(', ')')],
postprocessing=lambda parse, _, __: [
list(reversed(p)) for p in parse])
class SNLI(data.TabularDataset):
urls = ['http://nlp.stanford.edu/projects/snli/snli_1.0.zip']
dirname = 'snli_1.0'
name = 'snli'
@staticmethod
def sort_key(ex):
return data.interleave_keys(
len(ex.premise), len(ex.hypothesis))
@classmethod
def splits(cls, text_field, label_field, parse_field=None, root='.data',
train='snli_1.0_train.jsonl', validation='snli_1.0_dev.jsonl',
test='snli_1.0_test.jsonl'):
"""Create dataset objects for splits of the SNLI dataset.
This is the most flexible way to use the dataset.
Arguments:
text_field: The field that will be used for premise and hypothesis
data.
label_field: The field that will be used for label data.
parse_field: The field that will be used for shift-reduce parser
transitions, or None to not include them.
root: The root directory that the dataset's zip archive will be
expanded into; therefore the directory in whose snli_1.0
subdirectory the data files will be stored.
train: The filename of the train data. Default: 'train.jsonl'.
validation: The filename of the validation data, or None to not
load the validation set. Default: 'dev.jsonl'.
test: The filename of the test data, or None to not load the test
set. Default: 'test.jsonl'.
"""
path = cls.download(root)
if parse_field is None:
return super(SNLI, cls).splits(
path, root, train, validation, test,
format='json', fields={'sentence1': ('premise', text_field),
'sentence2': ('hypothesis', text_field),
'gold_label': ('label', label_field)},
filter_pred=lambda ex: ex.label != '-')
return super(SNLI, cls).splits(
path, root, train, validation, test,
format='json', fields={'sentence1_binary_parse':
[('premise', text_field),
('premise_transitions', parse_field)],
'sentence2_binary_parse':
[('hypothesis', text_field),
('hypothesis_transitions', parse_field)],
'gold_label': ('label', label_field)},
filter_pred=lambda ex: ex.label != '-')
@classmethod
def iters(cls, batch_size=32, device=0, root='.data',
vectors=None, trees=False, **kwargs):
"""Create iterator objects for splits of the SNLI dataset.
This is the simplest way to use the dataset, and assumes common
defaults for field, vocabulary, and iterator parameters.
Arguments:
batch_size: Batch size.
device: Device to create batches on. Use -1 for CPU and None for
the currently active GPU device.
root: The root directory that the dataset's zip archive will be
expanded into; therefore the directory in whose wikitext-2
subdirectory the data files will be stored.
vectors: one of the available pretrained vectors or a list with each
element one of the available pretrained vectors (see Vocab.load_vectors)
trees: Whether to include shift-reduce parser transitions.
Default: False.
Remaining keyword arguments: Passed to the splits method.
"""
if trees:
TEXT = ParsedTextField()
TRANSITIONS = ShiftReduceField()
else:
TEXT = data.Field(tokenize='spacy')
TRANSITIONS = None
LABEL = data.Field(sequential=False)
train, val, test = cls.splits(
TEXT, LABEL, TRANSITIONS, root=root, **kwargs)
TEXT.build_vocab(train, vectors=vectors)
LABEL.build_vocab(train)
return data.BucketIterator.splits(
(train, val, test), batch_size=batch_size, device=device)