Switch to torch.utils.data.DataLoader to load the datasets

Introduce new classes for numericalization and padding, in preparation
for moving away from torchtext.Field and torchtext.Batch. The
new classes are based on namedtuples, which are more compatible with
pytorch.
Move the useful parts of torchtext.Iterator (batching and bucketing
by example length) into a new Iterator class.
This commit is contained in:
Giovanni Campagna 2019-12-21 06:52:14 -08:00
parent f6a2e4608d
commit 50b9c05248
10 changed files with 361 additions and 66 deletions

View File

@ -124,8 +124,6 @@ def parse(argv):
parser.add_argument('--no_commit', action='store_false', dest='commit', help='do not track the git commit associated with this training run')
parser.add_argument('--exist_ok', action='store_true', help='Ok if the save directory already exists, i.e. overwrite is ok')
parser.add_argument('--token_testing', action='store_true', help='if true, sorts all iterators')
parser.add_argument('--reverse', action='store_true', help='if token_testing and true, sorts all iterators in reverse')
parser.add_argument('--skip_cache', action='store_true', dest='skip_cache_bool', help='whether to use exisiting cached splits or generate new ones')
parser.add_argument('--lr_rate', default=0.001, type=float, help='initial_learning_rate')

0
decanlp/data/__init__.py Normal file
View File

76
decanlp/data/example.py Normal file
View File

@ -0,0 +1,76 @@
#
# Copyright (c) 2018, Salesforce, Inc.
# The Board of Trustees of the Leland Stanford Junior University
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import torch
import random
from typing import NamedTuple, List
from .numericalizer import DecoderVocabulary, SimpleNumericalizer, SequentialField
class Example(NamedTuple):
example_id : List[str]
context : List[str]
question : List[str]
answer : List[str]
@staticmethod
def from_raw(example_id : str, context : str, question : str, answer : str, tokenize, lower=False):
args = [[example_id]]
for arg in (context, question, answer):
new_arg = tokenize(arg.rstrip('\n'))
if lower:
new_arg = [word.lower() for word in new_arg]
args.append(new_arg)
return Example(*args)
class Batch(NamedTuple):
example_id : List[str]
context : SequentialField
question : SequentialField
answer : SequentialField
decoder_vocab : DecoderVocabulary
@staticmethod
def from_examples(examples, numericalizer : SimpleNumericalizer, decoder_vocab, device=None):
example_ids = [ex.example_id for ex in examples]
context_input = [ex.context for ex in examples]
question_input = [ex.question for ex in examples]
answer_input = [ex.answer for ex in examples]
decoder_vocab = decoder_vocab.clone()
return Batch(example_ids,
numericalizer.encode(context_input, decoder_vocab, device=device),
numericalizer.encode(question_input, decoder_vocab, device=device),
numericalizer.encode(answer_input, decoder_vocab, device=device),
decoder_vocab)

113
decanlp/data/iterator.py Normal file
View File

@ -0,0 +1,113 @@
#
# Copyright (c) 2018, Salesforce, Inc.
# The Board of Trustees of the Leland Stanford Junior University
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import torch
import random
class Iterator(torch.utils.data.IterableDataset):
def __init__(self,
dataset : torch.utils.data.Dataset,
batch_size,
shuffle=False,
repeat=False,
batch_size_fn=None,
bucket_by_sort_key=False):
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.repeat = repeat
if batch_size_fn is None:
def batch_size_fn(new, count, sofar):
return count
self.batch_size_fn = batch_size_fn
self.bucket_by_sort_key = bucket_by_sort_key
def __len__(self):
if self.repeat:
raise NotImplementedError()
else:
return len(self.dataset)
def __iter__(self):
while self.repeat:
if self.shuffle:
dataset = list(self.dataset)
random.shuffle(dataset)
else:
dataset = self.dataset
if self.bucket_by_sort_key:
batches = self._pool(dataset)
else:
batches = self._batch(dataset, self.batch_size)
for minibatch in batches:
yield minibatch
def _batch(self, data, batch_size):
"""Yield elements from data in chunks of batch_size."""
minibatch = []
size_so_far = 0
for ex in data:
minibatch.append(ex)
size_so_far = self.batch_size_fn(ex, len(minibatch), size_so_far)
if size_so_far == batch_size:
yield minibatch
minibatch, size_so_far = [], 0
elif size_so_far > batch_size:
if len(minibatch) == 1: # if we only have one really big example
yield minibatch
minibatch, size_so_far = [], 0
else:
yield minibatch[:-1]
minibatch, size_so_far = minibatch[-1:], self.batch_size_fn(ex, 1, 0)
if size_so_far > batch_size: # if we add a really big example that needs to be on its own to a batch
yield minibatch
minibatch, size_so_far = [], 0
if minibatch:
yield minibatch
def _pool(self, data):
"""Sort within buckets, then batch, then shuffle batches.
Partitions data into chunks of size 100*batch_size, sorts examples within
each chunk using sort_key, then batch these examples and shuffle the
batches.
"""
for p in self._batch(data, self.batch_size * 100):
p_batch = self._batch(sorted(p, key=self.dataset.sort_key), self.batch_size)
if self.shuffle:
p_batch = list(p_batch)
random.shuffle(p_batch)
for b in p_batch:
yield b

