This commit is contained in:
Giovanni Campagna 2020-01-16 15:34:16 -08:00
parent bbcdb1a3e2
commit 1f7edc7b24
22 changed files with 1115 additions and 782 deletions

View File

@ -82,7 +82,6 @@ def parse(argv):
parser.add_argument('--vocab_tasks', nargs='+', type=str, help='tasks to use in the construction of the vocabulary')
parser.add_argument('--max_output_length', default=100, type=int, help='maximum output length for generation')
parser.add_argument('--max_effective_vocab', default=int(1e6), type=int, help='max effective vocabulary size for pretrained embeddings')
parser.add_argument('--max_generative_vocab', default=50000, type=int, help='max vocabulary for the generative softmax')
parser.add_argument('--max_train_context_length', default=500, type=int, help='maximum length of the contexts during training')
parser.add_argument('--max_val_context_length', default=500, type=int, help='maximum length of the contexts during validation')
@ -90,19 +89,22 @@ def parse(argv):
parser.add_argument('--subsample', default=20000000, type=int, help='subsample the datasets')
parser.add_argument('--preserve_case', action='store_false', dest='lower', help='whether to preserve casing for all text')
parser.add_argument('--model', type=str, default='MultitaskQuestionAnsweringNetwork', help='which model to import')
parser.add_argument('--model', type=str, choices=['Seq2Seq'], default='Seq2Seq', help='which model to import')
parser.add_argument('--seq2seq_encoder', type=str, choices=['MQANEncoder', 'Identity'], default='MQANEncoder',
help='which encoder to use for the Seq2Seq model')
parser.add_argument('--seq2seq_decoder', type=str, choices=['MQANDecoder'], default='MQANDecoder',
help='which decoder to use for the Seq2Seq model')
parser.add_argument('--dimension', default=200, type=int, help='output dimensions for all layers')
parser.add_argument('--rnn_layers', default=1, type=int, help='number of layers for RNN modules')
parser.add_argument('--transformer_layers', default=2, type=int, help='number of layers for transformer modules')
parser.add_argument('--transformer_hidden', default=150, type=int, help='hidden size of the transformer modules')
parser.add_argument('--transformer_heads', default=3, type=int, help='number of heads for transformer modules')
parser.add_argument('--dropout_ratio', default=0.2, type=float, help='dropout for the model')
parser.add_argument('--no_glove_and_char', action='store_false', dest='glove_and_char', help='turn off GloVe and CharNGram embeddings')
parser.add_argument('--locale', default='en', help='locale to use for word embeddings')
parser.add_argument('--retrain_encoder_embedding', default=False, action='store_true', help='whether to retrain encoder embeddings')
parser.add_argument('--trainable_decoder_embedding', default=0, type=int, help='size of trainable portion of decoder embedding (0 or omit to disable)')
parser.add_argument('--no_glove_decoder', action='store_false', dest='glove_decoder', help='turn off GloVe embeddings from decoder')
parser.add_argument('--pretrained_decoder_lm', help='pretrained language model to use as embedding layer for the decoder (omit to disable)')
parser.add_argument('--encoder_embeddings', default='glove+char', help='which word embedding to use on the encoder side; use a bert-* pretrained model for BERT; multiple embeddings can be concatenated with +')
parser.add_argument('--train_encoder_embeddings', action='store_true', default=False, help='back propagate into pretrained encoder embedding (recommended for BERT)')
parser.add_argument('--decoder_embeddings', default='glove+char', help='which pretrained word embedding to use on the decoder side')
parser.add_argument('--trainable_decoder_embeddings', default=0, type=int, help='size of trainable portion of decoder embedding (0 or omit to disable)')
parser.add_argument('--warmup', default=800, type=int, help='warmup for learning rate')
parser.add_argument('--grad_clip', default=1.0, type=float, help='gradient clipping')
@ -123,8 +125,6 @@ def parse(argv):
parser.add_argument('--skip_cache', action='store_true', dest='skip_cache_bool', help='whether to use exisiting cached splits or generate new ones')
parser.add_argument('--lr_rate', default=0.001, type=float, help='initial_learning_rate')
parser.add_argument('--small_glove', action='store_true', help='Use glove.6B.50d instead of glove.840B.300d')
parser.add_argument('--almond_type_embeddings', action='store_true', help='Add type-based word embeddings for Almond task')
parser.add_argument('--use_curriculum', action='store_true', help='Use curriculum learning')
parser.add_argument('--aux_dataset', default='', type=str, help='path to auxiliary dataset (ignored if curriculum is not used)')
parser.add_argument('--curriculum_max_frac', default=1.0, type=float, help='max fraction of harder dataset to keep for curriculum')

View File

