diff --git a/decanlp/arguments.py b/decanlp/arguments.py index 5d20477b..d1781bc7 100644 --- a/decanlp/arguments.py +++ b/decanlp/arguments.py @@ -124,8 +124,6 @@ def parse(argv): parser.add_argument('--no_commit', action='store_false', dest='commit', help='do not track the git commit associated with this training run') parser.add_argument('--exist_ok', action='store_true', help='Ok if the save directory already exists, i.e. overwrite is ok') - parser.add_argument('--token_testing', action='store_true', help='if true, sorts all iterators') - parser.add_argument('--reverse', action='store_true', help='if token_testing and true, sorts all iterators in reverse') parser.add_argument('--skip_cache', action='store_true', dest='skip_cache_bool', help='whether to use exisiting cached splits or generate new ones') parser.add_argument('--lr_rate', default=0.001, type=float, help='initial_learning_rate') diff --git a/decanlp/data/__init__.py b/decanlp/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/decanlp/data/example.py b/decanlp/data/example.py new file mode 100644 index 00000000..19bed54b --- /dev/null +++ b/decanlp/data/example.py @@ -0,0 +1,76 @@ +# +# Copyright (c) 2018, Salesforce, Inc. +# The Board of Trustees of the Leland Stanford Junior University +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch +import random +from typing import NamedTuple, List + +from .numericalizer import DecoderVocabulary, SimpleNumericalizer, SequentialField + + +class Example(NamedTuple): + example_id : List[str] + context : List[str] + question : List[str] + answer : List[str] + + @staticmethod + def from_raw(example_id : str, context : str, question : str, answer : str, tokenize, lower=False): + args = [[example_id]] + for arg in (context, question, answer): + new_arg = tokenize(arg.rstrip('\n')) + if lower: + new_arg = [word.lower() for word in new_arg] + args.append(new_arg) + return Example(*args) + + +class Batch(NamedTuple): + example_id : List[str] + context : SequentialField + question : SequentialField + answer : SequentialField + decoder_vocab : DecoderVocabulary + + @staticmethod + def from_examples(examples, numericalizer : SimpleNumericalizer, decoder_vocab, device=None): + example_ids = [ex.example_id for ex in examples] + context_input = [ex.context for ex in examples] + question_input = [ex.question for ex in examples] + answer_input = [ex.answer for ex in examples] + decoder_vocab = decoder_vocab.clone() + + return Batch(example_ids, + numericalizer.encode(context_input, decoder_vocab, device=device), + numericalizer.encode(question_input, decoder_vocab, device=device), + numericalizer.encode(answer_input, decoder_vocab, device=device), + decoder_vocab) + + diff --git a/decanlp/data/iterator.py b/decanlp/data/iterator.py new file mode 100644 index 00000000..29d617bd --- /dev/null +++ b/decanlp/data/iterator.py @@ -0,0 +1,113 @@ +# +# Copyright (c) 2018, Salesforce, Inc. +# The Board of Trustees of the Leland Stanford Junior University +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch +import random + + +class Iterator(torch.utils.data.IterableDataset): + def __init__(self, + dataset : torch.utils.data.Dataset, + batch_size, + shuffle=False, + repeat=False, + batch_size_fn=None, + bucket_by_sort_key=False): + self.dataset = dataset + + self.batch_size = batch_size + self.shuffle = shuffle + self.repeat = repeat + + if batch_size_fn is None: + def batch_size_fn(new, count, sofar): + return count + self.batch_size_fn = batch_size_fn + self.bucket_by_sort_key = bucket_by_sort_key + + def __len__(self): + if self.repeat: + raise NotImplementedError() + else: + return len(self.dataset) + + def __iter__(self): + while self.repeat: + if self.shuffle: + dataset = list(self.dataset) + random.shuffle(dataset) + else: + dataset = self.dataset + + if self.bucket_by_sort_key: + batches = self._pool(dataset) + else: + batches = self._batch(dataset, self.batch_size) + + for minibatch in batches: + yield minibatch + + def _batch(self, data, batch_size): + """Yield elements from data in chunks of batch_size.""" + minibatch = [] + size_so_far = 0 + for ex in data: + minibatch.append(ex) + size_so_far = self.batch_size_fn(ex, len(minibatch), size_so_far) + if size_so_far == batch_size: + yield minibatch + minibatch, size_so_far = [], 0 + elif size_so_far > batch_size: + if len(minibatch) == 1: # if we only have one really big example + yield minibatch + minibatch, size_so_far = [], 0 + else: + yield minibatch[:-1] + minibatch, size_so_far = minibatch[-1:], self.batch_size_fn(ex, 1, 0) + if size_so_far > batch_size: # if we add a really big example that needs to be on its own to a batch + yield minibatch + minibatch, size_so_far = [], 0 + if minibatch: + yield minibatch + + def _pool(self, data): + """Sort within buckets, then batch, then shuffle batches. + + Partitions data into chunks of size 100*batch_size, sorts examples within + each chunk using sort_key, then batch these examples and shuffle the + batches. + """ + for p in self._batch(data, self.batch_size * 100): + p_batch = self._batch(sorted(p, key=self.dataset.sort_key), self.batch_size) + if self.shuffle: + p_batch = list(p_batch) + random.shuffle(p_batch) + for b in p_batch: + yield b \ No newline at end of file diff --git a/decanlp/data/numericalizer.py b/decanlp/data/numericalizer.py new file mode 100644 index 00000000..8bf89169 --- /dev/null +++ b/decanlp/data/numericalizer.py @@ -0,0 +1,132 @@ +# +# Copyright (c) 2018, The Board of Trustees of the Leland Stanford Junior University +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import torch +from typing import NamedTuple, List + + +class SequentialField(NamedTuple): + value : torch.tensor + length : torch.tensor + limited : torch.tensor + tokens : List[str] + + +class DecoderVocabulary(object): + def __init__(self, words, full_vocab): + self.full_vocab = full_vocab + if words is not None: + self.itos = words + self.stoi = { word: idx for idx, word in enumerate(words) } + else: + self.itos = [] + self.stoi = dict() + self.oov_itos = [] + self.oov_stoi = dict() + + @property + def max_generative_vocab(self): + return len(self.itos) + + def clone(self): + new_subset = DecoderVocabulary(None, self.full_vocab) + new_subset.itos = self.itos + new_subset.stoi = self.stoi + return new_subset + + def __len__(self): + return len(self.itos) + len(self.oov_itos) + + def encode(self, word): + if word in self.stoi: + lim_idx = self.stoi[word] + elif word in self.oov_stoi: + lim_idx = self.oov_stoi[word] + else: + lim_idx = len(self) + self.oov_itos.append(word) + self.oov_stoi[word] = lim_idx + return lim_idx + + def decode(self, lim_idx): + if lim_idx < len(self.itos): + return self.itos[lim_idx] + else: + return self.oov_itos[lim_idx-len(self.itos)] + + +class SimpleNumericalizer(object): + def __init__(self, vocab, init_token=None, eos_token=None, pad_token="", unk_token="", + fix_length=None, pad_first=False): + self.vocab = vocab + self.init_token = init_token + self.eos_token = eos_token + self.unk_token = unk_token + self.pad_token = pad_token + self.fix_length = fix_length + self.pad_first = pad_first + + def encode(self, minibatch, decoder_vocab : DecoderVocabulary, device=None): + if not isinstance(minibatch, list): + minibatch = list(minibatch) + if self.fix_length is None: + max_len = max(len(x) for x in minibatch) + else: + max_len = self.fix_length + ( + self.init_token, self.eos_token).count(None) - 2 + padded = [] + lengths = [] + numerical = [] + decoder_numerical = [] + for example in minibatch: + if self.pad_first: + padded_example = [self.pad_token] * max(0, max_len - len(example)) + \ + ([] if self.init_token is None else [self.init_token]) + \ + list(example[:max_len]) + \ + ([] if self.eos_token is None else [self.eos_token]) + else: + padded_example = ([] if self.init_token is None else [self.init_token]) + \ + list(example[:max_len]) + \ + ([] if self.eos_token is None else [self.eos_token]) + \ + [self.pad_token] * max(0, max_len - len(example)) + + lengths.append(len(padded_example) - max(0, max_len - len(example))) + + numerical.append([self.vocab.stoi[word] for word in padded_example]) + decoder_numerical.append([decoder_vocab.encode(word) for word in padded_example]) + + length = torch.tensor(lengths, dtype=torch.int32, device=device) + numerical = torch.tensor(numerical, dtype=torch.int64, device=device) + decoder_numerical = torch.tensor(decoder_numerical, dtype=torch.int64, device=device) + + return SequentialField(tokens=padded, length=length, value=numerical, limited=decoder_numerical) + + def decode(self, tensor): + return [self.vocab.itos[idx] for idx in tensor] \ No newline at end of file diff --git a/decanlp/models/multitask_question_answering_network.py b/decanlp/models/multitask_question_answering_network.py index c28b7201..558c1e93 100644 --- a/decanlp/models/multitask_question_answering_network.py +++ b/decanlp/models/multitask_question_answering_network.py @@ -126,11 +126,9 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): context, context_lengths, context_limited, context_tokens = batch.context, batch.context_lengths, batch.context_limited, batch.context_tokens question, question_lengths, question_limited, question_tokens = batch.question, batch.question_lengths, batch.question_limited, batch.question_tokens answer, answer_lengths, answer_limited, answer_tokens = batch.answer, batch.answer_lengths, batch.answer_limited, batch.answer_tokens - oov_to_limited_idx, limited_idx_to_full_idx = batch.oov_to_limited_idx, batch.limited_idx_to_full_idx + decoder_vocab = batch.decoder_vocab - def map_to_full(x): - return limited_idx_to_full_idx[x] - self.map_to_full = map_to_full + self.map_to_full = decoder_vocab.decode context_embedded = self.encoder_embeddings(context) question_embedded = self.encoder_embeddings(question) @@ -191,7 +189,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch, context_attention, question_attention, context_indices, question_indices, - oov_to_limited_idx) + decoder_vocab) if self.args.use_maxmargin_loss: @@ -206,7 +204,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): else: return None, self.greedy(self_attended_context, final_context, final_question, context_indices, question_indices, - oov_to_limited_idx, rnn_state=context_rnn_state).data + decoder_vocab, rnn_state=context_rnn_state).data def reshape_rnn_state(self, h): return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \ @@ -216,7 +214,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): def probs(self, generator, outputs, vocab_pointer_switches, context_question_switches, context_attention, question_attention, context_indices, question_indices, - oov_to_limited_idx): + decoder_vocab): size = list(outputs.size()) @@ -225,7 +223,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): p_vocab = F.softmax(scores, dim=scores.dim()-1) scaled_p_vocab = vocab_pointer_switches.expand_as(p_vocab) * p_vocab - effective_vocab_size = self.generative_vocab_size + len(oov_to_limited_idx) + effective_vocab_size = len(decoder_vocab) if self.generative_vocab_size < effective_vocab_size: size[-1] = effective_vocab_size - self.generative_vocab_size buff = scaled_p_vocab.new_full(size, EPSILON) @@ -242,7 +240,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): return scaled_p_vocab - def greedy(self, self_attended_context, context, question, context_indices, question_indices, oov_to_limited_idx, rnn_state=None): + def greedy(self, self_attended_context, context, question, context_indices, question_indices, decoder_vocab, rnn_state=None): B, TC, C = context.size() T = self.args.max_output_length outs = context.new_full((B, T), self.field.decoder_stoi[''], dtype=torch.long) @@ -300,7 +298,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch, context_attention, question_attention, context_indices, question_indices, - oov_to_limited_idx) + decoder_vocab) pred_probs, preds = probs.max(-1) preds = preds.squeeze(1) eos_yet = eos_yet | (preds == self.field.decoder_stoi['']) diff --git a/decanlp/predict.py b/decanlp/predict.py index e46b6525..56a4388c 100644 --- a/decanlp/predict.py +++ b/decanlp/predict.py @@ -38,7 +38,7 @@ import sys import logging from pprint import pformat -from .util import set_seed, preprocess_examples, load_config_json +from .util import set_seed, preprocess_examples, load_config_json, make_data_loader from .metrics import compute_metrics from .utils.embeddings import load_embeddings from .tasks.registry import get_tasks @@ -85,21 +85,12 @@ def prepare_data(args, FIELD): return FIELD, splits -def to_iter(data, bs, device): - it = Iterator(data, batch_size=bs, - device=device, batch_size_fn=None, - train=False, repeat=False, sort=False, - shuffle=False, reverse=False) - - return it - - def run(args, field, val_sets, model): device = set_seed(args) logger.info(f'Preparing iterators') if len(args.val_batch_size) == 1 and len(val_sets) > 1: args.val_batch_size *= len(val_sets) - iters = [(name, to_iter(x, bs, device)) for name, x, bs in zip(args.tasks, val_sets, args.val_batch_size)] + iters = [(name, make_data_loader(x, field, bs, device)) for name, x, bs in zip(args.tasks, val_sets, args.val_batch_size)] def mult(ps): r = 0 diff --git a/decanlp/tasks/generic_dataset.py b/decanlp/tasks/generic_dataset.py index 234ab3bf..e8b1a119 100644 --- a/decanlp/tasks/generic_dataset.py +++ b/decanlp/tasks/generic_dataset.py @@ -39,30 +39,13 @@ import hashlib import unicodedata import logging import xml.etree.ElementTree as ET -from typing import NamedTuple, List from ..text import data +from decanlp.data.example import Example logger = logging.getLogger(__name__) -class Example(NamedTuple): - example_id : List[str] - context : List[str] - question : List[str] - answer : List[str] - - @staticmethod - def from_raw(example_id : str, context : str, question : str, answer : str, tokenize, lower=False): - args = [[example_id]] - for arg in (context, question, answer): - new_arg = tokenize(arg.rstrip('\n')) - if lower: - new_arg = [word.lower() for word in new_arg] - args.append(new_arg) - return Example(*args) - - def make_example_id(dataset, example_id): return dataset.name + '/' + str(example_id) diff --git a/decanlp/train.py b/decanlp/train.py index 09c0e772..d57f77d9 100644 --- a/decanlp/train.py +++ b/decanlp/train.py @@ -48,10 +48,11 @@ from tensorboardX import SummaryWriter from . import arguments from .validate import validate from .multiprocess import Multiprocess -from .util import elapsed_time, batch_fn, set_seed, preprocess_examples, get_trainable_params +from .util import elapsed_time, set_seed, preprocess_examples, get_trainable_params, make_data_loader from .utils.saver import Saver from .utils.embeddings import load_embeddings -from .text.data import ReversibleField, BucketIterator, Iterator +from .text.data import ReversibleField +from .data.numericalizer import DecoderVocabulary def initialize_logger(args, rank='main'): @@ -130,10 +131,7 @@ def prepare_data(args, field, logger): logger.info(f'Building vocabulary') FIELD.build_vocab(*vocab_sets, max_size=args.max_effective_vocab, vectors=vectors) - FIELD.decoder_itos = FIELD.vocab.itos[:args.max_generative_vocab] - FIELD.decoder_stoi = {word: idx for idx, word in enumerate(FIELD.decoder_itos)} - FIELD.decoder_to_vocab = {idx: FIELD.vocab.stoi[word] for idx, word in enumerate(FIELD.decoder_itos)} - FIELD.vocab_to_decoder = {idx: FIELD.decoder_stoi[word] for idx, word in enumerate(FIELD.vocab.itos) if word in FIELD.decoder_stoi} + FIELD.decoder_vocab = DecoderVocabulary(FIELD.vocab.itos[:args.max_generative_vocab], FIELD.vocab) logger.info(f'Vocabulary has {len(FIELD.vocab)} tokens') logger.debug(f'The first 200 tokens:') @@ -150,18 +148,6 @@ def prepare_data(args, field, logger): return FIELD, train_sets, val_sets, aux_sets -def to_iter(args, world_size, val_batch_size, data, device, train=True, token_testing=False, sort=None): - sort = sort if not token_testing else True - shuffle = None if not token_testing else False - reverse = args.reverse - iteratorcls = BucketIterator if train else Iterator - it = iteratorcls(data, batch_size=val_batch_size, - device=device, batch_size_fn=batch_fn if train else None, - distributed=world_size>1, train=train, repeat=train, sort=sort, - shuffle=shuffle, reverse=reverse) - return it - - def get_learning_rate(i, args): transformer_lr = 1. / math.sqrt(args.dimension) * min( 1 / math.sqrt(i), i / (args.warmup * math.sqrt(args.warmup))) @@ -231,17 +217,16 @@ def train(args, model, opt, train_sets, train_iterations, field, rank=0, world_s epoch = 0 logger.info(f'Preparing iterators') - train_iters = [(task, to_iter(args, world_size, tok, x, device, token_testing=args.token_testing)) - for task, x, tok in zip(args.train_tasks, train_sets, args.train_batch_tokens)] + train_iters = [(task, make_data_loader(x, field, tok, device, train=True)) + for task, x, tok in zip(args.train_tasks, train_sets, args.train_batch_tokens)] train_iters = [(task, iter(train_iter)) for task, train_iter in train_iters] - val_iters = [(task, to_iter(args, world_size, tok, x, device, train=False, token_testing=args.token_testing, sort=False if 'sql' in task.name else None)) - for task, x, tok in zip(args.val_tasks, val_sets, args.val_batch_size)] - + val_iters = [(task, make_data_loader(x, field, bs, device, train=False)) + for task, x, bs in zip(args.val_tasks, val_sets, args.val_batch_size)] if args.use_curriculum: - aux_iters = [(name, to_iter(args, world_size, tok, x, device, token_testing=args.token_testing)) - for name, x, tok in zip(args.train_tasks, aux_sets, args.train_batch_tokens)] + aux_iters = [(name, make_data_loader(x, field, tok, device, train=True)) + for name, x, tok in zip(args.train_tasks, aux_sets, args.train_batch_tokens)] aux_iters = [(task, iter(aux_iter)) for task, aux_iter in aux_iters] zero_loss = 0 diff --git a/decanlp/util.py b/decanlp/util.py index b4479151..7b7918c6 100644 --- a/decanlp/util.py +++ b/decanlp/util.py @@ -36,10 +36,14 @@ import random import numpy as np import ujson as json import logging +from .data.numericalizer import SimpleNumericalizer +from .data.example import Batch +from .data.iterator import Iterator logger = logging.getLogger(__name__) + def tokenizer(s): return s.split() @@ -153,6 +157,21 @@ def batch_fn(new, i, sofar): return max(len(new.context), 5*len(new.answer), prev_max_len) * i +def make_data_loader(dataset, field, batch_size, device=None, train=False): + numericalizer = SimpleNumericalizer(field.vocab, init_token=field.init_token, eos_token=field.eos_token, + pad_token=field.pad_token, unk_token=field.unk_token, + fix_length=field.fix_length, pad_first=field.pad_first) + + iterator = Iterator(dataset, batch_size, + batch_size_fn=batch_fn if train else None, + shuffle=train, + repeat=train, + bucket_by_sort_key=train) + return torch.utils.data.DataLoader(iterator, batch_size=None, + collate_fn=lambda minibatch: Batch.from_examples(minibatch, numericalizer, + field.decoder_vocab, + device=device)) + def pad(x, new_channel, dim, val=None): if x.size(dim) > new_channel: x = x.narrow(dim, 0, new_channel)