View File

@ -0,0 +1,132 @@
#
# Copyright (c) 2018, The Board of Trustees of the Leland Stanford Junior University
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import torch
from typing import NamedTuple, List
class SequentialField(NamedTuple):
value : torch.tensor
length : torch.tensor
limited : torch.tensor
tokens : List[str]
class DecoderVocabulary(object):
def __init__(self, words, full_vocab):
self.full_vocab = full_vocab
if words is not None:
self.itos = words
self.stoi = { word: idx for idx, word in enumerate(words) }
else:
self.itos = []
self.stoi = dict()
self.oov_itos = []
self.oov_stoi = dict()
@property
def max_generative_vocab(self):
return len(self.itos)
def clone(self):
new_subset = DecoderVocabulary(None, self.full_vocab)
new_subset.itos = self.itos
new_subset.stoi = self.stoi
return new_subset
def __len__(self):
return len(self.itos) + len(self.oov_itos)
def encode(self, word):
if word in self.stoi:
lim_idx = self.stoi[word]
elif word in self.oov_stoi:
lim_idx = self.oov_stoi[word]
else:
lim_idx = len(self)
self.oov_itos.append(word)
self.oov_stoi[word] = lim_idx
return lim_idx
def decode(self, lim_idx):
if lim_idx < len(self.itos):
return self.itos[lim_idx]
else:
return self.oov_itos[lim_idx-len(self.itos)]
class SimpleNumericalizer(object):
def __init__(self, vocab, init_token=None, eos_token=None, pad_token="<pad>", unk_token="<unk>",
fix_length=None, pad_first=False):
self.vocab = vocab
self.init_token = init_token
self.eos_token = eos_token
self.unk_token = unk_token
self.pad_token = pad_token
self.fix_length = fix_length
self.pad_first = pad_first
def encode(self, minibatch, decoder_vocab : DecoderVocabulary, device=None):
if not isinstance(minibatch, list):
minibatch = list(minibatch)
if self.fix_length is None:
max_len = max(len(x) for x in minibatch)
else:
max_len = self.fix_length + (
self.init_token, self.eos_token).count(None) - 2
padded = []
lengths = []
numerical = []
decoder_numerical = []
for example in minibatch:
if self.pad_first:
padded_example = [self.pad_token] * max(0, max_len - len(example)) + \
([] if self.init_token is None else [self.init_token]) + \
list(example[:max_len]) + \
([] if self.eos_token is None else [self.eos_token])
else:
padded_example = ([] if self.init_token is None else [self.init_token]) + \
list(example[:max_len]) + \
([] if self.eos_token is None else [self.eos_token]) + \
[self.pad_token] * max(0, max_len - len(example))
lengths.append(len(padded_example) - max(0, max_len - len(example)))
numerical.append([self.vocab.stoi[word] for word in padded_example])
decoder_numerical.append([decoder_vocab.encode(word) for word in padded_example])
length = torch.tensor(lengths, dtype=torch.int32, device=device)
numerical = torch.tensor(numerical, dtype=torch.int64, device=device)
decoder_numerical = torch.tensor(decoder_numerical, dtype=torch.int64, device=device)
return SequentialField(tokens=padded, length=length, value=numerical, limited=decoder_numerical)
def decode(self, tensor):
return [self.vocab.itos[idx] for idx in tensor]

View File