@ -29,14 +29,12 @@
from argparse import ArgumentParser
import torch
import numpy as np
import random
import logging
import sys
from pprint import pformat
from .utils.embeddings import load_embeddings
from .util import set_seed
from .data.embeddings import load_embeddings
logger = logging.getLogger(__name__)
@ -44,8 +42,8 @@ logger = logging.getLogger(__name__)
def get_args(argv):
parser = ArgumentParser(prog=argv[0])
parser.add_argument('--seed', default=123, type=int, help='Random seed.')
parser.add_argument('--embeddings', default='./decaNLP/.embeddings', type=str, help='where to save embeddings.')
parser.add_argument('--locale', default='en', help='locale to use for word embeddings')
parser.add_argument('-d', '--destdir', default='./decaNLP/.embeddings', type=str, help='where to save embeddings.')
parser.add_argument('--embeddings', default='glove+char', help='which embeddings to download')
args = parser.parse_args(argv[1:])
return args
@ -55,9 +53,5 @@ def main(argv=sys.argv):
args = get_args(argv)
logger.info(f'Arguments:\n{pformat(vars(args))}')
np.random.seed(args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
load_embeddings(args, load_almond_embeddings=False)
set_seed(args.seed)
load_embeddings(args.destdir, args.embeddings, '')

View File

@ -30,17 +30,15 @@
import torch
import logging
from ..data import word_vectors
_logger = logging.getLogger(__name__)
from .word_vectors import Vectors
ENTITIES = ['DATE', 'DURATION', 'EMAIL_ADDRESS', 'HASHTAG',
'LOCATION', 'NUMBER', 'PHONE_NUMBER', 'QUOTED_STRING',
'TIME', 'URL', 'USERNAME', 'PATH_NAME', 'CURRENCY']
MAX_ARG_VALUES = 5
class AlmondEmbeddings(word_vectors.Vectors):
class AlmondEmbeddings(Vectors):
def __init__(self, name=None, cache=None, **kw):
super().__init__(name, cache, **kw)
@ -63,27 +61,4 @@ class AlmondEmbeddings(word_vectors.Vectors):
self.itos = itos
self.stoi = {word: i for i, word in enumerate(itos)}
self.vectors = torch.stack(vectors, dim=0).view(-1, dim)
self.dim = dim
def load_embeddings(args, logger=_logger, load_almond_embeddings=True):
logger.info(f'Getting pretrained word vectors')
language = args.locale.split('-')[0]
if language == 'en':
char_vectors = word_vectors.CharNGram(cache=args.embeddings)
if args.small_glove:
glove_vectors = word_vectors.GloVe(cache=args.embeddings, name="6B", dim=50)
else:
glove_vectors = word_vectors.GloVe(cache=args.embeddings)
vectors = [char_vectors, glove_vectors]
# elif args.locale == 'zh':
# Chinese word embeddings
else:
# default to fastText
vectors = [word_vectors.FastText(cache=args.embeddings, language=language)]
if load_almond_embeddings and args.almond_type_embeddings:
vectors.append(AlmondEmbeddings())
return vectors
self.dim = dim

232
decanlp/data/embeddings.py Normal file
View File

@ -0,0 +1,232 @@
#
# Copyright (c) 2018-2019, Salesforce, Inc.
# The Board of Trustees of the Leland Stanford Junior University
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import torch
import os
from collections import defaultdict
import logging
from transformers import AutoTokenizer, AutoModel, BertConfig
from typing import NamedTuple, List
from .numericalizer import SimpleNumericalizer, BertNumericalizer
from . import word_vectors
from .almond_embeddings import AlmondEmbeddings
from .pretrained_lstm_lm import PretrainedLTSMLM
_logger = logging.getLogger(__name__)
class EmbeddingOutput(NamedTuple):
all_layers: List[torch.Tensor]
last_layer: torch.Tensor
class WordVectorEmbedding(torch.nn.Module):
def __init__(self, vec_collection):
super().__init__()
self._vec_collection = vec_collection
self.dim = vec_collection.dim
self.num_layers = 0
self.embedding = None
def init_for_vocab(self, vocab):
vectors = torch.empty(len(vocab), self.dim, device=torch.device('cpu'))
for ti, token in enumerate(vocab.itos):
vectors[ti] = self._vec_collection[token.strip()]
self.embedding = torch.nn.Embedding(len(vocab.itos), self.dim)
self.embedding.weight.data = vectors
def grow_for_vocab(self, vocab, new_words):
if not new_words:
return
new_vectors = []
for word in new_words:
new_vector = self._vec_collection[word]
# charNgram returns a [1, D] tensor, while Glove returns a [D] tensor
# normalize to [1, D] so we can concat along the second dimension
# and later concat all vectors along the first
new_vector = new_vector if new_vector.dim() > 1 else new_vector.unsqueeze(0)
new_vectors.append(new_vector)
self.embedding.weight.data = torch.cat([self.embedding.weight.data.cpu()] + new_vectors, dim=0)
def forward(self, input : torch.Tensor, padding=None):
last_layer = self.embedding(input.cpu()).to(input.device)
return EmbeddingOutput(all_layers=[last_layer], last_layer=last_layer)
def to(self, *args, **kwargs):
# ignore attempts to move the word embedding, which should stay on CPU
kwargs['device'] = torch.device('cpu')
return super().to(*args, **kwargs)
def cuda(self, device=None):
# ignore attempts to move the word embedding
pass
class BertEmbedding(torch.nn.Module):
def __init__(self, model):
super().__init__()
model.config.output_hidden_states = True
self.dim = model.config.hidden_size
self.num_layers = model.config.num_hidden_layers
self.model = model
def init_for_vocab(self, vocab):
self.model.resize_token_embeddings(len(vocab))
def grow_for_vocab(self, vocab, new_words):
self.model.resize_token_embeddings(len(vocab))
def forward(self, input : torch.Tensor, padding=None):
last_hidden_state, _pooled, hidden_states = self.model(input, attention_mask=(~padding).to(dtype=torch.float))
return EmbeddingOutput(all_layers=hidden_states, last_layer=last_hidden_state)
class PretrainedLMEmbedding(torch.nn.Module):
def __init__(self, model_name, cachedir):
super().__init__()
# map to CPU first, we will be moved to the right place later
pretrained_save_dict = torch.load(os.path.join(cachedir, model_name), map_location=torch.device('cpu'))
self.itos = pretrained_save_dict['vocab']
self.stoi = defaultdict(lambda: 0, {
w: i for i, w in enumerate(self.pretrained_decoder_vocab_itos)
})
self.dim = pretrained_save_dict['settings']['nhid']
self.num_layers = 1
self.model = PretrainedLTSMLM(rnn_type=pretrained_save_dict['settings']['rnn_type'],
ntoken=len(self.pretrained_decoder_vocab_itos),
emsize=pretrained_save_dict['settings']['emsize'],
nhid=pretrained_save_dict['settings']['nhid'],
nlayers=pretrained_save_dict['settings']['nlayers'],
dropout=0.0)
self.model.load_state_dict(pretrained_save_dict['model'], strict=True)
self.vocab_to_pretrained = None
def init_for_vocab(self, vocab):
self.vocab_to_pretrained = torch.empty(len(self.vocab), dtype=torch.int64)
unk_id = self.stoi['<unk>']
for ti, token in enumerate(vocab.itos):
if token in self.pretrained_decoder_vocab_stoi:
self.vocab_to_pretrained[ti] = self.stoi[token]
else:
self.vocab_to_pretrained[ti] = unk_id
def grow_for_vocab(self, vocab, new_words):
self.init_for_vocab(vocab)
def forward(self, input : torch.Tensor, padding=None):
pretrained_indices = torch.gather(self.vocab_to_pretrained, dim=0, index=input)
rnn_output = self.model(pretrained_indices)
return EmbeddingOutput(all_layers=[rnn_output], last_layer=rnn_output)
def _is_bert(embedding_name):
return embedding_name.startswith('bert-')
def _name_to_vector(emb_name, cachedir):
if emb_name == 'glove':
return WordVectorEmbedding(word_vectors.GloVe(cache=cachedir))
elif emb_name == 'small_glove':
return WordVectorEmbedding(word_vectors.GloVe(cache=cachedir, name="6B", dim=50))
elif emb_name == 'char':
return WordVectorEmbedding(word_vectors.CharNGram(cache=cachedir))
elif emb_name == 'almond_type':
return AlmondEmbeddings()
elif emb_name.startswith('fasttext/'):
# FIXME this should use the fasttext library
return WordVectorEmbedding(word_vectors.FastText(cache=cachedir, language=emb_name[len('fasttext/'):]))
elif emb_name.startswith('pretrained_lstm/'):
return PretrainedLMEmbedding(emb_name[len('pretrained_lstm/'):], cachedir=cachedir)
else:
raise ValueError(f'Unrecognized embedding name {emb_name}')
def load_embeddings(cachedir, encoder_emb_names, decoder_emb_names, max_generative_vocab=50000, logger=_logger):
logger.info(f'Getting pretrained word vectors and pretrained models')
encoder_emb_names = encoder_emb_names.split('+')
decoder_emb_names = decoder_emb_names.split('+')
all_vectors = {}
encoder_vectors = []
decoder_vectors = []
numericalizer = None
for emb_name in encoder_emb_names:
if not emb_name:
continue
if _is_bert(emb_name):
if numericalizer is not None:
raise ValueError('Cannot specify multiple BERT embeddings')
config = BertConfig.from_pretrained(emb_name, cache_dir=cachedir)
config.output_hidden_states = True
numericalizer = BertNumericalizer(config, emb_name, max_generative_vocab=max_generative_vocab, cache=cachedir)
# load the tokenizer once to ensure all files are downloaded
AutoTokenizer.from_pretrained(emb_name, cache_dir=cachedir)
encoder_vectors.append(BertEmbedding(AutoModel.from_pretrained(emb_name, config=config, cache_dir=cachedir)))
else:
if numericalizer is not None:
logger.warning('Combining BERT embeddings with other pretrained embeddings is unlikely to work')
if emb_name in all_vectors:
encoder_vectors.append(all_vectors[emb_name])
else:
vec = _name_to_vector(emb_name, cachedir)
all_vectors[emb_name] = vec
encoder_vectors.append(vec)
for emb_name in decoder_emb_names:
if not emb_name:
continue
if _is_bert(emb_name):
raise ValueError('BERT embeddings cannot be specified in the decoder')
if emb_name in all_vectors:
decoder_vectors.append(all_vectors[emb_name])
else:
vec = _name_to_vector(emb_name, cachedir)
all_vectors[emb_name] = vec
decoder_vectors.append(vec)
if numericalizer is None:
numericalizer = SimpleNumericalizer(max_generative_vocab=max_generative_vocab, pad_first=False)
return numericalizer, encoder_vectors, decoder_vectors

View File

@ -41,7 +41,8 @@ class BertNumericalizer(object):
Numericalizer that uses BertTokenizer from huggingface's transformers library.
"""
def __init__(self, pretrained_tokenizer, max_generative_vocab, cache=None, fix_length=None):
def __init__(self, config, pretrained_tokenizer, max_generative_vocab, cache=None, fix_length=None):
self.config = config
self._pretrained_name = pretrained_tokenizer
self.max_generative_vocab = max_generative_vocab
self._cache = cache
@ -50,13 +51,18 @@ class BertNumericalizer(object):
self.fix_length = fix_length
@property
def vocab(self):
return self._tokenizer
@property
def num_tokens(self):
return self._tokenizer.vocab_size
return len(self._tokenizer)
def load(self, save_dir):
self.config = BertConfig.from_pretrained(os.path.join(save_dir, 'bert-config.json'), cache_dir=self._cache)
self._tokenizer = MaskedBertTokenizer.from_pretrained(save_dir, config=self.config, cache_dir=self._cache)
# HACK we cannot save the tokenizer without this
del self._tokenizer.init_kwargs['config']
with open(os.path.join(save_dir, 'decoder-vocab.txt'), 'r') as fp:
self._decoder_words = [line.rstrip('\n') for line in fp]
@ -64,15 +70,15 @@ class BertNumericalizer(object):
self._init()
def save(self, save_dir):
self.config.save_pretrained(os.path.join(save_dir, 'bert-config.json'))
self._tokenizer.save_pretrained(os.path.join(save_dir))
self._tokenizer.save_pretrained(save_dir)
with open(os.path.join(save_dir, 'decoder-vocab.txt'), 'w') as fp:
for word in self._decoder_words:
fp.write(word + '\n')
def build_vocab(self, vectors, vocab_fields, vocab_sets):
self.config = BertConfig.from_pretrained(self._pretrained_name, cache_dir=self._cache)
def build_vocab(self, vocab_fields, vocab_sets):
self._tokenizer = MaskedBertTokenizer.from_pretrained(self._pretrained_name, config=self.config, cache_dir=self._cache)
# HACK we cannot save the tokenizer without this
del self._tokenizer.init_kwargs['config']
# ensure that init, eos, unk and pad are set
# this method has no effect if the tokens are already set according to the tokenizer class
@ -83,19 +89,31 @@ class BertNumericalizer(object):
'pad_token': '[PAD]'
})
# do a pass over all the answers in the dataset, and construct a counter of wordpieces
# do a pass over all the data in the dataset
# in this pass, we
# 1) tokenize everything, to ensure we account for all added tokens
# 2) we construct a counter of wordpieces in the answers, for the decoder vocabulary
decoder_words = collections.Counter()
for dataset in vocab_sets:
for example in dataset:
self._tokenizer.tokenize(example.context, example.context_word_mask)
self._tokenizer.tokenize(example.question, example.question_word_mask)
tokens = self._tokenizer.tokenize(example.answer, example.answer_word_mask)
decoder_words.update(tokens)
self._decoder_words = decoder_words.most_common(self.max_generative_vocab)
self._decoder_words = [word for word, _freq in decoder_words.most_common(self.max_generative_vocab)]
self._init()
def grow_vocab(self, examples, vectors):
# TODO
def grow_vocab(self, examples):
# do a pass over all the data in the dataset and tokenize everything
# this will add any new tokens that are not to be converted into word-pieces
for example in examples:
self._tokenizer.tokenize(example.context, example.context_word_mask)
self._tokenizer.tokenize(example.question, example.question_word_mask)
# return no new words - BertEmbedding will resize the embedding regardless
return []
def _init(self):
@ -128,9 +146,9 @@ class BertNumericalizer(object):
wp_tokenized.append(self._tokenizer.tokenize(tokens, mask))
if self.fix_length is None:
max_len = max(len(x) for x in minibatch) + 2
max_len = max(len(x) for x in wp_tokenized)
else:
max_len = self.fix_length + 2
max_len = self.fix_length
padded = []
lengths = []
numerical = []
@ -148,7 +166,7 @@ class BertNumericalizer(object):
[self.pad_token] * max(0, max_len - len(wp_tokens))
padded.append(padded_example)
lengths.append(len(padded_example) - max(0, max_len - len(wp_tokens)))
lengths.append(len(wp_tokens) + 2)
numerical.append(self._tokenizer.convert_tokens_to_ids(padded_example))
decoder_numerical.append([decoder_vocab.encode(word) for word in padded_example])
@ -157,7 +175,7 @@ class BertNumericalizer(object):
numerical = torch.tensor(numerical, dtype=torch.int64, device=device)
decoder_numerical = torch.tensor(decoder_numerical, dtype=torch.int64, device=device)
return SequentialField(tokens=padded, length=length, value=numerical, limited=decoder_numerical)
return SequentialField(length=length, value=numerical, limited=decoder_numerical)
def decode(self, tensor):
return self._tokenizer.convert_ids_to_tokens(tensor)
@ -176,7 +194,10 @@ class BertNumericalizer(object):
if token in (self.init_token, self.pad_token):
continue
if token.startswith('##'):
tokens[-1] += token[2:]
if len(tokens) == 0:
tokens.append(token[2:])
else:
tokens[-1] += token[2:]
else:
tokens.append(token)

View File

@ -54,10 +54,10 @@ class MaskedWordPieceTokenizer:
def tokenize(self, tokens, mask):
output_tokens = []
for token, should_word_split in tokens, mask:
for token, should_word_split in zip(tokens, mask):
if not should_word_split:
if token not in self.vocab and token not in self.added_tokens_encoder:
token_id = len(self.added_tokens_encoder)
token_id = len(self.vocab) + len(self.added_tokens_encoder)
self.added_tokens_encoder[token] = token_id
self.added_tokens_decoder[token_id] = token
output_tokens.append(token)
@ -95,11 +95,60 @@ class MaskedWordPieceTokenizer:
return output_tokens
class IToSWrapper:
"""Wrap the ordered dict vocabs to look like a list int -> str"""
def __init__(self, base_vocab, added_tokens):
self.base_vocab = base_vocab
self.added_tokens = added_tokens
def __getitem__(self, key):
if isinstance(key, slice):
return [self[key] for key in range(key.start or 0, key.stop or len(self), key.step or 1)]
if key < len(self.base_vocab):
return self.base_vocab[key]
else:
return self.added_tokens[key]
def __len__(self):
return len(self.base_vocab) + len(self.added_tokens)
def __iter__(self):
for key in range(len(self.base_vocab)):
yield self.base_vocab[key]
for key in range(len(self.base_vocab), len(self.base_vocab) + len(self.added_tokens)):
yield self.added_tokens[key]
class SToIWrapper:
"""Wrap the ordered dict vocabs to look like a single dictionary"""
def __init__(self, base_vocab, added_tokens):
self.base_vocab = base_vocab
self.added_tokens = added_tokens
def __getitem__(self, key):
if key in self.base_vocab:
return self.base_vocab[key]
else:
return self.added_tokens[key]
def __len__(self):
return len(self.base_vocab) + len(self.added_tokens)
def __iter__(self):
for key in self.base_vocab:
yield key
for key in self.added_tokens:
yield key
class MaskedBertTokenizer(BertTokenizer):
"""
A modified BertTokenizer that respects a mask deciding whether a token should be split or not.
"""
def __init__(self, *args, do_lower_case, do_basic_tokenize, **kwargs):
def __init__(self, *args, do_lower_case=False, do_basic_tokenize=False, **kwargs):
# override do_lower_case and do_basic_tokenize unconditionally
super().__init__(*args, do_lower_case=False, do_basic_tokenize=False, **kwargs)
@ -109,14 +158,21 @@ class MaskedBertTokenizer(BertTokenizer):
added_tokens_decoder=self.added_tokens_decoder,
unk_token=self.unk_token)
self._itos = IToSWrapper(self.ids_to_tokens, self.added_tokens_decoder)
self._stoi = SToIWrapper(self.vocab, self.added_tokens_encoder)
def tokenize(self, tokens, mask=None):
return self.wordpiece_tokenizer.tokenize(tokens, mask)
# provide an interface that DecoderVocabulary can like
# provide an interface similar to Vocab
def __len__(self):
return len(self.vocab) + len(self.added_tokens_encoder)
@property
def stoi(self):
return self.vocab
return self._stoi
@property
def itos(self):
return self.ids_to_tokens
return self._itos

View File

@ -34,5 +34,4 @@ from typing import NamedTuple, List
class SequentialField(NamedTuple):
value : torch.tensor
length : torch.tensor
limited : torch.tensor
tokens : List[List[str]]
limited : torch.tensor

View File

@ -35,8 +35,7 @@ from .sequential_field import SequentialField
from .decoder_vocab import DecoderVocabulary
class SimpleNumericalizer(object):
def __init__(self, max_effective_vocab, max_generative_vocab, fix_length=None, pad_first=False):
self.max_effective_vocab = max_effective_vocab
def __init__(self, max_generative_vocab, fix_length=None, pad_first=False):
self.max_generative_vocab = max_generative_vocab
self.init_token = '<init>'
@ -58,17 +57,15 @@ class SimpleNumericalizer(object):
def save(self, save_dir):
torch.save(self.vocab, os.path.join(save_dir, 'vocab.pth'))
def build_vocab(self, vectors, vocab_fields, vocab_sets):
def build_vocab(self, vocab_fields, vocab_sets):
self.vocab = Vocab.build_from_data(vocab_fields, *vocab_sets,
unk_token=self.unk_token,
init_token=self.init_token,
eos_token=self.eos_token,
pad_token=self.pad_token,
max_size=self.max_effective_vocab,
vectors=vectors)
pad_token=self.pad_token)
self._init_vocab()
def _grow_vocab_one(self, sentence, vectors, new_vectors):
def _grow_vocab_one(self, sentence, new_words):
assert isinstance(sentence, list)
# check if all the words are in the vocabulary, and if not
@ -77,22 +74,15 @@ class SimpleNumericalizer(object):
if word not in self.vocab.stoi:
self.vocab.stoi[word] = len(self.vocab.itos)
self.vocab.itos.append(word)
new_words.append(word)
new_vector = [vec[word] for vec in vectors]
# charNgram returns a [1, D] tensor, while Glove returns a [D] tensor
# normalize to [1, D] so we can concat along the second dimension
# and later concat all vectors along the first
new_vector = [vec if vec.dim() > 1 else vec.unsqueeze(0) for vec in new_vector]
new_vectors.append(torch.cat(new_vector, dim=1))
def grow_vocab(self, examples, vectors):
new_vectors = []
def grow_vocab(self, examples):
new_words = []
for ex in examples:
self._grow_vocab_one(ex.context, vectors, new_vectors)
self._grow_vocab_one(ex.question, vectors, new_vectors)
self._grow_vocab_one(ex.answer, vectors, new_vectors)
return new_vectors
self._grow_vocab_one(ex.context, new_words)
self._grow_vocab_one(ex.question, new_words)
self._grow_vocab_one(ex.answer, new_words)
return new_words
def _init_vocab(self):
self.init_id = self.vocab.stoi[self.init_token]
@ -113,7 +103,7 @@ class SimpleNumericalizer(object):
if self.fix_length is None:
max_len = max(len(x[0]) for x in minibatch)
else:
max_len = self.fix_length + 2
max_len = self.fix_length
padded = []
lengths = []
numerical = []
@ -140,7 +130,7 @@ class SimpleNumericalizer(object):
numerical = torch.tensor(numerical, dtype=torch.int64, device=device)
decoder_numerical = torch.tensor(decoder_numerical, dtype=torch.int64, device=device)
return SequentialField(tokens=padded, length=length, value=numerical, limited=decoder_numerical)
return SequentialField(length=length, value=numerical, limited=decoder_numerical)
def decode(self, tensor):
return [self.vocab.itos[idx] for idx in tensor]

View File

@ -20,8 +20,7 @@ class Vocab(object):
numerical identifiers.
itos: A list of token strings indexed by their numerical identifiers.
"""
def __init__(self, counter, max_size=None, min_freq=1, specials=('<pad>',),
vectors=None, cat_vectors=True):
def __init__(self, counter, max_size=None, min_freq=1, specials=('<pad>',)):
"""Create a Vocab object from a collections.Counter.
Arguments:
@ -60,10 +59,6 @@ class Vocab(object):
self.itos.append(word)
self.stoi[word] = len(self.itos) - 1
self.vectors = None
if vectors is not None:
self.load_vectors(vectors, cat=cat_vectors)
def __eq__(self, other):
if self.freqs != other.freqs:
return False
@ -71,8 +66,6 @@ class Vocab(object):
return False
if self.itos != other.itos:
return False
if self.vectors != other.vectors:
return False
return True
def __len__(self):
@ -85,54 +78,6 @@ class Vocab(object):
self.itos.append(w)
self.stoi[w] = len(self.itos) - 1
def load_vectors(self, vectors, cat=True):
"""
Arguments:
vectors: one of or a list containing instantiations of the
GloVe, CharNGram, or Vectors classes.
"""
if not isinstance(vectors, list):
vectors = [vectors]
if cat:
tot_dim = sum(v.dim for v in vectors)
self.vectors = torch.Tensor(len(self), tot_dim)
for ti, token in enumerate(self.itos):
start_dim = 0
for v in vectors:
end_dim = start_dim + v.dim
self.vectors[ti][start_dim:end_dim] = v[token.strip()]
start_dim = end_dim
assert(start_dim == tot_dim)
else:
self.vectors = [torch.Tensor(len(self), v.dim) for v in vectors]
for ti, t in enumerate(self.itos):
for vi, v in enumerate(vectors):
self.vectors[vi][ti] = v[t.strip()]
def set_vectors(self, stoi, vectors, dim, unk_init=torch.Tensor.zero_):
"""
Set the vectors for the Vocab instance from a collection of Tensors.
Arguments:
stoi: A dictionary of string to the index of the associated vector
in the `vectors` input argument.
vectors: An indexed iterable (or other structure supporting __getitem__) that
given an input index, returns a FloatTensor representing the vector
for the token associated with the index. For example,
vector[stoi["string"]] should return the vector for "string".
dim: The dimensionality of the vectors.
unk_init (callback): by default, initialize out-of-vocabulary word vectors
to zero vectors; can be any function that takes in a Tensor and
returns a Tensor of the same size. Default: torch.Tensor.zero_
"""
self.vectors = torch.Tensor(len(self), dim)
for i, token in enumerate(self.itos):
wv_index = stoi.get(token, None)
if wv_index is not None:
self.vectors[i] = vectors[wv_index]
else:
self.vectors[i] = unk_init(self.vectors[i])
@staticmethod
def build_from_data(field_names, *args, unk_token=None, pad_token=None, init_token=None, eos_token=None, **kwargs):
"""Construct the Vocab object for this field from one or more datasets.

View File

@ -0,0 +1,93 @@
# The following code was copied and adapted from github.com/floyhub/world-language-model
#
# BSD 3-Clause License
#
# Copyright (c) 2017,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from torch import nn
class PretrainedLTSMLM(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
def __init__(self, rnn_type, ntoken, emsize, nhid, nlayers, dropout=0.5, tie_weights=False):
super(PretrainedLTSMLM, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, emsize) # Token2Embeddings
if rnn_type in ['LSTM', 'GRU']:
self.rnn = getattr(nn, rnn_type)(emsize, nhid, nlayers, dropout=dropout)
else:
try:
nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
except KeyError:
raise ValueError( """An invalid option for `--model` was supplied,
options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""")
self.rnn = nn.RNN(emsize, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)
self.decoder = nn.Linear(nhid, ntoken)
# Optionally tie weights as in:
# "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
# https://arxiv.org/abs/1608.05859
# and
# "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
# https://arxiv.org/abs/1611.01462
if tie_weights:
if nhid != emsize:
raise ValueError('When using the tied flag, nhid must be equal to emsize')
self.decoder.weight = self.encoder.weight
self.init_weights()
self.rnn_type = rnn_type
self.nhid = nhid
self.nlayers = nlayers
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.fill_(0)
self.decoder.weight.data.uniform_(-initrange, initrange)
def encode(self, input, hidden=None):
emb = self.drop(self.encoder(input))
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
return output, hidden
def forward(self, input, hidden=None):
encoded, hidden = self.encode(input, hidden)
decoded = self.decoder(encoded.view(encoded.size(0)*encoded.size(1), encoded.size(2)))
return decoded.view(encoded.size(0), encoded.size(1), decoded.size(1)), hidden
def init_hidden(self, bsz):
weight = next(self.parameters()).data
if self.rnn_type == 'LSTM':
return (weight.new(self.nlayers, bsz, self.nhid).zero_(),
weight.new(self.nlayers, bsz, self.nhid).zero_())
else:
return weight.new(self.nlayers, bsz, self.nhid).zero_()

View File

@ -27,4 +27,4 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from .multitask_question_answering_network import MultitaskQuestionAnsweringNetwork
from .general_seq2seq import Seq2Seq

View File

@ -37,18 +37,24 @@ import os
import sys
import numpy as np
import torch.nn as nn
from typing import NamedTuple, List
from torch.nn.utils.rnn import pad_packed_sequence as unpack
from torch.nn.utils.rnn import pack_padded_sequence as pack
class EmbeddingOutput(NamedTuple):
all_layers: List[torch.Tensor]
last_layer: torch.Tensor
INF = 1e10
EPSILON = 1e-10
class LSTMDecoder(nn.Module):
class MultiLSTMCell(nn.Module):
def __init__(self, num_layers, input_size, rnn_size, dropout):
super(LSTMDecoder, self).__init__()
super(MultiLSTMCell, self).__init__()
self.dropout = nn.Dropout(dropout)
self.num_layers = num_layers
self.layers = nn.ModuleList()
@ -304,6 +310,7 @@ class LinearFeedforward(nn.Module):
def forward(self, x):
return self.dropout(self.linear(self.feedforward(x)))
class PackedLSTM(nn.Module):
def __init__(self, d_in, d_out, bidirectional=False, num_layers=1,
@ -354,32 +361,23 @@ class Feedforward(nn.Module):
return self.activation(self.linear(self.dropout(x)))
class Embedding(nn.Module):
class CombinedEmbedding(nn.Module):
def __init__(self, numericalizer, output_dimension, include_pretrained=True, trained_dimension=0, dropout=0.0, project=True, requires_grad=False):
def __init__(self, numericalizer, pretrained_embeddings,
output_dimension,
finetune_pretrained=False,
trained_dimension=0,
project=True):
super().__init__()
self.project = project
self.requires_grad = requires_grad
self.finetune_pretrained = finetune_pretrained
self.pretrained_embeddings = tuple(pretrained_embeddings)
dimension = 0
pretrained_dimension = numericalizer.vocab.vectors.size(-1)
for idx, embedding in enumerate(self.pretrained_embeddings):
dimension += embedding.dim
self.add_module('pretrained_' + str(idx), embedding)
if include_pretrained:
# NOTE: this must be a list so that pytorch will not iterate into the module when
# traversing this module
# in turn, this means that moving this Embedding() to the GPU will not move the
# actual embedding, which will stay on CPU; this is necessary because a) we call
# set_embeddings() sometimes with CPU-only tensors, and b) the embedding tensor
# is too big for the GPU anyway
self.pretrained_embeddings = [nn.Embedding(numericalizer.num_tokens, pretrained_dimension)]
self.pretrained_embeddings[0].weight.data = numericalizer.vocab.vectors
self.pretrained_embeddings[0].weight.requires_grad = self.requires_grad
dimension += pretrained_dimension
else:
self.pretrained_embeddings = None
# OTOH, if we have a trained embedding, we move it around together with the module
# (ie, potentially on GPU), because the saving when applying gradient outweights
# the cost, and hopefully the embedding is small enough to fit in GPU memory
if trained_dimension > 0:
self.trained_embeddings = nn.Embedding(numericalizer.num_tokens, trained_dimension)
dimension += trained_dimension
@ -387,34 +385,53 @@ class Embedding(nn.Module):
self.trained_embeddings = None
if self.project:
self.projection = Feedforward(dimension, output_dimension)
self.dropout = nn.Dropout(dropout)
else:
assert dimension == output_dimension
self.dimension = output_dimension
def forward(self, x, lengths=None, device=-1):
def _combine_embeddings(self, embeddings):
if len(embeddings) == 1:
all_layers = embeddings[0].all_layers
last_layer = embeddings[0].last_layer
if self.project:
last_layer = self.projection(last_layer)
return EmbeddingOutput(all_layers=all_layers, last_layer=last_layer)
all_layers = None
last_layer = []
for emb in embeddings:
if all_layers is None:
all_layers = [[layer] for layer in emb.all_layers]
elif len(all_layers) != len(emb.all_layers):
raise ValueError('Cannot combine embeddings that use different numbers of layers')
else:
for layer_list, layer in zip(all_layers, emb.all_layers):
layer_list.append(layer)
last_layer.append(emb.last_layer)
all_layers = [torch.cat(layer, dim=2) for layer in all_layers]
last_layer = torch.cat(last_layer, dim=2)
if self.project:
last_layer = self.projection(last_layer)
return EmbeddingOutput(all_layers=all_layers, last_layer=last_layer)
def forward(self, x, padding=None):
embedded = []
if self.pretrained_embeddings is not None:
pretrained_embeddings = self.pretrained_embeddings[0](x.cpu()).to(x.device).detach()
else:
pretrained_embeddings = None
if self.finetune_pretrained:
embedded += [emb(x, padding=padding) for emb in self.pretrained_embeddings]
else:
with torch.no_grad():
embedded += [emb(x, padding=padding) for emb in self.pretrained_embeddings]
if self.trained_embeddings is not None:
trained_vocabulary_size = self.trained_embeddings.weight.size()[0]
valid_x = torch.lt(x, trained_vocabulary_size)
masked_x = torch.where(valid_x, x, torch.zeros_like(x))
trained_embeddings = self.trained_embeddings(masked_x)
else:
trained_embeddings = None
if pretrained_embeddings is not None and trained_embeddings is not None:
embeddings = torch.cat((pretrained_embeddings, trained_embeddings), dim=2)
elif pretrained_embeddings is not None:
embeddings = pretrained_embeddings
else:
embeddings = trained_embeddings
output = self.trained_embeddings(masked_x)
embedded.append(EmbeddingOutput(all_layers=[output], last_layer=output))
return self.projection(embeddings) if self.project else embeddings
def set_embeddings(self, w):
if self.pretrained_embeddings is not None:
self.pretrained_embeddings[0].weight.data = w
self.pretrained_embeddings[0].weight.requires_grad = self.requires_grad
return self._combine_embeddings(embedded)
class SemanticFusionUnit(nn.Module):
@ -443,23 +460,38 @@ class LSTMDecoderAttention(nn.Module):
self.dot = dot
def applyMasks(self, context_mask):
self.context_mask = context_mask
# context_mask is batch x encoder_time, convert it to batch x 1 x encoder_time
self.context_mask = context_mask.unsqueeze(1)
def forward(self, input : torch.Tensor, context : torch.Tensor):
# input is batch x decoder_time x dim
# context is batch x encoder_time x dim
# output will be batch x decoder_time x dim
# context_attention will be batch x decoder_time x encoder_time
def forward(self, input, context):
if not self.dot:
targetT = self.linear_in(input).unsqueeze(2) # batch x dim x 1
targetT = self.linear_in(input) # batch x decoder_time x dim x 1
else:
targetT = input.unsqueeze(2)
targetT = input
context_scores = torch.bmm(context, targetT).squeeze(2)
x = input.shape
transposed_context = torch.transpose(context, 2, 1)
x = transposed_context.shape
context_scores = torch.matmul(targetT, transposed_context)
context_scores.masked_fill_(self.context_mask, -float('inf'))
context_attention = F.softmax(context_scores, dim=-1) + EPSILON
context_alignment = torch.bmm(context_attention.unsqueeze(1), context).squeeze(1)
combined_representation = torch.cat([input, context_alignment], 1)
# convert context_attention to batch x decoder_time x 1 x encoder_time
# convert context to batch x 1 x encoder_time x dim
# context_alignment will be batch x decoder_time x 1 x dim
context_alignment = torch.matmul(context_attention.unsqueeze(2), context.unsqueeze(1))
# squeeze out the extra dimension
context_alignment = context_alignment.squeeze(2)
combined_representation = torch.cat([input, context_alignment], 2)
output = self.tanh(self.linear_out(combined_representation))
return output, context_attention, context_alignment
return output, context_attention
class CoattentiveLayer(nn.Module):
@ -471,8 +503,8 @@ class CoattentiveLayer(nn.Module):
self.dropout = nn.Dropout(dropout)
def forward(self, context, question, context_padding, question_padding):
context_padding = torch.cat([context.new_zeros((context.size(0), 1), dtype=torch.long)==1, context_padding], 1)
question_padding = torch.cat([question.new_zeros((question.size(0), 1), dtype=torch.long)==1, question_padding], 1)
context_padding = torch.cat([context.new_zeros((context.size(0), 1), dtype=torch.bool), context_padding], 1)
question_padding = torch.cat([question.new_zeros((question.size(0), 1), dtype=torch.bool), question_padding], 1)
context_sentinel = self.embed_sentinel(context.new_zeros((context.size(0), 1), dtype=torch.long))
context = torch.cat([context_sentinel, self.dropout(context)], 1) # batch_size x (context_length + 1) x features
@ -503,94 +535,3 @@ class CoattentiveLayer(nn.Module):
return F.softmax(raw_scores, dim=1)
# The following code was copied and adapted from github.com/floyhub/world-language-model
#
# BSD 3-Clause License
#
# Copyright (c) 2017,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
class PretrainedDecoderLM(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
def __init__(self, rnn_type, ntoken, emsize, nhid, nlayers, dropout=0.5, tie_weights=False):
super(PretrainedDecoderLM, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, emsize) # Token2Embeddings
if rnn_type in ['LSTM', 'GRU']:
self.rnn = getattr(nn, rnn_type)(emsize, nhid, nlayers, dropout=dropout)
else:
try:
nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
except KeyError:
raise ValueError( """An invalid option for `--model` was supplied,
options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""")
self.rnn = nn.RNN(emsize, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)
self.decoder = nn.Linear(nhid, ntoken)
# Optionally tie weights as in:
# "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
# https://arxiv.org/abs/1608.05859
# and
# "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
# https://arxiv.org/abs/1611.01462
if tie_weights:
if nhid != emsize:
raise ValueError('When using the tied flag, nhid must be equal to emsize')
self.decoder.weight = self.encoder.weight
self.init_weights()
self.rnn_type = rnn_type
self.nhid = nhid
self.nlayers = nlayers
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.fill_(0)
self.decoder.weight.data.uniform_(-initrange, initrange)
def encode(self, input, hidden=None):
emb = self.drop(self.encoder(input))
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
return output, hidden
def forward(self, input, hidden=None):
encoded, hidden = self.encode(input, hidden)
decoded = self.decoder(encoded.view(encoded.size(0)*encoded.size(1), encoded.size(2)))
return decoded.view(encoded.size(0), encoded.size(1), decoded.size(1)), hidden
def init_hidden(self, bsz):
weight = next(self.parameters()).data
if self.rnn_type == 'LSTM':
return (weight.new(self.nlayers, bsz, self.nhid).zero_(),
weight.new(self.nlayers, bsz, self.nhid).zero_())
else:
return weight.new(self.nlayers, bsz, self.nhid).zero_()

View File

@ -0,0 +1,59 @@
#
# Copyright (c) 2018, Salesforce, Inc.
# The Board of Trustees of the Leland Stanford Junior University
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from torch import nn
from .mqan_encoder import MQANEncoder
from .identity_encoder import IdentityEncoder
from .mqan_decoder import MQANDecoder
ENCODERS = {
'MQANEncoder': MQANEncoder,
'Identity': IdentityEncoder
}
DECODERS = {
'MQANDecoder': MQANDecoder
}
class Seq2Seq(nn.Module):
def __init__(self, numericalizer, args, encoder_embeddings, decoder_embeddings):
super().__init__()
self.args = args
self.encoder = ENCODERS[args.seq2seq_encoder](numericalizer, args, encoder_embeddings)
self.decoder = DECODERS[args.seq2seq_decoder](numericalizer, args, decoder_embeddings)
def forward(self, batch, iteration):
self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state = self.encoder(batch)
loss, predictions = self.decoder(batch, self_attended_context, final_context, context_rnn_state,
final_question, question_rnn_state)
return loss, predictions

View File

@ -0,0 +1,68 @@
#
# Copyright (c) 2018, The Board of Trustees of the Leland Stanford Junior University
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from torch import nn
from .common import CombinedEmbedding
class IdentityEncoder(nn.Module):
def __init__(self, numericalizer, args, encoder_embeddings):
super().__init__()
self.args = args
self.pad_idx = numericalizer.pad_id
if sum(emb.dim for emb in encoder_embeddings) != args.dimension:
raise ValueError('Hidden dimension must be equal to the sum of the embedding sizes to use IdentityEncoder')
if args.rnn_layers > 0:
raise ValueError('Cannot have RNN layers with IdentityEncoder')
self.encoder_embeddings = CombinedEmbedding(numericalizer, encoder_embeddings, args.dimension,
trained_dimension=0,
project=False,
finetune_pretrained=args.train_encoder_embeddings)
def forward(self, batch):
context, context_lengths = batch.context.value, batch.context.length
question, question_lengths = batch.question.value, batch.question.length
context_padding = context.data == self.pad_idx
question_padding = question.data == self.pad_idx
context_embedded = self.encoder_embeddings(context, padding=context_padding)
question_embedded = self.encoder_embeddings(question, padding=question_padding)
# pick the top-most N transformer layers to pass to the decoder for cross-attention
# (add 1 to account for the embedding layer - the decoder will drop it later)
self_attended_context = context_embedded.all_layers[:-(self.args.transformer_layers+1)]
final_context = context_embedded.last_layer
final_question = question_embedded.last_layer
context_rnn_state = None
question_rnn_state = None
return self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state

View File

@ -0,0 +1,277 @@
#
# Copyright (c) 2018, Salesforce, Inc.
# The Board of Trustees of the Leland Stanford Junior University
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from .common import *
class MQANDecoder(nn.Module):
def __init__(self, numericalizer, args, decoder_embeddings):
super().__init__()
self.numericalizer = numericalizer
self.pad_idx = numericalizer.pad_id
self.init_idx = numericalizer.init_id
self.args = args
self.decoder_embeddings = CombinedEmbedding(numericalizer, decoder_embeddings, args.dimension,
trained_dimension=args.trainable_decoder_embeddings,
project=True,
finetune_pretrained=False)
self.self_attentive_decoder = TransformerDecoder(args.dimension, args.transformer_heads,
args.transformer_hidden,
args.transformer_layers,
args.dropout_ratio)
if args.rnn_layers > 0:
self.rnn_decoder = LSTMDecoder(args.dimension, args.dimension,
dropout=args.dropout_ratio, num_layers=args.rnn_layers)
switch_input_len = 3 * args.dimension
else:
self.context_attn = LSTMDecoderAttention(args.dimension, dot=True)
self.question_attn = LSTMDecoderAttention(args.dimension, dot=True)
self.dropout = nn.Dropout(args.dropout_ratio)
switch_input_len = 2 * args.dimension
self.vocab_pointer_switch = nn.Sequential(Feedforward(switch_input_len, 1), nn.Sigmoid())
self.context_question_switch = nn.Sequential(Feedforward(switch_input_len, 1), nn.Sigmoid())
self.generative_vocab_size = numericalizer.generative_vocab_size
self.out = nn.Linear(args.dimension, self.generative_vocab_size)
def set_embeddings(self, embeddings):
if self.decoder_embeddings is not None:
self.decoder_embeddings.set_embeddings(embeddings)
def forward(self, batch, self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state):
context, context_lengths, context_limited = batch.context.value, batch.context.length, batch.context.limited
question, question_lengths, question_limited = batch.question.value, batch.question.length, batch.question.limited
answer, answer_lengths, answer_limited = batch.answer.value, batch.answer.length, batch.answer.limited
decoder_vocab = batch.decoder_vocab
self.map_to_full = decoder_vocab.decode
context_indices = context_limited if context_limited is not None else context
question_indices = question_limited if question_limited is not None else question
answer_indices = answer_limited if answer_limited is not None else answer
context_padding = context_indices.data == self.pad_idx
question_padding = question_indices.data == self.pad_idx
if self.args.rnn_layers > 0:
self.rnn_decoder.applyMasks(context_padding, question_padding)
else:
self.context_attn.applyMasks(context_padding)
self.question_attn.applyMasks(question_padding)
if self.training:
answer_padding = (answer_indices.data == self.pad_idx)[:, :-1]
answer_embedded = self.decoder_embeddings(answer[:, :-1], padding=answer_padding).last_layer
self_attended_decoded = self.self_attentive_decoder(answer_embedded,
self_attended_context,
context_padding=context_padding,
answer_padding=answer_padding,
positional_encodings=True)
if self.args.rnn_layers > 0:
rnn_decoder_outputs = self.rnn_decoder(self_attended_decoded, final_context, final_question,
hidden=context_rnn_state)
decoder_output, vocab_pointer_switch_input, context_question_switch_input, context_attention, \
question_attention, rnn_state = rnn_decoder_outputs
else:
context_decoder_output, context_attention = self.context_attn(self_attended_decoded, final_context)
question_decoder_output, question_attention = self.question_attn(self_attended_decoded, final_question)
vocab_pointer_switch_input = torch.cat((context_decoder_output, self_attended_decoded), dim=-1)
context_question_switch_input = torch.cat((question_decoder_output, self_attended_decoded), dim=-1)
decoder_output = self.dropout(context_decoder_output)
vocab_pointer_switch = self.vocab_pointer_switch(vocab_pointer_switch_input)
context_question_switch = self.context_question_switch(context_question_switch_input)
probs = self.probs(self.out, decoder_output, vocab_pointer_switch, context_question_switch,
context_attention, question_attention,
context_indices, question_indices,
decoder_vocab)
probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=self.pad_idx)
loss = F.nll_loss(probs.log(), targets)
return loss, None
else:
return None, self.greedy(self_attended_context, final_context, final_question,
context_indices, question_indices,
decoder_vocab, rnn_state=context_rnn_state).data
def probs(self, generator, outputs, vocab_pointer_switches, context_question_switches,
context_attention, question_attention,
context_indices, question_indices,
decoder_vocab):
size = list(outputs.size())
size[-1] = self.generative_vocab_size
scores = generator(outputs.view(-1, outputs.size(-1))).view(size)
p_vocab = F.softmax(scores, dim=scores.dim() - 1)
scaled_p_vocab = vocab_pointer_switches.expand_as(p_vocab) * p_vocab
effective_vocab_size = len(decoder_vocab)
if self.generative_vocab_size < effective_vocab_size:
size[-1] = effective_vocab_size - self.generative_vocab_size
buff = scaled_p_vocab.new_full(size, EPSILON)
scaled_p_vocab = torch.cat([scaled_p_vocab, buff], dim=buff.dim() - 1)
# p_context_ptr
scaled_p_vocab.scatter_add_(scaled_p_vocab.dim() - 1, context_indices.unsqueeze(1).expand_as(context_attention),
(context_question_switches * (1 - vocab_pointer_switches)).expand_as(
context_attention) * context_attention)
# p_question_ptr
scaled_p_vocab.scatter_add_(scaled_p_vocab.dim() - 1,
question_indices.unsqueeze(1).expand_as(question_attention),
((1 - context_question_switches) * (1 - vocab_pointer_switches)).expand_as(
question_attention) * question_attention)
return scaled_p_vocab
def greedy(self, self_attended_context, context, question, context_indices, question_indices, decoder_vocab,
rnn_state=None):
B, TC, C = context.size()
T = self.args.max_output_length
outs = context.new_full((B, T), self.pad_idx, dtype=torch.long)
hiddens = [self_attended_context[0].new_zeros((B, T, C))
for l in range(len(self.self_attentive_decoder.layers) + 1)]
hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0])
eos_yet = context.new_zeros((B,)).byte()
decoder_output = None
for t in range(T):
if t == 0:
init_token = self_attended_context[-1].new_full((B, 1), self.init_idx,
dtype=torch.long)
embedding = self.decoder_embeddings(init_token).last_layer
else:
current_token_id = outs[:, t - 1].unsqueeze(1)
embedding = self.decoder_embeddings(current_token_id).last_layer
hiddens[0][:, t] = hiddens[0][:, t] + (math.sqrt(self.self_attentive_decoder.d_model) * embedding).squeeze(
1)
for l in range(len(self.self_attentive_decoder.layers)):
hiddens[l + 1][:, t] = self.self_attentive_decoder.layers[l].feedforward(
self.self_attentive_decoder.layers[l].attention(
self.self_attentive_decoder.layers[l].selfattn(hiddens[l][:, t], hiddens[l][:, :t + 1],
hiddens[l][:, :t + 1])
, self_attended_context[l], self_attended_context[l]))
self_attended_decoded = hiddens[-1][:, t].unsqueeze(1)
if self.args.rnn_layers > 0:
rnn_decoder_outputs = self.rnn_decoder(self_attended_decoded, context, question,
hidden=rnn_state, output=decoder_output)
decoder_output, vocab_pointer_switch_input, context_question_switch_input, context_attention, \
question_attention, rnn_state = rnn_decoder_outputs
else:
context_decoder_output, context_attention = self.context_attn(self_attended_decoded, context)
question_decoder_output, question_attention = self.question_attn(self_attended_decoded, question)
vocab_pointer_switch_input = torch.cat((context_decoder_output, self_attended_decoded), dim=-1)
context_question_switch_input = torch.cat((question_decoder_output, self_attended_decoded), dim=-1)
decoder_output = self.dropout(context_decoder_output)
vocab_pointer_switch = self.vocab_pointer_switch(vocab_pointer_switch_input)
context_question_switch = self.context_question_switch(context_question_switch_input)
probs = self.probs(self.out, decoder_output, vocab_pointer_switch, context_question_switch,
context_attention, question_attention,
context_indices, question_indices, decoder_vocab)
pred_probs, preds = probs.max(-1)
preds = preds.squeeze(1)
eos_yet = eos_yet | (preds == self.numericalizer.eos_id).byte()
outs[:, t] = preds.cpu().apply_(self.map_to_full)
if eos_yet.all():
break
return outs
class LSTMDecoder(nn.Module):
def __init__(self, d_in, d_hid, dropout=0.0, num_layers=1):
super().__init__()
self.d_hid = d_hid
self.d_in = d_in
self.num_layers = num_layers
self.dropout = nn.Dropout(dropout)
self.input_feed = True
if self.input_feed:
d_in += 1 * d_hid
self.rnn = MultiLSTMCell(self.num_layers, d_in, d_hid, dropout)
self.context_attn = LSTMDecoderAttention(d_hid, dot=True)
self.question_attn = LSTMDecoderAttention(d_hid, dot=True)
def applyMasks(self, context_mask, question_mask):
self.context_attn.applyMasks(context_mask)
self.question_attn.applyMasks(question_mask)
def forward(self, input : torch.Tensor, context, question, output=None, hidden=None):
context_output = output if output is not None else self.make_init_output(context)
context_outputs, vocab_pointer_switch_inputs, context_question_switch_inputs, context_attentions, question_attentions = [], [], [], [], []
for decoder_input in input.split(1, dim=1):
context_output = self.dropout(context_output)
if self.input_feed:
rnn_input = torch.cat([decoder_input, context_output], 2)
else:
rnn_input = decoder_input
rnn_input = rnn_input.squeeze(1)
dec_state, hidden = self.rnn(rnn_input, hidden)
dec_state = dec_state.unsqueeze(1)
context_output, context_attention = self.context_attn(dec_state, context)
question_output, question_attention = self.question_attn(dec_state, question)
vocab_pointer_switch_inputs.append(torch.cat([dec_state, context_output, decoder_input], -1))
context_question_switch_inputs.append(torch.cat([dec_state, question_output, decoder_input], -1))
context_output = self.dropout(context_output)
context_outputs.append(context_output)
context_attentions.append(context_attention)
question_attentions.append(question_attention)
return [torch.cat(x, dim=1) for x in (context_outputs,
vocab_pointer_switch_inputs,
context_question_switch_inputs,
context_attentions,
question_attentions)] + [hidden]
def make_init_output(self, context):
batch_size = context.size(0)
h_size = (batch_size, 1, self.d_hid)
return context.new_zeros(h_size)

View File

@ -0,0 +1,108 @@
#
# Copyright (c) 2018, Salesforce, Inc.
# The Board of Trustees of the Leland Stanford Junior University
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from .common import *
class MQANEncoder(nn.Module):
def __init__(self, numericalizer, args, encoder_embeddings):
super().__init__()
self.args = args
self.pad_idx = numericalizer.pad_id
self.encoder_embeddings = CombinedEmbedding(numericalizer, encoder_embeddings, args.dimension,
trained_dimension=0,
project=True,
finetune_pretrained=args.train_encoder_embeddings)
def dp(args):
return args.dropout_ratio if args.rnn_layers > 1 else 0.
self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension,
batch_first=True, bidirectional=True, num_layers=1, dropout=0)
self.coattention = CoattentiveLayer(args.dimension, dropout=0.3)
dim = 2 * args.dimension + args.dimension + args.dimension
self.context_bilstm_after_coattention = PackedLSTM(dim, args.dimension,
batch_first=True, dropout=dp(args), bidirectional=True,
num_layers=args.rnn_layers)
self.self_attentive_encoder_context = TransformerEncoder(args.dimension, args.transformer_heads,
args.transformer_hidden, args.transformer_layers,
args.dropout_ratio)
self.bilstm_context = PackedLSTM(args.dimension, args.dimension,
batch_first=True, dropout=dp(args), bidirectional=True,
num_layers=args.rnn_layers)
self.question_bilstm_after_coattention = PackedLSTM(dim, args.dimension,
batch_first=True, dropout=dp(args), bidirectional=True,
num_layers=args.rnn_layers)
self.self_attentive_encoder_question = TransformerEncoder(args.dimension, args.transformer_heads,
args.transformer_hidden, args.transformer_layers,
args.dropout_ratio)
self.bilstm_question = PackedLSTM(args.dimension, args.dimension,
batch_first=True, dropout=dp(args), bidirectional=True,
num_layers=args.rnn_layers)
def forward(self, batch):
context, context_lengths = batch.context.value, batch.context.length
question, question_lengths = batch.question.value, batch.question.length
context_padding = context.data == self.pad_idx
question_padding = question.data == self.pad_idx
context_embedded = self.encoder_embeddings(context, padding=context_padding).last_layer
question_embedded = self.encoder_embeddings(question, padding=question_padding).last_layer
context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0]
question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0]
coattended_context, coattended_question = self.coattention(context_encoded, question_encoded,
context_padding, question_padding)
context_summary = torch.cat([coattended_context, context_encoded, context_embedded], -1)
condensed_context, _ = self.context_bilstm_after_coattention(context_summary, context_lengths)
self_attended_context = self.self_attentive_encoder_context(condensed_context, padding=context_padding)
final_context, (context_rnn_h, context_rnn_c) = self.bilstm_context(self_attended_context[-1],
context_lengths)
context_rnn_state = [self.reshape_rnn_state(x) for x in (context_rnn_h, context_rnn_c)]
question_summary = torch.cat([coattended_question, question_encoded, question_embedded], -1)
condensed_question, _ = self.question_bilstm_after_coattention(question_summary, question_lengths)
self_attended_question = self.self_attentive_encoder_question(condensed_question, padding=question_padding)
final_question, (question_rnn_h, question_rnn_c) = self.bilstm_question(self_attended_question[-1],
question_lengths)
question_rnn_state = [self.reshape_rnn_state(x) for x in (question_rnn_h, question_rnn_c)]
return self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state
def reshape_rnn_state(self, h):
return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \
.transpose(1, 2).contiguous() \
.view(h.size(0) // 2, h.size(1), h.size(2) * 2).contiguous()

View File

@ -1,420 +0,0 @@
#
# Copyright (c) 2018, Salesforce, Inc.
# The Board of Trustees of the Leland Stanford Junior University
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from collections import defaultdict
from ..util import get_trainable_params
from .common import *
class MQANEncoder(nn.Module):
def __init__(self, numericalizer, args):
super().__init__()
self.args = args
self.pad_idx = numericalizer.pad_id
if self.args.glove_and_char:
self.encoder_embeddings = Embedding(numericalizer, args.dimension,
trained_dimension=0,
dropout=args.dropout_ratio,
project=True,
requires_grad=args.retrain_encoder_embedding)
def dp(args):
return args.dropout_ratio if args.rnn_layers > 1 else 0.
self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension,
batch_first=True, bidirectional=True, num_layers=1, dropout=0)
self.coattention = CoattentiveLayer(args.dimension, dropout=0.3)
dim = 2 * args.dimension + args.dimension + args.dimension
self.context_bilstm_after_coattention = PackedLSTM(dim, args.dimension,
batch_first=True, dropout=dp(args), bidirectional=True,
num_layers=args.rnn_layers)
self.self_attentive_encoder_context = TransformerEncoder(args.dimension, args.transformer_heads,
args.transformer_hidden, args.transformer_layers,
args.dropout_ratio)
self.bilstm_context = PackedLSTM(args.dimension, args.dimension,
batch_first=True, dropout=dp(args), bidirectional=True,
num_layers=args.rnn_layers)
self.question_bilstm_after_coattention = PackedLSTM(dim, args.dimension,
batch_first=True, dropout=dp(args), bidirectional=True,
num_layers=args.rnn_layers)
self.self_attentive_encoder_question = TransformerEncoder(args.dimension, args.transformer_heads,
args.transformer_hidden, args.transformer_layers,
args.dropout_ratio)
self.bilstm_question = PackedLSTM(args.dimension, args.dimension,
batch_first=True, dropout=dp(args), bidirectional=True,
num_layers=args.rnn_layers)
def set_embeddings(self, embeddings):
self.encoder_embeddings.set_embeddings(embeddings)
def forward(self, batch):
context, context_lengths = batch.context.value, batch.context.length
question, question_lengths = batch.question.value, batch.question.length
context_embedded = self.encoder_embeddings(context)
question_embedded = self.encoder_embeddings(question)
context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0]
question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0]
context_padding = context.data == self.pad_idx
question_padding = question.data == self.pad_idx
coattended_context, coattended_question = self.coattention(context_encoded, question_encoded,
context_padding, question_padding)
context_summary = torch.cat([coattended_context, context_encoded, context_embedded], -1)
condensed_context, _ = self.context_bilstm_after_coattention(context_summary, context_lengths)
self_attended_context = self.self_attentive_encoder_context(condensed_context, padding=context_padding)
final_context, (context_rnn_h, context_rnn_c) = self.bilstm_context(self_attended_context[-1],
context_lengths)
context_rnn_state = [self.reshape_rnn_state(x) for x in (context_rnn_h, context_rnn_c)]
question_summary = torch.cat([coattended_question, question_encoded, question_embedded], -1)
condensed_question, _ = self.question_bilstm_after_coattention(question_summary, question_lengths)
self_attended_question = self.self_attentive_encoder_question(condensed_question, padding=question_padding)
final_question, (question_rnn_h, question_rnn_c) = self.bilstm_question(self_attended_question[-1],
question_lengths)
question_rnn_state = [self.reshape_rnn_state(x) for x in (question_rnn_h, question_rnn_c)]
return self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state
def reshape_rnn_state(self, h):
return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \
.transpose(1, 2).contiguous() \
.view(h.size(0) // 2, h.size(1), h.size(2) * 2).contiguous()
class MQANDecoder(nn.Module):
def __init__(self, numericalizer, args, devices):
super().__init__()
self.numericalizer = numericalizer
self.pad_idx = numericalizer.pad_id
self.init_idx = numericalizer.init_id
self.args = args
self.devices = devices
if args.pretrained_decoder_lm:
pretrained_save_dict = torch.load(os.path.join(args.embeddings, args.pretrained_decoder_lm), map_location=devices[0])
self.pretrained_decoder_vocab_itos = pretrained_save_dict['vocab']
self.pretrained_decoder_vocab_stoi = defaultdict(lambda: 0, {
w: i for i, w in enumerate(self.pretrained_decoder_vocab_itos)
})
self.pretrained_decoder_embeddings = PretrainedDecoderLM(rnn_type=pretrained_save_dict['settings']['rnn_type'],
ntoken=len(self.pretrained_decoder_vocab_itos),
emsize=pretrained_save_dict['settings']['emsize'],
nhid=pretrained_save_dict['settings']['nhid'],
nlayers=pretrained_save_dict['settings']['nlayers'],
dropout=0.0)
self.pretrained_decoder_embeddings.load_state_dict(pretrained_save_dict['model'], strict=True)
pretrained_lm_params = get_trainable_params(self.pretrained_decoder_embeddings)
for p in pretrained_lm_params:
p.requires_grad = False
if self.pretrained_decoder_embeddings.nhid != args.dimension:
self.pretrained_decoder_embedding_projection = Feedforward(self.pretrained_decoder_embeddings.nhid,
args.dimension)
else:
self.pretrained_decoder_embedding_projection = None
self.decoder_embeddings = None
else:
self.pretrained_decoder_vocab_itos = None
self.pretrained_decoder_vocab_stoi = None
self.pretrained_decoder_embeddings = None
self.decoder_embeddings = Embedding(self.numericalizer, args.dimension,
include_pretrained=args.glove_decoder,
trained_dimension=args.trainable_decoder_embedding,
dropout=args.dropout_ratio, project=True)
self.self_attentive_decoder = TransformerDecoder(args.dimension, args.transformer_heads, args.transformer_hidden, args.transformer_layers, args.dropout_ratio)
self.dual_ptr_rnn_decoder = DualPtrRNNDecoder(args.dimension, args.dimension,
dropout=args.dropout_ratio, num_layers=args.rnn_layers)
self.generative_vocab_size = numericalizer.generative_vocab_size
self.out = nn.Linear(args.dimension, self.generative_vocab_size)
def set_embeddings(self, embeddings):
if self.decoder_embeddings is not None:
self.decoder_embeddings.set_embeddings(embeddings)
def forward(self, batch, self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state):
context, context_lengths, context_limited, context_tokens = batch.context.value, batch.context.length, batch.context.limited, batch.context.tokens
question, question_lengths, question_limited, question_tokens = batch.question.value, batch.question.length, batch.question.limited, batch.question.tokens
answer, answer_lengths, answer_limited, answer_tokens = batch.answer.value, batch.answer.length, batch.answer.limited, batch.answer.tokens
decoder_vocab = batch.decoder_vocab
self.map_to_full = decoder_vocab.decode
context_indices = context_limited if context_limited is not None else context
question_indices = question_limited if question_limited is not None else question
answer_indices = answer_limited if answer_limited is not None else answer
context_padding = context_indices.data == self.pad_idx
question_padding = question_indices.data == self.pad_idx
self.dual_ptr_rnn_decoder.applyMasks(context_padding, question_padding)
if self.training:
answer_padding = (answer_indices.data == self.pad_idx)[:, :-1]
if self.args.pretrained_decoder_lm:
# note that pretrained_decoder_embeddings is time first
answer_pretrained_numerical = [
[self.pretrained_decoder_vocab_stoi[sentence[time]] for sentence in answer_tokens] for time in
range(len(answer_tokens[0]))
]
answer_pretrained_numerical = torch.tensor(answer_pretrained_numerical, dtype=torch.long)
with torch.no_grad():
answer_embedded, _ = self.pretrained_decoder_embeddings.encode(answer_pretrained_numerical)
answer_embedded.transpose_(0, 1)
if self.pretrained_decoder_embedding_projection is not None:
answer_embedded = self.pretrained_decoder_embedding_projection(answer_embedded)
else:
answer_embedded = self.decoder_embeddings(answer)
self_attended_decoded = self.self_attentive_decoder(answer_embedded[:, :-1].contiguous(),
self_attended_context, context_padding=context_padding,
answer_padding=answer_padding,
positional_encodings=True)
decoder_outputs = self.dual_ptr_rnn_decoder(self_attended_decoded,
final_context, final_question, hidden=context_rnn_state)
rnn_output, context_attention, question_attention, context_alignment, question_alignment, vocab_pointer_switch, context_question_switch, rnn_state = decoder_outputs
probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch,
context_attention, question_attention,
context_indices, question_indices,
decoder_vocab)
probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=self.pad_idx)
loss = F.nll_loss(probs.log(), targets)
return loss, None
else:
return None, self.greedy(self_attended_context, final_context, final_question,
context_indices, question_indices,
decoder_vocab, rnn_state=context_rnn_state).data
def probs(self, generator, outputs, vocab_pointer_switches, context_question_switches,
context_attention, question_attention,
context_indices, question_indices,
decoder_vocab):
size = list(outputs.size())
size[-1] = self.generative_vocab_size
scores = generator(outputs.view(-1, outputs.size(-1))).view(size)
p_vocab = F.softmax(scores, dim=scores.dim() - 1)
scaled_p_vocab = vocab_pointer_switches.expand_as(p_vocab) * p_vocab
effective_vocab_size = len(decoder_vocab)
if self.generative_vocab_size < effective_vocab_size:
size[-1] = effective_vocab_size - self.generative_vocab_size
buff = scaled_p_vocab.new_full(size, EPSILON)
scaled_p_vocab = torch.cat([scaled_p_vocab, buff], dim=buff.dim() - 1)
# p_context_ptr
scaled_p_vocab.scatter_add_(scaled_p_vocab.dim() - 1, context_indices.unsqueeze(1).expand_as(context_attention),
(context_question_switches * (1 - vocab_pointer_switches)).expand_as(
context_attention) * context_attention)
# p_question_ptr
scaled_p_vocab.scatter_add_(scaled_p_vocab.dim() - 1,
question_indices.unsqueeze(1).expand_as(question_attention),
((1 - context_question_switches) * (1 - vocab_pointer_switches)).expand_as(
question_attention) * question_attention)
return scaled_p_vocab
def greedy(self, self_attended_context, context, question, context_indices, question_indices, decoder_vocab,
rnn_state=None):
B, TC, C = context.size()
T = self.args.max_output_length
outs = context.new_full((B, T), self.pad_idx, dtype=torch.long)
hiddens = [self_attended_context[0].new_zeros((B, T, C))
for l in range(len(self.self_attentive_decoder.layers) + 1)]
hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0])
eos_yet = context.new_zeros((B,)).byte()
pretrained_lm_hidden = None
if self.args.pretrained_decoder_lm:
pretrained_lm_hidden = self.pretrained_decoder_embeddings.init_hidden(B)
rnn_output, context_alignment, question_alignment = None, None, None
for t in range(T):
if t == 0:
if self.args.pretrained_decoder_lm:
init_token = self_attended_context[-1].new_full((1, B),
self.pretrained_decoder_vocab_stoi[self.numericalizer.init_token],
dtype=torch.long)
# note that pretrained_decoder_embeddings is time first
embedding, pretrained_lm_hidden = self.pretrained_decoder_embeddings.encode(init_token,
pretrained_lm_hidden)
embedding.transpose_(0, 1)
if self.pretrained_decoder_embedding_projection is not None:
embedding = self.pretrained_decoder_embedding_projection(embedding)
else:
init_token = self_attended_context[-1].new_full((B, 1), self.init_idx,
dtype=torch.long)
embedding = self.decoder_embeddings(init_token, [1] * B)
else:
if self.args.pretrained_decoder_lm:
current_token = [self.numericalizer.decode([x])[0] for x in outs[:, t - 1]]
current_token_id = torch.tensor([[self.pretrained_decoder_vocab_stoi[x] for x in current_token]],
dtype=torch.long, requires_grad=False)
embedding, pretrained_lm_hidden = self.pretrained_decoder_embeddings.encode(current_token_id,
pretrained_lm_hidden)
# note that pretrained_decoder_embeddings is time first
embedding.transpose_(0, 1)
if self.pretrained_decoder_embedding_projection is not None:
embedding = self.pretrained_decoder_embedding_projection(embedding)
else:
current_token_id = outs[:, t - 1].unsqueeze(1)
embedding = self.decoder_embeddings(current_token_id, [1] * B)
hiddens[0][:, t] = hiddens[0][:, t] + (math.sqrt(self.self_attentive_decoder.d_model) * embedding).squeeze(
1)
for l in range(len(self.self_attentive_decoder.layers)):
hiddens[l + 1][:, t] = self.self_attentive_decoder.layers[l].feedforward(
self.self_attentive_decoder.layers[l].attention(
self.self_attentive_decoder.layers[l].selfattn(hiddens[l][:, t], hiddens[l][:, :t + 1],
hiddens[l][:, :t + 1])
, self_attended_context[l], self_attended_context[l]))
decoder_outputs = self.dual_ptr_rnn_decoder(hiddens[-1][:, t].unsqueeze(1),
context, question,
context_alignment=context_alignment,
question_alignment=question_alignment,
hidden=rnn_state, output=rnn_output)
rnn_output, context_attention, question_attention, context_alignment, question_alignment, vocab_pointer_switch, context_question_switch, rnn_state = decoder_outputs
probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch,
context_attention, question_attention,
context_indices, question_indices,
decoder_vocab)
pred_probs, preds = probs.max(-1)
preds = preds.squeeze(1)
eos_yet = eos_yet | (preds == self.numericalizer.eos_id).byte()
outs[:, t] = preds.cpu().apply_(self.map_to_full)
if eos_yet.all():
break
return outs
class MultitaskQuestionAnsweringNetwork(nn.Module):
def __init__(self, numericalizer, args, devices):
super().__init__()
self.args = args
self.encoder = MQANEncoder(numericalizer, args)
self.decoder = MQANDecoder(numericalizer, args, devices)
def set_embeddings(self, embeddings):
self.encoder.set_embeddings(embeddings)
self.decoder.set_embeddings(embeddings)
def forward(self, batch, iteration):
self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state = self.encoder(batch)
loss, predictions = self.decoder(batch, self_attended_context, final_context, context_rnn_state,
final_question, question_rnn_state)
return loss, predictions
class DualPtrRNNDecoder(nn.Module):
def __init__(self, d_in, d_hid, dropout=0.0, num_layers=1):
super().__init__()
self.d_hid = d_hid
self.d_in = d_in
self.num_layers = num_layers
self.dropout = nn.Dropout(dropout)
self.input_feed = True
if self.input_feed:
d_in += 1 * d_hid
self.rnn = LSTMDecoder(self.num_layers, d_in, d_hid, dropout)
self.context_attn = LSTMDecoderAttention(d_hid, dot=True)
self.question_attn = LSTMDecoderAttention(d_hid, dot=True)
self.vocab_pointer_switch = nn.Sequential(Feedforward(2 * self.d_hid + d_in, 1), nn.Sigmoid())
self.context_question_switch = nn.Sequential(Feedforward(2 * self.d_hid + d_in, 1), nn.Sigmoid())
def forward(self, input, context, question, output=None, hidden=None, context_alignment=None, question_alignment=None):
context_output = output.squeeze(1) if output is not None else self.make_init_output(context)
context_alignment = context_alignment if context_alignment is not None else self.make_init_output(context)
question_alignment = question_alignment if question_alignment is not None else self.make_init_output(question)
context_outputs, vocab_pointer_switches, context_question_switches, context_attentions, question_attentions, context_alignments, question_alignments = [], [], [], [], [], [], []
for emb_t in input.split(1, dim=1):
emb_t = emb_t.squeeze(1)
context_output = self.dropout(context_output)
if self.input_feed:
emb_t = torch.cat([emb_t, context_output], 1)
dec_state, hidden = self.rnn(emb_t, hidden)
context_output, context_attention, context_alignment = self.context_attn(dec_state, context)
question_output, question_attention, question_alignment = self.question_attn(dec_state, question)
vocab_pointer_switch = self.vocab_pointer_switch(torch.cat([dec_state, context_output, emb_t], -1))
context_question_switch = self.context_question_switch(torch.cat([dec_state, question_output, emb_t], -1))
context_output = self.dropout(context_output)
context_outputs.append(context_output)
vocab_pointer_switches.append(vocab_pointer_switch)
context_question_switches.append(context_question_switch)
context_attentions.append(context_attention)
context_alignments.append(context_alignment)
question_attentions.append(question_attention)
question_alignments.append(question_alignment)
context_outputs, vocab_pointer_switches, context_question_switches, context_attention, question_attention = [self.package_outputs(x) for x in [context_outputs, vocab_pointer_switches, context_question_switches, context_attentions, question_attentions]]
return context_outputs, context_attention, question_attention, context_alignment, question_alignment, vocab_pointer_switches, context_question_switches, hidden
def applyMasks(self, context_mask, question_mask):
self.context_attn.applyMasks(context_mask)
self.question_attn.applyMasks(question_mask)
def make_init_output(self, context):
batch_size = context.size(0)
h_size = (batch_size, self.d_hid)
return context.new_zeros(h_size)
def package_outputs(self, outputs):
outputs = torch.stack(outputs, dim=1)
return outputs

