Switch to torch.utils.data.DataLoader to load the datasets
Introduce new classes for numericalization and padding, in preparation for moving away from torchtext.Field and torchtext.Batch. The new classes are based on namedtuples, which are more compatible with pytorch. Move the useful parts of torchtext.Iterator (batching and bucketing by example length) into a new Iterator class.
This commit is contained in:
parent
f6a2e4608d
commit
50b9c05248
|
@ -124,8 +124,6 @@ def parse(argv):
|
|||
|
||||
parser.add_argument('--no_commit', action='store_false', dest='commit', help='do not track the git commit associated with this training run')
|
||||
parser.add_argument('--exist_ok', action='store_true', help='Ok if the save directory already exists, i.e. overwrite is ok')
|
||||
parser.add_argument('--token_testing', action='store_true', help='if true, sorts all iterators')
|
||||
parser.add_argument('--reverse', action='store_true', help='if token_testing and true, sorts all iterators in reverse')
|
||||
|
||||
parser.add_argument('--skip_cache', action='store_true', dest='skip_cache_bool', help='whether to use exisiting cached splits or generate new ones')
|
||||
parser.add_argument('--lr_rate', default=0.001, type=float, help='initial_learning_rate')
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
#
|
||||
# Copyright (c) 2018, Salesforce, Inc.
|
||||
# The Board of Trustees of the Leland Stanford Junior University
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import torch
|
||||
import random
|
||||
from typing import NamedTuple, List
|
||||
|
||||
from .numericalizer import DecoderVocabulary, SimpleNumericalizer, SequentialField
|
||||
|
||||
|
||||
class Example(NamedTuple):
|
||||
example_id : List[str]
|
||||
context : List[str]
|
||||
question : List[str]
|
||||
answer : List[str]
|
||||
|
||||
@staticmethod
|
||||
def from_raw(example_id : str, context : str, question : str, answer : str, tokenize, lower=False):
|
||||
args = [[example_id]]
|
||||
for arg in (context, question, answer):
|
||||
new_arg = tokenize(arg.rstrip('\n'))
|
||||
if lower:
|
||||
new_arg = [word.lower() for word in new_arg]
|
||||
args.append(new_arg)
|
||||
return Example(*args)
|
||||
|
||||
|
||||
class Batch(NamedTuple):
|
||||
example_id : List[str]
|
||||
context : SequentialField
|
||||
question : SequentialField
|
||||
answer : SequentialField
|
||||
decoder_vocab : DecoderVocabulary
|
||||
|
||||
@staticmethod
|
||||
def from_examples(examples, numericalizer : SimpleNumericalizer, decoder_vocab, device=None):
|
||||
example_ids = [ex.example_id for ex in examples]
|
||||
context_input = [ex.context for ex in examples]
|
||||
question_input = [ex.question for ex in examples]
|
||||
answer_input = [ex.answer for ex in examples]
|
||||
decoder_vocab = decoder_vocab.clone()
|
||||
|
||||
return Batch(example_ids,
|
||||
numericalizer.encode(context_input, decoder_vocab, device=device),
|
||||
numericalizer.encode(question_input, decoder_vocab, device=device),
|
||||
numericalizer.encode(answer_input, decoder_vocab, device=device),
|
||||
decoder_vocab)
|
||||
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
#
|
||||
# Copyright (c) 2018, Salesforce, Inc.
|
||||
# The Board of Trustees of the Leland Stanford Junior University
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import torch
|
||||
import random
|
||||
|
||||
|
||||
class Iterator(torch.utils.data.IterableDataset):
|
||||
def __init__(self,
|
||||
dataset : torch.utils.data.Dataset,
|
||||
batch_size,
|
||||
shuffle=False,
|
||||
repeat=False,
|
||||
batch_size_fn=None,
|
||||
bucket_by_sort_key=False):
|
||||
self.dataset = dataset
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
self.repeat = repeat
|
||||
|
||||
if batch_size_fn is None:
|
||||
def batch_size_fn(new, count, sofar):
|
||||
return count
|
||||
self.batch_size_fn = batch_size_fn
|
||||
self.bucket_by_sort_key = bucket_by_sort_key
|
||||
|
||||
def __len__(self):
|
||||
if self.repeat:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
return len(self.dataset)
|
||||
|
||||
def __iter__(self):
|
||||
while self.repeat:
|
||||
if self.shuffle:
|
||||
dataset = list(self.dataset)
|
||||
random.shuffle(dataset)
|
||||
else:
|
||||
dataset = self.dataset
|
||||
|
||||
if self.bucket_by_sort_key:
|
||||
batches = self._pool(dataset)
|
||||
else:
|
||||
batches = self._batch(dataset, self.batch_size)
|
||||
|
||||
for minibatch in batches:
|
||||
yield minibatch
|
||||
|
||||
def _batch(self, data, batch_size):
|
||||
"""Yield elements from data in chunks of batch_size."""
|
||||
minibatch = []
|
||||
size_so_far = 0
|
||||
for ex in data:
|
||||
minibatch.append(ex)
|
||||
size_so_far = self.batch_size_fn(ex, len(minibatch), size_so_far)
|
||||
if size_so_far == batch_size:
|
||||
yield minibatch
|
||||
minibatch, size_so_far = [], 0
|
||||
elif size_so_far > batch_size:
|
||||
if len(minibatch) == 1: # if we only have one really big example
|
||||
yield minibatch
|
||||
minibatch, size_so_far = [], 0
|
||||
else:
|
||||
yield minibatch[:-1]
|
||||
minibatch, size_so_far = minibatch[-1:], self.batch_size_fn(ex, 1, 0)
|
||||
if size_so_far > batch_size: # if we add a really big example that needs to be on its own to a batch
|
||||
yield minibatch
|
||||
minibatch, size_so_far = [], 0
|
||||
if minibatch:
|
||||
yield minibatch
|
||||
|
||||
def _pool(self, data):
|
||||
"""Sort within buckets, then batch, then shuffle batches.
|
||||
|
||||
Partitions data into chunks of size 100*batch_size, sorts examples within
|
||||
each chunk using sort_key, then batch these examples and shuffle the
|
||||
batches.
|
||||
"""
|
||||
for p in self._batch(data, self.batch_size * 100):
|
||||
p_batch = self._batch(sorted(p, key=self.dataset.sort_key), self.batch_size)
|
||||
if self.shuffle:
|
||||
p_batch = list(p_batch)
|
||||
random.shuffle(p_batch)
|
||||
for b in p_batch:
|
||||
yield b
|
|
@ -0,0 +1,132 @@
|
|||
#
|
||||
# Copyright (c) 2018, The Board of Trustees of the Leland Stanford Junior University
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
import torch
|
||||
from typing import NamedTuple, List
|
||||
|
||||
|
||||
class SequentialField(NamedTuple):
|
||||
value : torch.tensor
|
||||
length : torch.tensor
|
||||
limited : torch.tensor
|
||||
tokens : List[str]
|
||||
|
||||
|
||||
class DecoderVocabulary(object):
|
||||
def __init__(self, words, full_vocab):
|
||||
self.full_vocab = full_vocab
|
||||
if words is not None:
|
||||
self.itos = words
|
||||
self.stoi = { word: idx for idx, word in enumerate(words) }
|
||||
else:
|
||||
self.itos = []
|
||||
self.stoi = dict()
|
||||
self.oov_itos = []
|
||||
self.oov_stoi = dict()
|
||||
|
||||
@property
|
||||
def max_generative_vocab(self):
|
||||
return len(self.itos)
|
||||
|
||||
def clone(self):
|
||||
new_subset = DecoderVocabulary(None, self.full_vocab)
|
||||
new_subset.itos = self.itos
|
||||
new_subset.stoi = self.stoi
|
||||
return new_subset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.itos) + len(self.oov_itos)
|
||||
|
||||
def encode(self, word):
|
||||
if word in self.stoi:
|
||||
lim_idx = self.stoi[word]
|
||||
elif word in self.oov_stoi:
|
||||
lim_idx = self.oov_stoi[word]
|
||||
else:
|
||||
lim_idx = len(self)
|
||||
self.oov_itos.append(word)
|
||||
self.oov_stoi[word] = lim_idx
|
||||
return lim_idx
|
||||
|
||||
def decode(self, lim_idx):
|
||||
if lim_idx < len(self.itos):
|
||||
return self.itos[lim_idx]
|
||||
else:
|
||||
return self.oov_itos[lim_idx-len(self.itos)]
|
||||
|
||||
|
||||
class SimpleNumericalizer(object):
|
||||
def __init__(self, vocab, init_token=None, eos_token=None, pad_token="<pad>", unk_token="<unk>",
|
||||
fix_length=None, pad_first=False):
|
||||
self.vocab = vocab
|
||||
self.init_token = init_token
|
||||
self.eos_token = eos_token
|
||||
self.unk_token = unk_token
|
||||
self.pad_token = pad_token
|
||||
self.fix_length = fix_length
|
||||
self.pad_first = pad_first
|
||||
|
||||
def encode(self, minibatch, decoder_vocab : DecoderVocabulary, device=None):
|
||||
if not isinstance(minibatch, list):
|
||||
minibatch = list(minibatch)
|
||||
if self.fix_length is None:
|
||||
max_len = max(len(x) for x in minibatch)
|
||||
else:
|
||||
max_len = self.fix_length + (
|
||||
self.init_token, self.eos_token).count(None) - 2
|
||||
padded = []
|
||||
lengths = []
|
||||
numerical = []
|
||||
decoder_numerical = []
|
||||
for example in minibatch:
|
||||
if self.pad_first:
|
||||
padded_example = [self.pad_token] * max(0, max_len - len(example)) + \
|
||||
([] if self.init_token is None else [self.init_token]) + \
|
||||
list(example[:max_len]) + \
|
||||
([] if self.eos_token is None else [self.eos_token])
|
||||
else:
|
||||
padded_example = ([] if self.init_token is None else [self.init_token]) + \
|
||||
list(example[:max_len]) + \
|
||||
([] if self.eos_token is None else [self.eos_token]) + \
|
||||
[self.pad_token] * max(0, max_len - len(example))
|
||||
|
||||
lengths.append(len(padded_example) - max(0, max_len - len(example)))
|
||||
|
||||
numerical.append([self.vocab.stoi[word] for word in padded_example])
|
||||
decoder_numerical.append([decoder_vocab.encode(word) for word in padded_example])
|
||||
|
||||
length = torch.tensor(lengths, dtype=torch.int32, device=device)
|
||||
numerical = torch.tensor(numerical, dtype=torch.int64, device=device)
|
||||
decoder_numerical = torch.tensor(decoder_numerical, dtype=torch.int64, device=device)
|
||||
|
||||
return SequentialField(tokens=padded, length=length, value=numerical, limited=decoder_numerical)
|
||||
|
||||
def decode(self, tensor):
|
||||
return [self.vocab.itos[idx] for idx in tensor]
|
|
@ -126,11 +126,9 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
context, context_lengths, context_limited, context_tokens = batch.context, batch.context_lengths, batch.context_limited, batch.context_tokens
|
||||
question, question_lengths, question_limited, question_tokens = batch.question, batch.question_lengths, batch.question_limited, batch.question_tokens
|
||||
answer, answer_lengths, answer_limited, answer_tokens = batch.answer, batch.answer_lengths, batch.answer_limited, batch.answer_tokens
|
||||
oov_to_limited_idx, limited_idx_to_full_idx = batch.oov_to_limited_idx, batch.limited_idx_to_full_idx
|
||||
decoder_vocab = batch.decoder_vocab
|
||||
|
||||
def map_to_full(x):
|
||||
return limited_idx_to_full_idx[x]
|
||||
self.map_to_full = map_to_full
|
||||
self.map_to_full = decoder_vocab.decode
|
||||
|
||||
context_embedded = self.encoder_embeddings(context)
|
||||
question_embedded = self.encoder_embeddings(question)
|
||||
|
@ -191,7 +189,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch,
|
||||
context_attention, question_attention,
|
||||
context_indices, question_indices,
|
||||
oov_to_limited_idx)
|
||||
decoder_vocab)
|
||||
|
||||
|
||||
if self.args.use_maxmargin_loss:
|
||||
|
@ -206,7 +204,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
else:
|
||||
return None, self.greedy(self_attended_context, final_context, final_question,
|
||||
context_indices, question_indices,
|
||||
oov_to_limited_idx, rnn_state=context_rnn_state).data
|
||||
decoder_vocab, rnn_state=context_rnn_state).data
|
||||
|
||||
def reshape_rnn_state(self, h):
|
||||
return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \
|
||||
|
@ -216,7 +214,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
def probs(self, generator, outputs, vocab_pointer_switches, context_question_switches,
|
||||
context_attention, question_attention,
|
||||
context_indices, question_indices,
|
||||
oov_to_limited_idx):
|
||||
decoder_vocab):
|
||||
|
||||
size = list(outputs.size())
|
||||
|
||||
|
@ -225,7 +223,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
p_vocab = F.softmax(scores, dim=scores.dim()-1)
|
||||
scaled_p_vocab = vocab_pointer_switches.expand_as(p_vocab) * p_vocab
|
||||
|
||||
effective_vocab_size = self.generative_vocab_size + len(oov_to_limited_idx)
|
||||
effective_vocab_size = len(decoder_vocab)
|
||||
if self.generative_vocab_size < effective_vocab_size:
|
||||
size[-1] = effective_vocab_size - self.generative_vocab_size
|
||||
buff = scaled_p_vocab.new_full(size, EPSILON)
|
||||
|
@ -242,7 +240,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
return scaled_p_vocab
|
||||
|
||||
|
||||
def greedy(self, self_attended_context, context, question, context_indices, question_indices, oov_to_limited_idx, rnn_state=None):
|
||||
def greedy(self, self_attended_context, context, question, context_indices, question_indices, decoder_vocab, rnn_state=None):
|
||||
B, TC, C = context.size()
|
||||
T = self.args.max_output_length
|
||||
outs = context.new_full((B, T), self.field.decoder_stoi['<pad>'], dtype=torch.long)
|
||||
|
@ -300,7 +298,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch,
|
||||
context_attention, question_attention,
|
||||
context_indices, question_indices,
|
||||
oov_to_limited_idx)
|
||||
decoder_vocab)
|
||||
pred_probs, preds = probs.max(-1)
|
||||
preds = preds.squeeze(1)
|
||||
eos_yet = eos_yet | (preds == self.field.decoder_stoi['<eos>'])
|
||||
|
|
|
@ -38,7 +38,7 @@ import sys
|
|||
import logging
|
||||
from pprint import pformat
|
||||
|
||||
from .util import set_seed, preprocess_examples, load_config_json
|
||||
from .util import set_seed, preprocess_examples, load_config_json, make_data_loader
|
||||
from .metrics import compute_metrics
|
||||
from .utils.embeddings import load_embeddings
|
||||
from .tasks.registry import get_tasks
|
||||
|
@ -85,21 +85,12 @@ def prepare_data(args, FIELD):
|
|||
return FIELD, splits
|
||||
|
||||
|
||||
def to_iter(data, bs, device):
|
||||
it = Iterator(data, batch_size=bs,
|
||||
device=device, batch_size_fn=None,
|
||||
train=False, repeat=False, sort=False,
|
||||
shuffle=False, reverse=False)
|
||||
|
||||
return it
|
||||
|
||||
|
||||
def run(args, field, val_sets, model):
|
||||
device = set_seed(args)
|
||||
logger.info(f'Preparing iterators')
|
||||
if len(args.val_batch_size) == 1 and len(val_sets) > 1:
|
||||
args.val_batch_size *= len(val_sets)
|
||||
iters = [(name, to_iter(x, bs, device)) for name, x, bs in zip(args.tasks, val_sets, args.val_batch_size)]
|
||||
iters = [(name, make_data_loader(x, field, bs, device)) for name, x, bs in zip(args.tasks, val_sets, args.val_batch_size)]
|
||||
|
||||
def mult(ps):
|
||||
r = 0
|
||||
|
|
|
@ -39,30 +39,13 @@ import hashlib
|
|||
import unicodedata
|
||||
import logging
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import NamedTuple, List
|
||||
|
||||
from ..text import data
|
||||
from decanlp.data.example import Example
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Example(NamedTuple):
|
||||
example_id : List[str]
|
||||
context : List[str]
|
||||
question : List[str]
|
||||
answer : List[str]
|
||||
|
||||
@staticmethod
|
||||
def from_raw(example_id : str, context : str, question : str, answer : str, tokenize, lower=False):
|
||||
args = [[example_id]]
|
||||
for arg in (context, question, answer):
|
||||
new_arg = tokenize(arg.rstrip('\n'))
|
||||
if lower:
|
||||
new_arg = [word.lower() for word in new_arg]
|
||||
args.append(new_arg)
|
||||
return Example(*args)
|
||||
|
||||
|
||||
def make_example_id(dataset, example_id):
|
||||
return dataset.name + '/' + str(example_id)
|
||||
|
||||
|
|
|
@ -48,10 +48,11 @@ from tensorboardX import SummaryWriter
|
|||
from . import arguments
|
||||
from .validate import validate
|
||||
from .multiprocess import Multiprocess
|
||||
from .util import elapsed_time, batch_fn, set_seed, preprocess_examples, get_trainable_params
|
||||
from .util import elapsed_time, set_seed, preprocess_examples, get_trainable_params, make_data_loader
|
||||
from .utils.saver import Saver
|
||||
from .utils.embeddings import load_embeddings
|
||||
from .text.data import ReversibleField, BucketIterator, Iterator
|
||||
from .text.data import ReversibleField
|
||||
from .data.numericalizer import DecoderVocabulary
|
||||
|
||||
|
||||
def initialize_logger(args, rank='main'):
|
||||
|
@ -130,10 +131,7 @@ def prepare_data(args, field, logger):
|
|||
logger.info(f'Building vocabulary')
|
||||
FIELD.build_vocab(*vocab_sets, max_size=args.max_effective_vocab, vectors=vectors)
|
||||
|
||||
FIELD.decoder_itos = FIELD.vocab.itos[:args.max_generative_vocab]
|
||||
FIELD.decoder_stoi = {word: idx for idx, word in enumerate(FIELD.decoder_itos)}
|
||||
FIELD.decoder_to_vocab = {idx: FIELD.vocab.stoi[word] for idx, word in enumerate(FIELD.decoder_itos)}
|
||||
FIELD.vocab_to_decoder = {idx: FIELD.decoder_stoi[word] for idx, word in enumerate(FIELD.vocab.itos) if word in FIELD.decoder_stoi}
|
||||
FIELD.decoder_vocab = DecoderVocabulary(FIELD.vocab.itos[:args.max_generative_vocab], FIELD.vocab)
|
||||
|
||||
logger.info(f'Vocabulary has {len(FIELD.vocab)} tokens')
|
||||
logger.debug(f'The first 200 tokens:')
|
||||
|
@ -150,18 +148,6 @@ def prepare_data(args, field, logger):
|
|||
return FIELD, train_sets, val_sets, aux_sets
|
||||
|
||||
|
||||
def to_iter(args, world_size, val_batch_size, data, device, train=True, token_testing=False, sort=None):
|
||||
sort = sort if not token_testing else True
|
||||
shuffle = None if not token_testing else False
|
||||
reverse = args.reverse
|
||||
iteratorcls = BucketIterator if train else Iterator
|
||||
it = iteratorcls(data, batch_size=val_batch_size,
|
||||
device=device, batch_size_fn=batch_fn if train else None,
|
||||
distributed=world_size>1, train=train, repeat=train, sort=sort,
|
||||
shuffle=shuffle, reverse=reverse)
|
||||
return it
|
||||
|
||||
|
||||
def get_learning_rate(i, args):
|
||||
transformer_lr = 1. / math.sqrt(args.dimension) * min(
|
||||
1 / math.sqrt(i), i / (args.warmup * math.sqrt(args.warmup)))
|
||||
|
@ -231,17 +217,16 @@ def train(args, model, opt, train_sets, train_iterations, field, rank=0, world_s
|
|||
epoch = 0
|
||||
|
||||
logger.info(f'Preparing iterators')
|
||||
train_iters = [(task, to_iter(args, world_size, tok, x, device, token_testing=args.token_testing))
|
||||
for task, x, tok in zip(args.train_tasks, train_sets, args.train_batch_tokens)]
|
||||
train_iters = [(task, make_data_loader(x, field, tok, device, train=True))
|
||||
for task, x, tok in zip(args.train_tasks, train_sets, args.train_batch_tokens)]
|
||||
train_iters = [(task, iter(train_iter)) for task, train_iter in train_iters]
|
||||
|
||||
val_iters = [(task, to_iter(args, world_size, tok, x, device, train=False, token_testing=args.token_testing, sort=False if 'sql' in task.name else None))
|
||||
for task, x, tok in zip(args.val_tasks, val_sets, args.val_batch_size)]
|
||||
|
||||
val_iters = [(task, make_data_loader(x, field, bs, device, train=False))
|
||||
for task, x, bs in zip(args.val_tasks, val_sets, args.val_batch_size)]
|
||||
|
||||
if args.use_curriculum:
|
||||
aux_iters = [(name, to_iter(args, world_size, tok, x, device, token_testing=args.token_testing))
|
||||
for name, x, tok in zip(args.train_tasks, aux_sets, args.train_batch_tokens)]
|
||||
aux_iters = [(name, make_data_loader(x, field, tok, device, train=True))
|
||||
for name, x, tok in zip(args.train_tasks, aux_sets, args.train_batch_tokens)]
|
||||
aux_iters = [(task, iter(aux_iter)) for task, aux_iter in aux_iters]
|
||||
|
||||
zero_loss = 0
|
||||
|
|
|
@ -36,10 +36,14 @@ import random
|
|||
import numpy as np
|
||||
import ujson as json
|
||||
import logging
|
||||
from .data.numericalizer import SimpleNumericalizer
|
||||
from .data.example import Batch
|
||||
from .data.iterator import Iterator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def tokenizer(s):
|
||||
return s.split()
|
||||
|
||||
|
@ -153,6 +157,21 @@ def batch_fn(new, i, sofar):
|
|||
return max(len(new.context), 5*len(new.answer), prev_max_len) * i
|
||||
|
||||
|
||||
def make_data_loader(dataset, field, batch_size, device=None, train=False):
|
||||
numericalizer = SimpleNumericalizer(field.vocab, init_token=field.init_token, eos_token=field.eos_token,
|
||||
pad_token=field.pad_token, unk_token=field.unk_token,
|
||||
fix_length=field.fix_length, pad_first=field.pad_first)
|
||||
|
||||
iterator = Iterator(dataset, batch_size,
|
||||
batch_size_fn=batch_fn if train else None,
|
||||
shuffle=train,
|
||||
repeat=train,
|
||||
bucket_by_sort_key=train)
|
||||
return torch.utils.data.DataLoader(iterator, batch_size=None,
|
||||
collate_fn=lambda minibatch: Batch.from_examples(minibatch, numericalizer,
|
||||
field.decoder_vocab,
|
||||
device=device))
|
||||
|
||||
def pad(x, new_channel, dim, val=None):
|
||||
if x.size(dim) > new_channel:
|
||||
x = x.narrow(dim, 0, new_channel)
|
||||
|
|
Loading…
Reference in New Issue