@ -126,11 +126,9 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
context, context_lengths, context_limited, context_tokens = batch.context, batch.context_lengths, batch.context_limited, batch.context_tokens
question, question_lengths, question_limited, question_tokens = batch.question, batch.question_lengths, batch.question_limited, batch.question_tokens
answer, answer_lengths, answer_limited, answer_tokens = batch.answer, batch.answer_lengths, batch.answer_limited, batch.answer_tokens
oov_to_limited_idx, limited_idx_to_full_idx = batch.oov_to_limited_idx, batch.limited_idx_to_full_idx
decoder_vocab = batch.decoder_vocab
def map_to_full(x):
return limited_idx_to_full_idx[x]
self.map_to_full = map_to_full
self.map_to_full = decoder_vocab.decode
context_embedded = self.encoder_embeddings(context)
question_embedded = self.encoder_embeddings(question)
@ -191,7 +189,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch,
context_attention, question_attention,
context_indices, question_indices,
oov_to_limited_idx)
decoder_vocab)
if self.args.use_maxmargin_loss:
@ -206,7 +204,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
else:
return None, self.greedy(self_attended_context, final_context, final_question,
context_indices, question_indices,
oov_to_limited_idx, rnn_state=context_rnn_state).data
decoder_vocab, rnn_state=context_rnn_state).data
def reshape_rnn_state(self, h):
return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \
@ -216,7 +214,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
def probs(self, generator, outputs, vocab_pointer_switches, context_question_switches,
context_attention, question_attention,
context_indices, question_indices,
oov_to_limited_idx):
decoder_vocab):
size = list(outputs.size())
@ -225,7 +223,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
p_vocab = F.softmax(scores, dim=scores.dim()-1)
scaled_p_vocab = vocab_pointer_switches.expand_as(p_vocab) * p_vocab
effective_vocab_size = self.generative_vocab_size + len(oov_to_limited_idx)
effective_vocab_size = len(decoder_vocab)
if self.generative_vocab_size < effective_vocab_size:
size[-1] = effective_vocab_size - self.generative_vocab_size
buff = scaled_p_vocab.new_full(size, EPSILON)
@ -242,7 +240,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
return scaled_p_vocab
def greedy(self, self_attended_context, context, question, context_indices, question_indices, oov_to_limited_idx, rnn_state=None):
def greedy(self, self_attended_context, context, question, context_indices, question_indices, decoder_vocab, rnn_state=None):
B, TC, C = context.size()
T = self.args.max_output_length
outs = context.new_full((B, T), self.field.decoder_stoi['<pad>'], dtype=torch.long)
@ -300,7 +298,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch,
context_attention, question_attention,
context_indices, question_indices,
oov_to_limited_idx)
decoder_vocab)
pred_probs, preds = probs.max(-1)
preds = preds.squeeze(1)
eos_yet = eos_yet | (preds == self.field.decoder_stoi['<eos>'])

View File

@ -38,7 +38,7 @@ import sys
import logging
from pprint import pformat
from .util import set_seed, preprocess_examples, load_config_json
from .util import set_seed, preprocess_examples, load_config_json, make_data_loader
from .metrics import compute_metrics
from .utils.embeddings import load_embeddings
from .tasks.registry import get_tasks
@ -85,21 +85,12 @@ def prepare_data(args, FIELD):
return FIELD, splits
def to_iter(data, bs, device):
it = Iterator(data, batch_size=bs,
device=device, batch_size_fn=None,
train=False, repeat=False, sort=False,
shuffle=False, reverse=False)
return it
def run(args, field, val_sets, model):
device = set_seed(args)
logger.info(f'Preparing iterators')
if len(args.val_batch_size) == 1 and len(val_sets) > 1:
args.val_batch_size *= len(val_sets)
iters = [(name, to_iter(x, bs, device)) for name, x, bs in zip(args.tasks, val_sets, args.val_batch_size)]
iters = [(name, make_data_loader(x, field, bs, device)) for name, x, bs in zip(args.tasks, val_sets, args.val_batch_size)]
def mult(ps):
r = 0

View File

@ -39,30 +39,13 @@ import hashlib
import unicodedata
import logging
import xml.etree.ElementTree as ET
from typing import NamedTuple, List
from ..text import data
from decanlp.data.example import Example
logger = logging.getLogger(__name__)
class Example(NamedTuple):
example_id : List[str]
context : List[str]
question : List[str]
answer : List[str]
@staticmethod
def from_raw(example_id : str, context : str, question : str, answer : str, tokenize, lower=False):
args = [[example_id]]
for arg in (context, question, answer):
new_arg = tokenize(arg.rstrip('\n'))
if lower:
new_arg = [word.lower() for word in new_arg]
args.append(new_arg)
return Example(*args)
def make_example_id(dataset, example_id):
return dataset.name + '/' + str(example_id)

View File