View File

@ -36,10 +36,9 @@ import sys
import logging
from pprint import pformat
from .util import set_seed, preprocess_examples, load_config_json, make_data_loader, log_model_size, init_devices, \
make_numericalizer
from .util import set_seed, preprocess_examples, load_config_json, make_data_loader, log_model_size, init_devices
from .metrics import compute_metrics
from .utils.embeddings import load_embeddings
from .data.embeddings import load_embeddings
from .tasks.registry import get_tasks
from . import models
@ -67,18 +66,16 @@ def get_all_splits(args):
return splits
def prepare_data(args, numericalizer):
def prepare_data(args, numericalizer, embeddings):
splits = get_all_splits(args)
vectors = load_embeddings(args)
logger.info(f'Vocabulary has {numericalizer.num_tokens} tokens from training')
new_vectors = []
new_words = []
for split in splits:
new_vectors += numericalizer.grow_vocab(split, vectors)
logger.info(f'Vocabulary has expanded to {numericalizer.num_tokens} tokens')
if new_vectors:
# concat the old embedding matrix and all the new vector along the first dimension
new_embedding_matrix = torch.cat([numericalizer.vocab.vectors.cpu()] + new_vectors, dim=0)
numericalizer.vocab.vectors = new_embedding_matrix
new_words += numericalizer.grow_vocab(split)
logger.info(f'Vocabulary has expanded to {numericalizer.num_tokens} tokens')
for emb in embeddings:
emb.grow_for_vocab(numericalizer.vocab, new_words)
return splits
@ -186,17 +183,20 @@ def main(argv=sys.argv):
devices = init_devices(args)
save_dict = torch.load(args.best_checkpoint, map_location=devices[0])
numericalizer = make_numericalizer(args)
numericalizer, encoder_embeddings, decoder_embeddings = load_embeddings(args.embeddings, args.encoder_embeddings,
args.decoder_embeddings,
args.max_generative_vocab,
logger)
numericalizer.load(args.path)
for emb in set(encoder_embeddings + decoder_embeddings):
emb.init_for_vocab(numericalizer.vocab)
logger.info(f'Initializing Model')
Model = getattr(models, args.model)
model = Model(numericalizer, args, devices)
model = Model(numericalizer, args, encoder_embeddings, decoder_embeddings)
model_dict = save_dict['model_state_dict']
model.load_state_dict(model_dict)
splits = prepare_data(args, numericalizer)
if args.model != 'MultiLingualTranslationModel':
model.set_embeddings(numericalizer.vocab.vectors)
splits = prepare_data(args, numericalizer, set(encoder_embeddings + decoder_embeddings))
run(args, numericalizer, splits, model, devices[0])