@ -48,10 +48,11 @@ from tensorboardX import SummaryWriter
from . import arguments
from .validate import validate
from .multiprocess import Multiprocess
from .util import elapsed_time, batch_fn, set_seed, preprocess_examples, get_trainable_params
from .util import elapsed_time, set_seed, preprocess_examples, get_trainable_params, make_data_loader
from .utils.saver import Saver
from .utils.embeddings import load_embeddings
from .text.data import ReversibleField, BucketIterator, Iterator
from .text.data import ReversibleField
from .data.numericalizer import DecoderVocabulary
def initialize_logger(args, rank='main'):
@ -130,10 +131,7 @@ def prepare_data(args, field, logger):
logger.info(f'Building vocabulary')
FIELD.build_vocab(*vocab_sets, max_size=args.max_effective_vocab, vectors=vectors)
FIELD.decoder_itos = FIELD.vocab.itos[:args.max_generative_vocab]
FIELD.decoder_stoi = {word: idx for idx, word in enumerate(FIELD.decoder_itos)}
FIELD.decoder_to_vocab = {idx: FIELD.vocab.stoi[word] for idx, word in enumerate(FIELD.decoder_itos)}
FIELD.vocab_to_decoder = {idx: FIELD.decoder_stoi[word] for idx, word in enumerate(FIELD.vocab.itos) if word in FIELD.decoder_stoi}
FIELD.decoder_vocab = DecoderVocabulary(FIELD.vocab.itos[:args.max_generative_vocab], FIELD.vocab)
logger.info(f'Vocabulary has {len(FIELD.vocab)} tokens')
logger.debug(f'The first 200 tokens:')
@ -150,18 +148,6 @@ def prepare_data(args, field, logger):
return FIELD, train_sets, val_sets, aux_sets
def to_iter(args, world_size, val_batch_size, data, device, train=True, token_testing=False, sort=None):
sort = sort if not token_testing else True
shuffle = None if not token_testing else False
reverse = args.reverse
iteratorcls = BucketIterator if train else Iterator
it = iteratorcls(data, batch_size=val_batch_size,
device=device, batch_size_fn=batch_fn if train else None,
distributed=world_size>1, train=train, repeat=train, sort=sort,
shuffle=shuffle, reverse=reverse)
return it
def get_learning_rate(i, args):
transformer_lr = 1. / math.sqrt(args.dimension) * min(
1 / math.sqrt(i), i / (args.warmup * math.sqrt(args.warmup)))
@ -231,17 +217,16 @@ def train(args, model, opt, train_sets, train_iterations, field, rank=0, world_s
epoch = 0
logger.info(f'Preparing iterators')
train_iters = [(task, to_iter(args, world_size, tok, x, device, token_testing=args.token_testing))
for task, x, tok in zip(args.train_tasks, train_sets, args.train_batch_tokens)]
train_iters = [(task, make_data_loader(x, field, tok, device, train=True))
for task, x, tok in zip(args.train_tasks, train_sets, args.train_batch_tokens)]
train_iters = [(task, iter(train_iter)) for task, train_iter in train_iters]
val_iters = [(task, to_iter(args, world_size, tok, x, device, train=False, token_testing=args.token_testing, sort=False if 'sql' in task.name else None))
for task, x, tok in zip(args.val_tasks, val_sets, args.val_batch_size)]
val_iters = [(task, make_data_loader(x, field, bs, device, train=False))
for task, x, bs in zip(args.val_tasks, val_sets, args.val_batch_size)]
if args.use_curriculum:
aux_iters = [(name, to_iter(args, world_size, tok, x, device, token_testing=args.token_testing))
for name, x, tok in zip(args.train_tasks, aux_sets, args.train_batch_tokens)]
aux_iters = [(name, make_data_loader(x, field, tok, device, train=True))
for name, x, tok in zip(args.train_tasks, aux_sets, args.train_batch_tokens)]
aux_iters = [(task, iter(aux_iter)) for task, aux_iter in aux_iters]
zero_loss = 0

View File

@ -36,10 +36,14 @@ import random
import numpy as np
import ujson as json
import logging
from .data.numericalizer import SimpleNumericalizer
from .data.example import Batch
from .data.iterator import Iterator
logger = logging.getLogger(__name__)
def tokenizer(s):
return s.split()
@ -153,6 +157,21 @@ def batch_fn(new, i, sofar):
return max(len(new.context), 5*len(new.answer), prev_max_len) * i
def make_data_loader(dataset, field, batch_size, device=None, train=False):
numericalizer = SimpleNumericalizer(field.vocab, init_token=field.init_token, eos_token=field.eos_token,
pad_token=field.pad_token, unk_token=field.unk_token,
fix_length=field.fix_length, pad_first=field.pad_first)
iterator = Iterator(dataset, batch_size,
batch_size_fn=batch_fn if train else None,
shuffle=train,
repeat=train,
bucket_by_sort_key=train)
return torch.utils.data.DataLoader(iterator, batch_size=None,
collate_fn=lambda minibatch: Batch.from_examples(minibatch, numericalizer,
field.decoder_vocab,
device=device))
def pad(x, new_channel, dim, val=None):
if x.size(dim) > new_channel:
x = x.narrow(dim, 0, new_channel)