View File

@ -32,17 +32,15 @@
from argparse import ArgumentParser
import ujson as json
import torch
import numpy as np
import random
import asyncio
import logging
import sys
from pprint import pformat
from .data.example import Batch
from .util import set_seed, init_devices, load_config_json, log_model_size, make_numericalizer
from .util import set_seed, init_devices, load_config_json, log_model_size
from . import models
from .utils.embeddings import load_embeddings
from .data.embeddings import load_embeddings
from .tasks.registry import get_tasks
from .tasks.generic_dataset import Example
@ -52,27 +50,24 @@ class ProcessedExample():
pass
class Server():
def __init__(self, args, numericalizer, model, device):
def __init__(self, args, numericalizer, embeddings, model, device):
self.args = args
self.device = device
self.numericalizer = numericalizer
self.model = model
logger.info(f'Vocabulary has {numericalizer.num_tokens} tokens from training')
self._vector_collections = load_embeddings(args)
self._embeddings = embeddings
self._cached_tasks = dict()
def numericalize_example(self, ex):
new_vectors = self.numericalizer.grow_vocab([ex], self._vector_collections)
if new_vectors:
# concat the old embedding matrix and all the new vector along the first dimension
new_embedding_matrix = torch.cat([self.numericalizer.vocab.vectors.cpu()] + new_vectors, dim=0)
self.numericalizer.vocab.vectors = new_embedding_matrix
self.model.set_embeddings(new_embedding_matrix)
new_words = self.numericalizer.grow_vocab([ex])
for emb in self._embeddings:
emb.grow_for_vocab(self.numericalizer.vocab, new_words)
# batch of size 1
return Batch.from_examples([ex], self.numericalizer, self.numericalizer.decoder_vocab, device=self.device)
return Batch.from_examples([ex], self.numericalizer, device=self.device)
def handle_request(self, line):
request = json.loads(line)
@ -174,15 +169,19 @@ def main(argv=sys.argv):
devices = init_devices(args)
save_dict = torch.load(args.best_checkpoint, map_location=devices[0])
numericalizer = make_numericalizer(args)
numericalizer, encoder_embeddings, decoder_embeddings = load_embeddings(args.embeddings, args.encoder_embeddings,
args.decoder_embeddings,
args.max_generative_vocab)
numericalizer.load(args.path)
for emb in set(encoder_embeddings + decoder_embeddings):
emb.init_for_vocab(numericalizer.vocab)
logger.info(f'Initializing Model')
Model = getattr(models, args.model)
model = Model(numericalizer, args, devices)
model = Model(numericalizer, args, encoder_embeddings, decoder_embeddings)
model_dict = save_dict['model_state_dict']
model.load_state_dict(model_dict)
server = Server(args, numericalizer, model, devices[0])
server = Server(args, numericalizer, encoder_embeddings + decoder_embeddings, model, devices[0])
server.run()

View File

@ -47,9 +47,9 @@ from . import arguments
from . import models
from .validate import validate
from .util import elapsed_time, set_seed, preprocess_examples, get_trainable_params, make_data_loader, log_model_size, \
init_devices, make_numericalizer
init_devices
from .utils.saver import Saver
from .utils.embeddings import load_embeddings
from .data.embeddings import load_embeddings
def initialize_logger(args):
@ -111,16 +111,21 @@ def prepare_data(args, logger):
if args.vocab_tasks is not None and task.name in args.vocab_tasks:
vocab_sets.extend(split)
numericalizer = make_numericalizer(args)
numericalizer, encoder_embeddings, decoder_embeddings = load_embeddings(args.embeddings, args.encoder_embeddings,
args.decoder_embeddings,
args.max_generative_vocab,
logger)
if args.load is not None:
numericalizer.load(args.save)
else:
vectors = load_embeddings(args, logger)
vocab_sets = (train_sets + val_sets) if len(vocab_sets) == 0 else vocab_sets
logger.info(f'Building vocabulary')
numericalizer.build_vocab(vectors, Example.vocab_fields, vocab_sets)
numericalizer.build_vocab(Example.vocab_fields, vocab_sets)
numericalizer.save(args.save)
for vec in set(encoder_embeddings + decoder_embeddings):
vec.init_for_vocab(numericalizer.vocab)
logger.info(f'Vocabulary has {numericalizer.num_tokens} tokens')
logger.debug(f'The first 200 tokens:')
logger.debug(numericalizer.vocab.itos[:200])
@ -133,7 +138,7 @@ def prepare_data(args, logger):
logger.info('Preprocessing validation data')
preprocess_examples(args, args.val_tasks, val_sets, logger, train=args.val_filter)
return numericalizer, train_sets, val_sets, aux_sets
return numericalizer, encoder_embeddings, decoder_embeddings, train_sets, val_sets, aux_sets
def get_learning_rate(i, args):
@ -392,11 +397,11 @@ def train(args, devices, model, opt, train_sets, train_iterations, numericalizer
break
def init_model(args, numericalizer, devices, logger):
def init_model(args, numericalizer, encoder_embeddings, decoder_embeddings, devices, logger):
model_name = args.model
logger.info(f'Initializing {model_name}')
Model = getattr(models, model_name)
model = Model(numericalizer, args, devices)
model = Model(numericalizer, args, encoder_embeddings, decoder_embeddings)
params = get_trainable_params(model)
log_model_size(logger, model, model_name)
@ -432,7 +437,7 @@ def main(argv=sys.argv):
if args.load is not None:
logger.info(f'Loading vocab from {os.path.join(args.save, args.load)}')
save_dict = torch.load(os.path.join(args.save, args.load))
numericalizer, train_sets, val_sets, aux_sets = prepare_data(args, logger)
numericalizer, encoder_embeddings, decoder_embeddings, train_sets, val_sets, aux_sets = prepare_data(args, logger)
if (args.use_curriculum and aux_sets is None) or (not args.use_curriculum and len(aux_sets)):
logging.error('sth unpleasant is happening with curriculum')
@ -445,7 +450,7 @@ def main(argv=sys.argv):
else:
writer = None
model = init_model(args, numericalizer, devices, logger)
model = init_model(args, numericalizer, encoder_embeddings, decoder_embeddings, devices, logger)
opt = init_opt(args, model)
start_iteration = 1

View File

@ -187,34 +187,23 @@ def load_config_json(args):
args.almond_type_embeddings = False
with open(os.path.join(args.path, 'config.json')) as config_file:
config = json.load(config_file)
retrieve = ['model', 'transformer_layers', 'rnn_layers', 'transformer_hidden', 'dimension',
'load', 'max_val_context_length', 'val_batch_size', 'transformer_heads', 'max_output_length',
'max_effective_vocab', 'max_generative_vocab', 'lower', 'glove_and_char',
'small_glove', 'almond_type_embeddings', 'almond_grammar',
'trainable_decoder_embedding', 'glove_decoder', 'pretrained_decoder_lm',
'retrain_encoder_embedding', 'question', 'locale', 'use_google_translate']
retrieve = ['model', 'seq2seq_encoder', 'seq2seq_decoder', 'transformer_layers', 'rnn_layers',
'transformer_hidden', 'dimension', 'load', 'max_val_context_length', 'val_batch_size',
'transformer_heads', 'max_output_length', 'max_generative_vocab', 'lower', 'encoder_embeddings',
'decoder_embeddings', 'trainable_decoder_embeddings', 'train_encoder_embeddings',
'question', 'locale', 'use_google_translate']
for r in retrieve:
if r in config:
setattr(args, r, config[r])
elif r == 'locale':
setattr(args, r, 'en')
elif r in ('small_glove', 'almond_type_embbedings'):
setattr(args, r, False)
elif r in ('glove_decoder', 'glove_and_char'):
setattr(args, r, True)
elif r == 'trainable_decoder_embedding':
setattr(args, r, 0)
elif r == 'retrain_encoder_embedding':
elif r == 'train_encoder_embedding':
setattr(args, r, False)
else:
setattr(args, r, None)
args.dropout_ratio = 0.0
args.best_checkpoint = os.path.join(args.path, args.checkpoint_name)
def make_numericalizer(args):
return SimpleNumericalizer(max_effective_vocab=args.max_effective_vocab,
max_generative_vocab=args.max_generative_vocab,
pad_first=False)
args.best_checkpoint = os.path.join(args.path, args.checkpoint_name)

View File

@ -23,10 +23,12 @@ workdir=`mktemp -d $TMPDIR/decaNLP-tests-XXXXXX`
trap on_error ERR INT TERM
i=0
for hparams in "" "--use_curriculum"; do
for hparams in "--encoder_embeddings=small_glove+char --decoder_embeddings=small_glove+char" \
"--encoder_embeddings=bert-base-uncased --decoder_embeddings= --trainable_decoder_embeddings=50" \
"--encoder_embeddings=bert-base-uncased --decoder_embeddings= --trainable_decoder_embeddings=50 --seq2seq_encoder=Identity --dimension=768 --rnn_layers=0" ; do
# train
pipenv run python3 -m decanlp train --train_tasks almond --train_iterations 4 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --root "" --embeddings $SRCDIR/embeddings --small_glove --no_commit
pipenv run python3 -m decanlp train --train_tasks almond --train_iterations 6 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --root "" --embeddings $SRCDIR/embeddings --no_commit
# greedy decode
pipenv run python3 -m decanlp predict --tasks almond --evaluate test --path $workdir/model_$i --overwrite --eval_dir $workdir/model_$i/eval_results/ --data $SRCDIR/dataset/ --embeddings $SRCDIR/embeddings