BERT
This commit is contained in:
parent
bbcdb1a3e2
commit
1f7edc7b24
|
@ -82,7 +82,6 @@ def parse(argv):
|
|||
|
||||
parser.add_argument('--vocab_tasks', nargs='+', type=str, help='tasks to use in the construction of the vocabulary')
|
||||
parser.add_argument('--max_output_length', default=100, type=int, help='maximum output length for generation')
|
||||
parser.add_argument('--max_effective_vocab', default=int(1e6), type=int, help='max effective vocabulary size for pretrained embeddings')
|
||||
parser.add_argument('--max_generative_vocab', default=50000, type=int, help='max vocabulary for the generative softmax')
|
||||
parser.add_argument('--max_train_context_length', default=500, type=int, help='maximum length of the contexts during training')
|
||||
parser.add_argument('--max_val_context_length', default=500, type=int, help='maximum length of the contexts during validation')
|
||||
|
@ -90,19 +89,22 @@ def parse(argv):
|
|||
parser.add_argument('--subsample', default=20000000, type=int, help='subsample the datasets')
|
||||
parser.add_argument('--preserve_case', action='store_false', dest='lower', help='whether to preserve casing for all text')
|
||||
|
||||
parser.add_argument('--model', type=str, default='MultitaskQuestionAnsweringNetwork', help='which model to import')
|
||||
parser.add_argument('--model', type=str, choices=['Seq2Seq'], default='Seq2Seq', help='which model to import')
|
||||
parser.add_argument('--seq2seq_encoder', type=str, choices=['MQANEncoder', 'Identity'], default='MQANEncoder',
|
||||
help='which encoder to use for the Seq2Seq model')
|
||||
parser.add_argument('--seq2seq_decoder', type=str, choices=['MQANDecoder'], default='MQANDecoder',
|
||||
help='which decoder to use for the Seq2Seq model')
|
||||
parser.add_argument('--dimension', default=200, type=int, help='output dimensions for all layers')
|
||||
parser.add_argument('--rnn_layers', default=1, type=int, help='number of layers for RNN modules')
|
||||
parser.add_argument('--transformer_layers', default=2, type=int, help='number of layers for transformer modules')
|
||||
parser.add_argument('--transformer_hidden', default=150, type=int, help='hidden size of the transformer modules')
|
||||
parser.add_argument('--transformer_heads', default=3, type=int, help='number of heads for transformer modules')
|
||||
parser.add_argument('--dropout_ratio', default=0.2, type=float, help='dropout for the model')
|
||||
parser.add_argument('--no_glove_and_char', action='store_false', dest='glove_and_char', help='turn off GloVe and CharNGram embeddings')
|
||||
parser.add_argument('--locale', default='en', help='locale to use for word embeddings')
|
||||
parser.add_argument('--retrain_encoder_embedding', default=False, action='store_true', help='whether to retrain encoder embeddings')
|
||||
parser.add_argument('--trainable_decoder_embedding', default=0, type=int, help='size of trainable portion of decoder embedding (0 or omit to disable)')
|
||||
parser.add_argument('--no_glove_decoder', action='store_false', dest='glove_decoder', help='turn off GloVe embeddings from decoder')
|
||||
parser.add_argument('--pretrained_decoder_lm', help='pretrained language model to use as embedding layer for the decoder (omit to disable)')
|
||||
|
||||
parser.add_argument('--encoder_embeddings', default='glove+char', help='which word embedding to use on the encoder side; use a bert-* pretrained model for BERT; multiple embeddings can be concatenated with +')
|
||||
parser.add_argument('--train_encoder_embeddings', action='store_true', default=False, help='back propagate into pretrained encoder embedding (recommended for BERT)')
|
||||
parser.add_argument('--decoder_embeddings', default='glove+char', help='which pretrained word embedding to use on the decoder side')
|
||||
parser.add_argument('--trainable_decoder_embeddings', default=0, type=int, help='size of trainable portion of decoder embedding (0 or omit to disable)')
|
||||
|
||||
parser.add_argument('--warmup', default=800, type=int, help='warmup for learning rate')
|
||||
parser.add_argument('--grad_clip', default=1.0, type=float, help='gradient clipping')
|
||||
|
@ -123,8 +125,6 @@ def parse(argv):
|
|||
|
||||
parser.add_argument('--skip_cache', action='store_true', dest='skip_cache_bool', help='whether to use exisiting cached splits or generate new ones')
|
||||
parser.add_argument('--lr_rate', default=0.001, type=float, help='initial_learning_rate')
|
||||
parser.add_argument('--small_glove', action='store_true', help='Use glove.6B.50d instead of glove.840B.300d')
|
||||
parser.add_argument('--almond_type_embeddings', action='store_true', help='Add type-based word embeddings for Almond task')
|
||||
parser.add_argument('--use_curriculum', action='store_true', help='Use curriculum learning')
|
||||
parser.add_argument('--aux_dataset', default='', type=str, help='path to auxiliary dataset (ignored if curriculum is not used)')
|
||||
parser.add_argument('--curriculum_max_frac', default=1.0, type=float, help='max fraction of harder dataset to keep for curriculum')
|
||||
|
|
|
@ -29,14 +29,12 @@
|
|||
|
||||
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import logging
|
||||
import sys
|
||||
from pprint import pformat
|
||||
|
||||
from .utils.embeddings import load_embeddings
|
||||
from .util import set_seed
|
||||
from .data.embeddings import load_embeddings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -44,8 +42,8 @@ logger = logging.getLogger(__name__)
|
|||
def get_args(argv):
|
||||
parser = ArgumentParser(prog=argv[0])
|
||||
parser.add_argument('--seed', default=123, type=int, help='Random seed.')
|
||||
parser.add_argument('--embeddings', default='./decaNLP/.embeddings', type=str, help='where to save embeddings.')
|
||||
parser.add_argument('--locale', default='en', help='locale to use for word embeddings')
|
||||
parser.add_argument('-d', '--destdir', default='./decaNLP/.embeddings', type=str, help='where to save embeddings.')
|
||||
parser.add_argument('--embeddings', default='glove+char', help='which embeddings to download')
|
||||
|
||||
args = parser.parse_args(argv[1:])
|
||||
return args
|
||||
|
@ -55,9 +53,5 @@ def main(argv=sys.argv):
|
|||
args = get_args(argv)
|
||||
logger.info(f'Arguments:\n{pformat(vars(args))}')
|
||||
|
||||
np.random.seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
|
||||
load_embeddings(args, load_almond_embeddings=False)
|
||||
set_seed(args.seed)
|
||||
load_embeddings(args.destdir, args.embeddings, '')
|
||||
|
|
|
@ -30,17 +30,15 @@
|
|||
|
||||
import torch
|
||||
|
||||
import logging
|
||||
from ..data import word_vectors
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
from .word_vectors import Vectors
|
||||
|
||||
ENTITIES = ['DATE', 'DURATION', 'EMAIL_ADDRESS', 'HASHTAG',
|
||||
'LOCATION', 'NUMBER', 'PHONE_NUMBER', 'QUOTED_STRING',
|
||||
'TIME', 'URL', 'USERNAME', 'PATH_NAME', 'CURRENCY']
|
||||
MAX_ARG_VALUES = 5
|
||||
|
||||
class AlmondEmbeddings(word_vectors.Vectors):
|
||||
|
||||
class AlmondEmbeddings(Vectors):
|
||||
def __init__(self, name=None, cache=None, **kw):
|
||||
super().__init__(name, cache, **kw)
|
||||
|
||||
|
@ -63,27 +61,4 @@ class AlmondEmbeddings(word_vectors.Vectors):
|
|||
self.itos = itos
|
||||
self.stoi = {word: i for i, word in enumerate(itos)}
|
||||
self.vectors = torch.stack(vectors, dim=0).view(-1, dim)
|
||||
self.dim = dim
|
||||
|
||||
|
||||
def load_embeddings(args, logger=_logger, load_almond_embeddings=True):
|
||||
logger.info(f'Getting pretrained word vectors')
|
||||
|
||||
language = args.locale.split('-')[0]
|
||||
|
||||
if language == 'en':
|
||||
char_vectors = word_vectors.CharNGram(cache=args.embeddings)
|
||||
if args.small_glove:
|
||||
glove_vectors = word_vectors.GloVe(cache=args.embeddings, name="6B", dim=50)
|
||||
else:
|
||||
glove_vectors = word_vectors.GloVe(cache=args.embeddings)
|
||||
vectors = [char_vectors, glove_vectors]
|
||||
# elif args.locale == 'zh':
|
||||
# Chinese word embeddings
|
||||
else:
|
||||
# default to fastText
|
||||
vectors = [word_vectors.FastText(cache=args.embeddings, language=language)]
|
||||
|
||||
if load_almond_embeddings and args.almond_type_embeddings:
|
||||
vectors.append(AlmondEmbeddings())
|
||||
return vectors
|
||||
self.dim = dim
|
|
@ -0,0 +1,232 @@
|
|||
#
|
||||
# Copyright (c) 2018-2019, Salesforce, Inc.
|
||||
# The Board of Trustees of the Leland Stanford Junior University
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import torch
|
||||
import os
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
from transformers import AutoTokenizer, AutoModel, BertConfig
|
||||
from typing import NamedTuple, List
|
||||
|
||||
from .numericalizer import SimpleNumericalizer, BertNumericalizer
|
||||
from . import word_vectors
|
||||
from .almond_embeddings import AlmondEmbeddings
|
||||
from .pretrained_lstm_lm import PretrainedLTSMLM
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingOutput(NamedTuple):
|
||||
all_layers: List[torch.Tensor]
|
||||
last_layer: torch.Tensor
|
||||
|
||||
|
||||
class WordVectorEmbedding(torch.nn.Module):
|
||||
def __init__(self, vec_collection):
|
||||
super().__init__()
|
||||
self._vec_collection = vec_collection
|
||||
self.dim = vec_collection.dim
|
||||
self.num_layers = 0
|
||||
self.embedding = None
|
||||
|
||||
def init_for_vocab(self, vocab):
|
||||
vectors = torch.empty(len(vocab), self.dim, device=torch.device('cpu'))
|
||||
for ti, token in enumerate(vocab.itos):
|
||||
vectors[ti] = self._vec_collection[token.strip()]
|
||||
|
||||
self.embedding = torch.nn.Embedding(len(vocab.itos), self.dim)
|
||||
self.embedding.weight.data = vectors
|
||||
|
||||
def grow_for_vocab(self, vocab, new_words):
|
||||
if not new_words:
|
||||
return
|
||||
new_vectors = []
|
||||
for word in new_words:
|
||||
new_vector = self._vec_collection[word]
|
||||
|
||||
# charNgram returns a [1, D] tensor, while Glove returns a [D] tensor
|
||||
# normalize to [1, D] so we can concat along the second dimension
|
||||
# and later concat all vectors along the first
|
||||
new_vector = new_vector if new_vector.dim() > 1 else new_vector.unsqueeze(0)
|
||||
new_vectors.append(new_vector)
|
||||
|
||||
self.embedding.weight.data = torch.cat([self.embedding.weight.data.cpu()] + new_vectors, dim=0)
|
||||
|
||||
def forward(self, input : torch.Tensor, padding=None):
|
||||
last_layer = self.embedding(input.cpu()).to(input.device)
|
||||
return EmbeddingOutput(all_layers=[last_layer], last_layer=last_layer)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
# ignore attempts to move the word embedding, which should stay on CPU
|
||||
kwargs['device'] = torch.device('cpu')
|
||||
return super().to(*args, **kwargs)
|
||||
|
||||
def cuda(self, device=None):
|
||||
# ignore attempts to move the word embedding
|
||||
pass
|
||||
|
||||
|
||||
class BertEmbedding(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
model.config.output_hidden_states = True
|
||||
self.dim = model.config.hidden_size
|
||||
self.num_layers = model.config.num_hidden_layers
|
||||
self.model = model
|
||||
|
||||
def init_for_vocab(self, vocab):
|
||||
self.model.resize_token_embeddings(len(vocab))
|
||||
|
||||
def grow_for_vocab(self, vocab, new_words):
|
||||
self.model.resize_token_embeddings(len(vocab))
|
||||
|
||||
def forward(self, input : torch.Tensor, padding=None):
|
||||
last_hidden_state, _pooled, hidden_states = self.model(input, attention_mask=(~padding).to(dtype=torch.float))
|
||||
|
||||
return EmbeddingOutput(all_layers=hidden_states, last_layer=last_hidden_state)
|
||||
|
||||
|
||||
class PretrainedLMEmbedding(torch.nn.Module):
|
||||
def __init__(self, model_name, cachedir):
|
||||
super().__init__()
|
||||
# map to CPU first, we will be moved to the right place later
|
||||
pretrained_save_dict = torch.load(os.path.join(cachedir, model_name), map_location=torch.device('cpu'))
|
||||
|
||||
self.itos = pretrained_save_dict['vocab']
|
||||
self.stoi = defaultdict(lambda: 0, {
|
||||
w: i for i, w in enumerate(self.pretrained_decoder_vocab_itos)
|
||||
})
|
||||
self.dim = pretrained_save_dict['settings']['nhid']
|
||||
self.num_layers = 1
|
||||
self.model = PretrainedLTSMLM(rnn_type=pretrained_save_dict['settings']['rnn_type'],
|
||||
ntoken=len(self.pretrained_decoder_vocab_itos),
|
||||
emsize=pretrained_save_dict['settings']['emsize'],
|
||||
nhid=pretrained_save_dict['settings']['nhid'],
|
||||
nlayers=pretrained_save_dict['settings']['nlayers'],
|
||||
dropout=0.0)
|
||||
self.model.load_state_dict(pretrained_save_dict['model'], strict=True)
|
||||
|
||||
self.vocab_to_pretrained = None
|
||||
|
||||
def init_for_vocab(self, vocab):
|
||||
self.vocab_to_pretrained = torch.empty(len(self.vocab), dtype=torch.int64)
|
||||
|
||||
unk_id = self.stoi['<unk>']
|
||||
for ti, token in enumerate(vocab.itos):
|
||||
if token in self.pretrained_decoder_vocab_stoi:
|
||||
self.vocab_to_pretrained[ti] = self.stoi[token]
|
||||
else:
|
||||
self.vocab_to_pretrained[ti] = unk_id
|
||||
|
||||
def grow_for_vocab(self, vocab, new_words):
|
||||
self.init_for_vocab(vocab)
|
||||
|
||||
def forward(self, input : torch.Tensor, padding=None):
|
||||
pretrained_indices = torch.gather(self.vocab_to_pretrained, dim=0, index=input)
|
||||
rnn_output = self.model(pretrained_indices)
|
||||
return EmbeddingOutput(all_layers=[rnn_output], last_layer=rnn_output)
|
||||
|
||||
|
||||
def _is_bert(embedding_name):
|
||||
return embedding_name.startswith('bert-')
|
||||
|
||||
|
||||
def _name_to_vector(emb_name, cachedir):
|
||||
if emb_name == 'glove':
|
||||
return WordVectorEmbedding(word_vectors.GloVe(cache=cachedir))
|
||||
elif emb_name == 'small_glove':
|
||||
return WordVectorEmbedding(word_vectors.GloVe(cache=cachedir, name="6B", dim=50))
|
||||
elif emb_name == 'char':
|
||||
return WordVectorEmbedding(word_vectors.CharNGram(cache=cachedir))
|
||||
elif emb_name == 'almond_type':
|
||||
return AlmondEmbeddings()
|
||||
elif emb_name.startswith('fasttext/'):
|
||||
# FIXME this should use the fasttext library
|
||||
return WordVectorEmbedding(word_vectors.FastText(cache=cachedir, language=emb_name[len('fasttext/'):]))
|
||||
elif emb_name.startswith('pretrained_lstm/'):
|
||||
return PretrainedLMEmbedding(emb_name[len('pretrained_lstm/'):], cachedir=cachedir)
|
||||
else:
|
||||
raise ValueError(f'Unrecognized embedding name {emb_name}')
|
||||
|
||||
|
||||
def load_embeddings(cachedir, encoder_emb_names, decoder_emb_names, max_generative_vocab=50000, logger=_logger):
|
||||
logger.info(f'Getting pretrained word vectors and pretrained models')
|
||||
|
||||
encoder_emb_names = encoder_emb_names.split('+')
|
||||
decoder_emb_names = decoder_emb_names.split('+')
|
||||
|
||||
all_vectors = {}
|
||||
encoder_vectors = []
|
||||
decoder_vectors = []
|
||||
|
||||
numericalizer = None
|
||||
for emb_name in encoder_emb_names:
|
||||
if not emb_name:
|
||||
continue
|
||||
if _is_bert(emb_name):
|
||||
if numericalizer is not None:
|
||||
raise ValueError('Cannot specify multiple BERT embeddings')
|
||||
|
||||
config = BertConfig.from_pretrained(emb_name, cache_dir=cachedir)
|
||||
config.output_hidden_states = True
|
||||
numericalizer = BertNumericalizer(config, emb_name, max_generative_vocab=max_generative_vocab, cache=cachedir)
|
||||
|
||||
# load the tokenizer once to ensure all files are downloaded
|
||||
AutoTokenizer.from_pretrained(emb_name, cache_dir=cachedir)
|
||||
|
||||
encoder_vectors.append(BertEmbedding(AutoModel.from_pretrained(emb_name, config=config, cache_dir=cachedir)))
|
||||
else:
|
||||
if numericalizer is not None:
|
||||
logger.warning('Combining BERT embeddings with other pretrained embeddings is unlikely to work')
|
||||
|
||||
if emb_name in all_vectors:
|
||||
encoder_vectors.append(all_vectors[emb_name])
|
||||
else:
|
||||
vec = _name_to_vector(emb_name, cachedir)
|
||||
all_vectors[emb_name] = vec
|
||||
encoder_vectors.append(vec)
|
||||
|
||||
for emb_name in decoder_emb_names:
|
||||
if not emb_name:
|
||||
continue
|
||||
if _is_bert(emb_name):
|
||||
raise ValueError('BERT embeddings cannot be specified in the decoder')
|
||||
|
||||
if emb_name in all_vectors:
|
||||
decoder_vectors.append(all_vectors[emb_name])
|
||||
else:
|
||||
vec = _name_to_vector(emb_name, cachedir)
|
||||
all_vectors[emb_name] = vec
|
||||
decoder_vectors.append(vec)
|
||||
|
||||
if numericalizer is None:
|
||||
numericalizer = SimpleNumericalizer(max_generative_vocab=max_generative_vocab, pad_first=False)
|
||||
|
||||
return numericalizer, encoder_vectors, decoder_vectors
|
|
@ -41,7 +41,8 @@ class BertNumericalizer(object):
|
|||
Numericalizer that uses BertTokenizer from huggingface's transformers library.
|
||||
"""
|
||||
|
||||
def __init__(self, pretrained_tokenizer, max_generative_vocab, cache=None, fix_length=None):
|
||||
def __init__(self, config, pretrained_tokenizer, max_generative_vocab, cache=None, fix_length=None):
|
||||
self.config = config
|
||||
self._pretrained_name = pretrained_tokenizer
|
||||
self.max_generative_vocab = max_generative_vocab
|
||||
self._cache = cache
|
||||
|
@ -50,13 +51,18 @@ class BertNumericalizer(object):
|
|||
|
||||
self.fix_length = fix_length
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
return self._tokenizer
|
||||
|
||||
@property
|
||||
def num_tokens(self):
|
||||
return self._tokenizer.vocab_size
|
||||
return len(self._tokenizer)
|
||||
|
||||
def load(self, save_dir):
|
||||
self.config = BertConfig.from_pretrained(os.path.join(save_dir, 'bert-config.json'), cache_dir=self._cache)
|
||||
self._tokenizer = MaskedBertTokenizer.from_pretrained(save_dir, config=self.config, cache_dir=self._cache)
|
||||
# HACK we cannot save the tokenizer without this
|
||||
del self._tokenizer.init_kwargs['config']
|
||||
|
||||
with open(os.path.join(save_dir, 'decoder-vocab.txt'), 'r') as fp:
|
||||
self._decoder_words = [line.rstrip('\n') for line in fp]
|
||||
|
@ -64,15 +70,15 @@ class BertNumericalizer(object):
|
|||
self._init()
|
||||
|
||||
def save(self, save_dir):
|
||||
self.config.save_pretrained(os.path.join(save_dir, 'bert-config.json'))
|
||||
self._tokenizer.save_pretrained(os.path.join(save_dir))
|
||||
self._tokenizer.save_pretrained(save_dir)
|
||||
with open(os.path.join(save_dir, 'decoder-vocab.txt'), 'w') as fp:
|
||||
for word in self._decoder_words:
|
||||
fp.write(word + '\n')
|
||||
|
||||
def build_vocab(self, vectors, vocab_fields, vocab_sets):
|
||||
self.config = BertConfig.from_pretrained(self._pretrained_name, cache_dir=self._cache)
|
||||
def build_vocab(self, vocab_fields, vocab_sets):
|
||||
self._tokenizer = MaskedBertTokenizer.from_pretrained(self._pretrained_name, config=self.config, cache_dir=self._cache)
|
||||
# HACK we cannot save the tokenizer without this
|
||||
del self._tokenizer.init_kwargs['config']
|
||||
|
||||
# ensure that init, eos, unk and pad are set
|
||||
# this method has no effect if the tokens are already set according to the tokenizer class
|
||||
|
@ -83,19 +89,31 @@ class BertNumericalizer(object):
|
|||
'pad_token': '[PAD]'
|
||||
})
|
||||
|
||||
# do a pass over all the answers in the dataset, and construct a counter of wordpieces
|
||||
# do a pass over all the data in the dataset
|
||||
# in this pass, we
|
||||
# 1) tokenize everything, to ensure we account for all added tokens
|
||||
# 2) we construct a counter of wordpieces in the answers, for the decoder vocabulary
|
||||
decoder_words = collections.Counter()
|
||||
for dataset in vocab_sets:
|
||||
for example in dataset:
|
||||
self._tokenizer.tokenize(example.context, example.context_word_mask)
|
||||
self._tokenizer.tokenize(example.question, example.question_word_mask)
|
||||
|
||||
tokens = self._tokenizer.tokenize(example.answer, example.answer_word_mask)
|
||||
decoder_words.update(tokens)
|
||||
|
||||
self._decoder_words = decoder_words.most_common(self.max_generative_vocab)
|
||||
self._decoder_words = [word for word, _freq in decoder_words.most_common(self.max_generative_vocab)]
|
||||
|
||||
self._init()
|
||||
|
||||
def grow_vocab(self, examples, vectors):
|
||||
# TODO
|
||||
def grow_vocab(self, examples):
|
||||
# do a pass over all the data in the dataset and tokenize everything
|
||||
# this will add any new tokens that are not to be converted into word-pieces
|
||||
for example in examples:
|
||||
self._tokenizer.tokenize(example.context, example.context_word_mask)
|
||||
self._tokenizer.tokenize(example.question, example.question_word_mask)
|
||||
|
||||
# return no new words - BertEmbedding will resize the embedding regardless
|
||||
return []
|
||||
|
||||
def _init(self):
|
||||
|
@ -128,9 +146,9 @@ class BertNumericalizer(object):
|
|||
wp_tokenized.append(self._tokenizer.tokenize(tokens, mask))
|
||||
|
||||
if self.fix_length is None:
|
||||
max_len = max(len(x) for x in minibatch) + 2
|
||||
max_len = max(len(x) for x in wp_tokenized)
|
||||
else:
|
||||
max_len = self.fix_length + 2
|
||||
max_len = self.fix_length
|
||||
padded = []
|
||||
lengths = []
|
||||
numerical = []
|
||||
|
@ -148,7 +166,7 @@ class BertNumericalizer(object):
|
|||
[self.pad_token] * max(0, max_len - len(wp_tokens))
|
||||
|
||||
padded.append(padded_example)
|
||||
lengths.append(len(padded_example) - max(0, max_len - len(wp_tokens)))
|
||||
lengths.append(len(wp_tokens) + 2)
|
||||
|
||||
numerical.append(self._tokenizer.convert_tokens_to_ids(padded_example))
|
||||
decoder_numerical.append([decoder_vocab.encode(word) for word in padded_example])
|
||||
|
@ -157,7 +175,7 @@ class BertNumericalizer(object):
|
|||
numerical = torch.tensor(numerical, dtype=torch.int64, device=device)
|
||||
decoder_numerical = torch.tensor(decoder_numerical, dtype=torch.int64, device=device)
|
||||
|
||||
return SequentialField(tokens=padded, length=length, value=numerical, limited=decoder_numerical)
|
||||
return SequentialField(length=length, value=numerical, limited=decoder_numerical)
|
||||
|
||||
def decode(self, tensor):
|
||||
return self._tokenizer.convert_ids_to_tokens(tensor)
|
||||
|
@ -176,7 +194,10 @@ class BertNumericalizer(object):
|
|||
if token in (self.init_token, self.pad_token):
|
||||
continue
|
||||
if token.startswith('##'):
|
||||
tokens[-1] += token[2:]
|
||||
if len(tokens) == 0:
|
||||
tokens.append(token[2:])
|
||||
else:
|
||||
tokens[-1] += token[2:]
|
||||
else:
|
||||
tokens.append(token)
|
||||
|
||||
|
|
|
@ -54,10 +54,10 @@ class MaskedWordPieceTokenizer:
|
|||
|
||||
def tokenize(self, tokens, mask):
|
||||
output_tokens = []
|
||||
for token, should_word_split in tokens, mask:
|
||||
for token, should_word_split in zip(tokens, mask):
|
||||
if not should_word_split:
|
||||
if token not in self.vocab and token not in self.added_tokens_encoder:
|
||||
token_id = len(self.added_tokens_encoder)
|
||||
token_id = len(self.vocab) + len(self.added_tokens_encoder)
|
||||
self.added_tokens_encoder[token] = token_id
|
||||
self.added_tokens_decoder[token_id] = token
|
||||
output_tokens.append(token)
|
||||
|
@ -95,11 +95,60 @@ class MaskedWordPieceTokenizer:
|
|||
return output_tokens
|
||||
|
||||
|
||||
class IToSWrapper:
|
||||
"""Wrap the ordered dict vocabs to look like a list int -> str"""
|
||||
|
||||
def __init__(self, base_vocab, added_tokens):
|
||||
self.base_vocab = base_vocab
|
||||
self.added_tokens = added_tokens
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, slice):
|
||||
return [self[key] for key in range(key.start or 0, key.stop or len(self), key.step or 1)]
|
||||
|
||||
if key < len(self.base_vocab):
|
||||
return self.base_vocab[key]
|
||||
else:
|
||||
return self.added_tokens[key]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.base_vocab) + len(self.added_tokens)
|
||||
|
||||
def __iter__(self):
|
||||
for key in range(len(self.base_vocab)):
|
||||
yield self.base_vocab[key]
|
||||
for key in range(len(self.base_vocab), len(self.base_vocab) + len(self.added_tokens)):
|
||||
yield self.added_tokens[key]
|
||||
|
||||
|
||||
class SToIWrapper:
|
||||
"""Wrap the ordered dict vocabs to look like a single dictionary"""
|
||||
|
||||
def __init__(self, base_vocab, added_tokens):
|
||||
self.base_vocab = base_vocab
|
||||
self.added_tokens = added_tokens
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in self.base_vocab:
|
||||
return self.base_vocab[key]
|
||||
else:
|
||||
return self.added_tokens[key]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.base_vocab) + len(self.added_tokens)
|
||||
|
||||
def __iter__(self):
|
||||
for key in self.base_vocab:
|
||||
yield key
|
||||
for key in self.added_tokens:
|
||||
yield key
|
||||
|
||||
|
||||
class MaskedBertTokenizer(BertTokenizer):
|
||||
"""
|
||||
A modified BertTokenizer that respects a mask deciding whether a token should be split or not.
|
||||
"""
|
||||
def __init__(self, *args, do_lower_case, do_basic_tokenize, **kwargs):
|
||||
def __init__(self, *args, do_lower_case=False, do_basic_tokenize=False, **kwargs):
|
||||
# override do_lower_case and do_basic_tokenize unconditionally
|
||||
super().__init__(*args, do_lower_case=False, do_basic_tokenize=False, **kwargs)
|
||||
|
||||
|
@ -109,14 +158,21 @@ class MaskedBertTokenizer(BertTokenizer):
|
|||
added_tokens_decoder=self.added_tokens_decoder,
|
||||
unk_token=self.unk_token)
|
||||
|
||||
self._itos = IToSWrapper(self.ids_to_tokens, self.added_tokens_decoder)
|
||||
self._stoi = SToIWrapper(self.vocab, self.added_tokens_encoder)
|
||||
|
||||
def tokenize(self, tokens, mask=None):
|
||||
return self.wordpiece_tokenizer.tokenize(tokens, mask)
|
||||
|
||||
# provide an interface that DecoderVocabulary can like
|
||||
# provide an interface similar to Vocab
|
||||
|
||||
def __len__(self):
|
||||
return len(self.vocab) + len(self.added_tokens_encoder)
|
||||
|
||||
@property
|
||||
def stoi(self):
|
||||
return self.vocab
|
||||
return self._stoi
|
||||
|
||||
@property
|
||||
def itos(self):
|
||||
return self.ids_to_tokens
|
||||
return self._itos
|
|
@ -34,5 +34,4 @@ from typing import NamedTuple, List
|
|||
class SequentialField(NamedTuple):
|
||||
value : torch.tensor
|
||||
length : torch.tensor
|
||||
limited : torch.tensor
|
||||
tokens : List[List[str]]
|
||||
limited : torch.tensor
|
|
@ -35,8 +35,7 @@ from .sequential_field import SequentialField
|
|||
from .decoder_vocab import DecoderVocabulary
|
||||
|
||||
class SimpleNumericalizer(object):
|
||||
def __init__(self, max_effective_vocab, max_generative_vocab, fix_length=None, pad_first=False):
|
||||
self.max_effective_vocab = max_effective_vocab
|
||||
def __init__(self, max_generative_vocab, fix_length=None, pad_first=False):
|
||||
self.max_generative_vocab = max_generative_vocab
|
||||
|
||||
self.init_token = '<init>'
|
||||
|
@ -58,17 +57,15 @@ class SimpleNumericalizer(object):
|
|||
def save(self, save_dir):
|
||||
torch.save(self.vocab, os.path.join(save_dir, 'vocab.pth'))
|
||||
|
||||
def build_vocab(self, vectors, vocab_fields, vocab_sets):
|
||||
def build_vocab(self, vocab_fields, vocab_sets):
|
||||
self.vocab = Vocab.build_from_data(vocab_fields, *vocab_sets,
|
||||
unk_token=self.unk_token,
|
||||
init_token=self.init_token,
|
||||
eos_token=self.eos_token,
|
||||
pad_token=self.pad_token,
|
||||
max_size=self.max_effective_vocab,
|
||||
vectors=vectors)
|
||||
pad_token=self.pad_token)
|
||||
self._init_vocab()
|
||||
|
||||
def _grow_vocab_one(self, sentence, vectors, new_vectors):
|
||||
def _grow_vocab_one(self, sentence, new_words):
|
||||
assert isinstance(sentence, list)
|
||||
|
||||
# check if all the words are in the vocabulary, and if not
|
||||
|
@ -77,22 +74,15 @@ class SimpleNumericalizer(object):
|
|||
if word not in self.vocab.stoi:
|
||||
self.vocab.stoi[word] = len(self.vocab.itos)
|
||||
self.vocab.itos.append(word)
|
||||
new_words.append(word)
|
||||
|
||||
new_vector = [vec[word] for vec in vectors]
|
||||
|
||||
# charNgram returns a [1, D] tensor, while Glove returns a [D] tensor
|
||||
# normalize to [1, D] so we can concat along the second dimension
|
||||
# and later concat all vectors along the first
|
||||
new_vector = [vec if vec.dim() > 1 else vec.unsqueeze(0) for vec in new_vector]
|
||||
new_vectors.append(torch.cat(new_vector, dim=1))
|
||||
|
||||
def grow_vocab(self, examples, vectors):
|
||||
new_vectors = []
|
||||
def grow_vocab(self, examples):
|
||||
new_words = []
|
||||
for ex in examples:
|
||||
self._grow_vocab_one(ex.context, vectors, new_vectors)
|
||||
self._grow_vocab_one(ex.question, vectors, new_vectors)
|
||||
self._grow_vocab_one(ex.answer, vectors, new_vectors)
|
||||
return new_vectors
|
||||
self._grow_vocab_one(ex.context, new_words)
|
||||
self._grow_vocab_one(ex.question, new_words)
|
||||
self._grow_vocab_one(ex.answer, new_words)
|
||||
return new_words
|
||||
|
||||
def _init_vocab(self):
|
||||
self.init_id = self.vocab.stoi[self.init_token]
|
||||
|
@ -113,7 +103,7 @@ class SimpleNumericalizer(object):
|
|||
if self.fix_length is None:
|
||||
max_len = max(len(x[0]) for x in minibatch)
|
||||
else:
|
||||
max_len = self.fix_length + 2
|
||||
max_len = self.fix_length
|
||||
padded = []
|
||||
lengths = []
|
||||
numerical = []
|
||||
|
@ -140,7 +130,7 @@ class SimpleNumericalizer(object):
|
|||
numerical = torch.tensor(numerical, dtype=torch.int64, device=device)
|
||||
decoder_numerical = torch.tensor(decoder_numerical, dtype=torch.int64, device=device)
|
||||
|
||||
return SequentialField(tokens=padded, length=length, value=numerical, limited=decoder_numerical)
|
||||
return SequentialField(length=length, value=numerical, limited=decoder_numerical)
|
||||
|
||||
def decode(self, tensor):
|
||||
return [self.vocab.itos[idx] for idx in tensor]
|
||||
|
|
|
@ -20,8 +20,7 @@ class Vocab(object):
|
|||
numerical identifiers.
|
||||
itos: A list of token strings indexed by their numerical identifiers.
|
||||
"""
|
||||
def __init__(self, counter, max_size=None, min_freq=1, specials=('<pad>',),
|
||||
vectors=None, cat_vectors=True):
|
||||
def __init__(self, counter, max_size=None, min_freq=1, specials=('<pad>',)):
|
||||
"""Create a Vocab object from a collections.Counter.
|
||||
|
||||
Arguments:
|
||||
|
@ -60,10 +59,6 @@ class Vocab(object):
|
|||
self.itos.append(word)
|
||||
self.stoi[word] = len(self.itos) - 1
|
||||
|
||||
self.vectors = None
|
||||
if vectors is not None:
|
||||
self.load_vectors(vectors, cat=cat_vectors)
|
||||
|
||||
def __eq__(self, other):
|
||||
if self.freqs != other.freqs:
|
||||
return False
|
||||
|
@ -71,8 +66,6 @@ class Vocab(object):
|
|||
return False
|
||||
if self.itos != other.itos:
|
||||
return False
|
||||
if self.vectors != other.vectors:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __len__(self):
|
||||
|
@ -85,54 +78,6 @@ class Vocab(object):
|
|||
self.itos.append(w)
|
||||
self.stoi[w] = len(self.itos) - 1
|
||||
|
||||
def load_vectors(self, vectors, cat=True):
|
||||
"""
|
||||
Arguments:
|
||||
vectors: one of or a list containing instantiations of the
|
||||
GloVe, CharNGram, or Vectors classes.
|
||||
"""
|
||||
if not isinstance(vectors, list):
|
||||
vectors = [vectors]
|
||||
if cat:
|
||||
tot_dim = sum(v.dim for v in vectors)
|
||||
self.vectors = torch.Tensor(len(self), tot_dim)
|
||||
for ti, token in enumerate(self.itos):
|
||||
start_dim = 0
|
||||
for v in vectors:
|
||||
end_dim = start_dim + v.dim
|
||||
self.vectors[ti][start_dim:end_dim] = v[token.strip()]
|
||||
start_dim = end_dim
|
||||
assert(start_dim == tot_dim)
|
||||
else:
|
||||
self.vectors = [torch.Tensor(len(self), v.dim) for v in vectors]
|
||||
for ti, t in enumerate(self.itos):
|
||||
for vi, v in enumerate(vectors):
|
||||
self.vectors[vi][ti] = v[t.strip()]
|
||||
|
||||
def set_vectors(self, stoi, vectors, dim, unk_init=torch.Tensor.zero_):
|
||||
"""
|
||||
Set the vectors for the Vocab instance from a collection of Tensors.
|
||||
|
||||
Arguments:
|
||||
stoi: A dictionary of string to the index of the associated vector
|
||||
in the `vectors` input argument.
|
||||
vectors: An indexed iterable (or other structure supporting __getitem__) that
|
||||
given an input index, returns a FloatTensor representing the vector
|
||||
for the token associated with the index. For example,
|
||||
vector[stoi["string"]] should return the vector for "string".
|
||||
dim: The dimensionality of the vectors.
|
||||
unk_init (callback): by default, initialize out-of-vocabulary word vectors
|
||||
to zero vectors; can be any function that takes in a Tensor and
|
||||
returns a Tensor of the same size. Default: torch.Tensor.zero_
|
||||
"""
|
||||
self.vectors = torch.Tensor(len(self), dim)
|
||||
for i, token in enumerate(self.itos):
|
||||
wv_index = stoi.get(token, None)
|
||||
if wv_index is not None:
|
||||
self.vectors[i] = vectors[wv_index]
|
||||
else:
|
||||
self.vectors[i] = unk_init(self.vectors[i])
|
||||
|
||||
@staticmethod
|
||||
def build_from_data(field_names, *args, unk_token=None, pad_token=None, init_token=None, eos_token=None, **kwargs):
|
||||
"""Construct the Vocab object for this field from one or more datasets.
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
# The following code was copied and adapted from github.com/floyhub/world-language-model
|
||||
#
|
||||
# BSD 3-Clause License
|
||||
#
|
||||
# Copyright (c) 2017,
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from torch import nn
|
||||
|
||||
class PretrainedLTSMLM(nn.Module):
|
||||
"""Container module with an encoder, a recurrent module, and a decoder."""
|
||||
|
||||
def __init__(self, rnn_type, ntoken, emsize, nhid, nlayers, dropout=0.5, tie_weights=False):
|
||||
super(PretrainedLTSMLM, self).__init__()
|
||||
self.drop = nn.Dropout(dropout)
|
||||
self.encoder = nn.Embedding(ntoken, emsize) # Token2Embeddings
|
||||
if rnn_type in ['LSTM', 'GRU']:
|
||||
self.rnn = getattr(nn, rnn_type)(emsize, nhid, nlayers, dropout=dropout)
|
||||
else:
|
||||
try:
|
||||
nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
|
||||
except KeyError:
|
||||
raise ValueError( """An invalid option for `--model` was supplied,
|
||||
options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""")
|
||||
self.rnn = nn.RNN(emsize, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)
|
||||
self.decoder = nn.Linear(nhid, ntoken)
|
||||
|
||||
# Optionally tie weights as in:
|
||||
# "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
|
||||
# https://arxiv.org/abs/1608.05859
|
||||
# and
|
||||
# "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
|
||||
# https://arxiv.org/abs/1611.01462
|
||||
if tie_weights:
|
||||
if nhid != emsize:
|
||||
raise ValueError('When using the tied flag, nhid must be equal to emsize')
|
||||
self.decoder.weight = self.encoder.weight
|
||||
|
||||
self.init_weights()
|
||||
|
||||
self.rnn_type = rnn_type
|
||||
self.nhid = nhid
|
||||
self.nlayers = nlayers
|
||||
|
||||
def init_weights(self):
|
||||
initrange = 0.1
|
||||
self.encoder.weight.data.uniform_(-initrange, initrange)
|
||||
self.decoder.bias.data.fill_(0)
|
||||
self.decoder.weight.data.uniform_(-initrange, initrange)
|
||||
|
||||
def encode(self, input, hidden=None):
|
||||
emb = self.drop(self.encoder(input))
|
||||
output, hidden = self.rnn(emb, hidden)
|
||||
output = self.drop(output)
|
||||
return output, hidden
|
||||
|
||||
def forward(self, input, hidden=None):
|
||||
encoded, hidden = self.encode(input, hidden)
|
||||
decoded = self.decoder(encoded.view(encoded.size(0)*encoded.size(1), encoded.size(2)))
|
||||
return decoded.view(encoded.size(0), encoded.size(1), decoded.size(1)), hidden
|
||||
|
||||
def init_hidden(self, bsz):
|
||||
weight = next(self.parameters()).data
|
||||
if self.rnn_type == 'LSTM':
|
||||
return (weight.new(self.nlayers, bsz, self.nhid).zero_(),
|
||||
weight.new(self.nlayers, bsz, self.nhid).zero_())
|
||||
else:
|
||||
return weight.new(self.nlayers, bsz, self.nhid).zero_()
|
|
@ -27,4 +27,4 @@
|
|||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from .multitask_question_answering_network import MultitaskQuestionAnsweringNetwork
|
||||
from .general_seq2seq import Seq2Seq
|
|
@ -37,18 +37,24 @@ import os
|
|||
import sys
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from typing import NamedTuple, List
|
||||
|
||||
from torch.nn.utils.rnn import pad_packed_sequence as unpack
|
||||
from torch.nn.utils.rnn import pack_padded_sequence as pack
|
||||
|
||||
|
||||
class EmbeddingOutput(NamedTuple):
|
||||
all_layers: List[torch.Tensor]
|
||||
last_layer: torch.Tensor
|
||||
|
||||
|
||||
INF = 1e10
|
||||
EPSILON = 1e-10
|
||||
|
||||
class LSTMDecoder(nn.Module):
|
||||
|
||||
class MultiLSTMCell(nn.Module):
|
||||
def __init__(self, num_layers, input_size, rnn_size, dropout):
|
||||
super(LSTMDecoder, self).__init__()
|
||||
super(MultiLSTMCell, self).__init__()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.num_layers = num_layers
|
||||
self.layers = nn.ModuleList()
|
||||
|
@ -304,6 +310,7 @@ class LinearFeedforward(nn.Module):
|
|||
def forward(self, x):
|
||||
return self.dropout(self.linear(self.feedforward(x)))
|
||||
|
||||
|
||||
class PackedLSTM(nn.Module):
|
||||
|
||||
def __init__(self, d_in, d_out, bidirectional=False, num_layers=1,
|
||||
|
@ -354,32 +361,23 @@ class Feedforward(nn.Module):
|
|||
return self.activation(self.linear(self.dropout(x)))
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
class CombinedEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, numericalizer, output_dimension, include_pretrained=True, trained_dimension=0, dropout=0.0, project=True, requires_grad=False):
|
||||
def __init__(self, numericalizer, pretrained_embeddings,
|
||||
output_dimension,
|
||||
finetune_pretrained=False,
|
||||
trained_dimension=0,
|
||||
project=True):
|
||||
super().__init__()
|
||||
self.project = project
|
||||
self.requires_grad = requires_grad
|
||||
self.finetune_pretrained = finetune_pretrained
|
||||
self.pretrained_embeddings = tuple(pretrained_embeddings)
|
||||
|
||||
dimension = 0
|
||||
pretrained_dimension = numericalizer.vocab.vectors.size(-1)
|
||||
for idx, embedding in enumerate(self.pretrained_embeddings):
|
||||
dimension += embedding.dim
|
||||
self.add_module('pretrained_' + str(idx), embedding)
|
||||
|
||||
if include_pretrained:
|
||||
# NOTE: this must be a list so that pytorch will not iterate into the module when
|
||||
# traversing this module
|
||||
# in turn, this means that moving this Embedding() to the GPU will not move the
|
||||
# actual embedding, which will stay on CPU; this is necessary because a) we call
|
||||
# set_embeddings() sometimes with CPU-only tensors, and b) the embedding tensor
|
||||
# is too big for the GPU anyway
|
||||
self.pretrained_embeddings = [nn.Embedding(numericalizer.num_tokens, pretrained_dimension)]
|
||||
self.pretrained_embeddings[0].weight.data = numericalizer.vocab.vectors
|
||||
self.pretrained_embeddings[0].weight.requires_grad = self.requires_grad
|
||||
dimension += pretrained_dimension
|
||||
else:
|
||||
self.pretrained_embeddings = None
|
||||
|
||||
# OTOH, if we have a trained embedding, we move it around together with the module
|
||||
# (ie, potentially on GPU), because the saving when applying gradient outweights
|
||||
# the cost, and hopefully the embedding is small enough to fit in GPU memory
|
||||
if trained_dimension > 0:
|
||||
self.trained_embeddings = nn.Embedding(numericalizer.num_tokens, trained_dimension)
|
||||
dimension += trained_dimension
|
||||
|
@ -387,34 +385,53 @@ class Embedding(nn.Module):
|
|||
self.trained_embeddings = None
|
||||
if self.project:
|
||||
self.projection = Feedforward(dimension, output_dimension)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
else:
|
||||
assert dimension == output_dimension
|
||||
self.dimension = output_dimension
|
||||
|
||||
def forward(self, x, lengths=None, device=-1):
|
||||
def _combine_embeddings(self, embeddings):
|
||||
if len(embeddings) == 1:
|
||||
all_layers = embeddings[0].all_layers
|
||||
last_layer = embeddings[0].last_layer
|
||||
if self.project:
|
||||
last_layer = self.projection(last_layer)
|
||||
return EmbeddingOutput(all_layers=all_layers, last_layer=last_layer)
|
||||
|
||||
all_layers = None
|
||||
last_layer = []
|
||||
for emb in embeddings:
|
||||
if all_layers is None:
|
||||
all_layers = [[layer] for layer in emb.all_layers]
|
||||
elif len(all_layers) != len(emb.all_layers):
|
||||
raise ValueError('Cannot combine embeddings that use different numbers of layers')
|
||||
else:
|
||||
for layer_list, layer in zip(all_layers, emb.all_layers):
|
||||
layer_list.append(layer)
|
||||
last_layer.append(emb.last_layer)
|
||||
|
||||
all_layers = [torch.cat(layer, dim=2) for layer in all_layers]
|
||||
last_layer = torch.cat(last_layer, dim=2)
|
||||
if self.project:
|
||||
last_layer = self.projection(last_layer)
|
||||
return EmbeddingOutput(all_layers=all_layers, last_layer=last_layer)
|
||||
|
||||
def forward(self, x, padding=None):
|
||||
embedded = []
|
||||
if self.pretrained_embeddings is not None:
|
||||
pretrained_embeddings = self.pretrained_embeddings[0](x.cpu()).to(x.device).detach()
|
||||
else:
|
||||
pretrained_embeddings = None
|
||||
if self.finetune_pretrained:
|
||||
embedded += [emb(x, padding=padding) for emb in self.pretrained_embeddings]
|
||||
else:
|
||||
with torch.no_grad():
|
||||
embedded += [emb(x, padding=padding) for emb in self.pretrained_embeddings]
|
||||
|
||||
if self.trained_embeddings is not None:
|
||||
trained_vocabulary_size = self.trained_embeddings.weight.size()[0]
|
||||
valid_x = torch.lt(x, trained_vocabulary_size)
|
||||
masked_x = torch.where(valid_x, x, torch.zeros_like(x))
|
||||
trained_embeddings = self.trained_embeddings(masked_x)
|
||||
else:
|
||||
trained_embeddings = None
|
||||
if pretrained_embeddings is not None and trained_embeddings is not None:
|
||||
embeddings = torch.cat((pretrained_embeddings, trained_embeddings), dim=2)
|
||||
elif pretrained_embeddings is not None:
|
||||
embeddings = pretrained_embeddings
|
||||
else:
|
||||
embeddings = trained_embeddings
|
||||
output = self.trained_embeddings(masked_x)
|
||||
embedded.append(EmbeddingOutput(all_layers=[output], last_layer=output))
|
||||
|
||||
return self.projection(embeddings) if self.project else embeddings
|
||||
|
||||
def set_embeddings(self, w):
|
||||
if self.pretrained_embeddings is not None:
|
||||
self.pretrained_embeddings[0].weight.data = w
|
||||
self.pretrained_embeddings[0].weight.requires_grad = self.requires_grad
|
||||
return self._combine_embeddings(embedded)
|
||||
|
||||
|
||||
class SemanticFusionUnit(nn.Module):
|
||||
|
@ -443,23 +460,38 @@ class LSTMDecoderAttention(nn.Module):
|
|||
self.dot = dot
|
||||
|
||||
def applyMasks(self, context_mask):
|
||||
self.context_mask = context_mask
|
||||
# context_mask is batch x encoder_time, convert it to batch x 1 x encoder_time
|
||||
self.context_mask = context_mask.unsqueeze(1)
|
||||
|
||||
def forward(self, input : torch.Tensor, context : torch.Tensor):
|
||||
# input is batch x decoder_time x dim
|
||||
# context is batch x encoder_time x dim
|
||||
# output will be batch x decoder_time x dim
|
||||
# context_attention will be batch x decoder_time x encoder_time
|
||||
|
||||
def forward(self, input, context):
|
||||
if not self.dot:
|
||||
targetT = self.linear_in(input).unsqueeze(2) # batch x dim x 1
|
||||
targetT = self.linear_in(input) # batch x decoder_time x dim x 1
|
||||
else:
|
||||
targetT = input.unsqueeze(2)
|
||||
targetT = input
|
||||
|
||||
context_scores = torch.bmm(context, targetT).squeeze(2)
|
||||
x = input.shape
|
||||
transposed_context = torch.transpose(context, 2, 1)
|
||||
x = transposed_context.shape
|
||||
context_scores = torch.matmul(targetT, transposed_context)
|
||||
context_scores.masked_fill_(self.context_mask, -float('inf'))
|
||||
context_attention = F.softmax(context_scores, dim=-1) + EPSILON
|
||||
context_alignment = torch.bmm(context_attention.unsqueeze(1), context).squeeze(1)
|
||||
|
||||
combined_representation = torch.cat([input, context_alignment], 1)
|
||||
# convert context_attention to batch x decoder_time x 1 x encoder_time
|
||||
# convert context to batch x 1 x encoder_time x dim
|
||||
# context_alignment will be batch x decoder_time x 1 x dim
|
||||
context_alignment = torch.matmul(context_attention.unsqueeze(2), context.unsqueeze(1))
|
||||
# squeeze out the extra dimension
|
||||
context_alignment = context_alignment.squeeze(2)
|
||||
|
||||
combined_representation = torch.cat([input, context_alignment], 2)
|
||||
output = self.tanh(self.linear_out(combined_representation))
|
||||
|
||||
return output, context_attention, context_alignment
|
||||
return output, context_attention
|
||||
|
||||
|
||||
class CoattentiveLayer(nn.Module):
|
||||
|
@ -471,8 +503,8 @@ class CoattentiveLayer(nn.Module):
|
|||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, context, question, context_padding, question_padding):
|
||||
context_padding = torch.cat([context.new_zeros((context.size(0), 1), dtype=torch.long)==1, context_padding], 1)
|
||||
question_padding = torch.cat([question.new_zeros((question.size(0), 1), dtype=torch.long)==1, question_padding], 1)
|
||||
context_padding = torch.cat([context.new_zeros((context.size(0), 1), dtype=torch.bool), context_padding], 1)
|
||||
question_padding = torch.cat([question.new_zeros((question.size(0), 1), dtype=torch.bool), question_padding], 1)
|
||||
|
||||
context_sentinel = self.embed_sentinel(context.new_zeros((context.size(0), 1), dtype=torch.long))
|
||||
context = torch.cat([context_sentinel, self.dropout(context)], 1) # batch_size x (context_length + 1) x features
|
||||
|
@ -503,94 +535,3 @@ class CoattentiveLayer(nn.Module):
|
|||
return F.softmax(raw_scores, dim=1)
|
||||
|
||||
|
||||
# The following code was copied and adapted from github.com/floyhub/world-language-model
|
||||
#
|
||||
# BSD 3-Clause License
|
||||
#
|
||||
# Copyright (c) 2017,
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
class PretrainedDecoderLM(nn.Module):
|
||||
"""Container module with an encoder, a recurrent module, and a decoder."""
|
||||
|
||||
def __init__(self, rnn_type, ntoken, emsize, nhid, nlayers, dropout=0.5, tie_weights=False):
|
||||
super(PretrainedDecoderLM, self).__init__()
|
||||
self.drop = nn.Dropout(dropout)
|
||||
self.encoder = nn.Embedding(ntoken, emsize) # Token2Embeddings
|
||||
if rnn_type in ['LSTM', 'GRU']:
|
||||
self.rnn = getattr(nn, rnn_type)(emsize, nhid, nlayers, dropout=dropout)
|
||||
else:
|
||||
try:
|
||||
nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
|
||||
except KeyError:
|
||||
raise ValueError( """An invalid option for `--model` was supplied,
|
||||
options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""")
|
||||
self.rnn = nn.RNN(emsize, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)
|
||||
self.decoder = nn.Linear(nhid, ntoken)
|
||||
|
||||
# Optionally tie weights as in:
|
||||
# "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
|
||||
# https://arxiv.org/abs/1608.05859
|
||||
# and
|
||||
# "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
|
||||
# https://arxiv.org/abs/1611.01462
|
||||
if tie_weights:
|
||||
if nhid != emsize:
|
||||
raise ValueError('When using the tied flag, nhid must be equal to emsize')
|
||||
self.decoder.weight = self.encoder.weight
|
||||
|
||||
self.init_weights()
|
||||
|
||||
self.rnn_type = rnn_type
|
||||
self.nhid = nhid
|
||||
self.nlayers = nlayers
|
||||
|
||||
def init_weights(self):
|
||||
initrange = 0.1
|
||||
self.encoder.weight.data.uniform_(-initrange, initrange)
|
||||
self.decoder.bias.data.fill_(0)
|
||||
self.decoder.weight.data.uniform_(-initrange, initrange)
|
||||
|
||||
def encode(self, input, hidden=None):
|
||||
emb = self.drop(self.encoder(input))
|
||||
output, hidden = self.rnn(emb, hidden)
|
||||
output = self.drop(output)
|
||||
return output, hidden
|
||||
|
||||
def forward(self, input, hidden=None):
|
||||
encoded, hidden = self.encode(input, hidden)
|
||||
decoded = self.decoder(encoded.view(encoded.size(0)*encoded.size(1), encoded.size(2)))
|
||||
return decoded.view(encoded.size(0), encoded.size(1), decoded.size(1)), hidden
|
||||
|
||||
def init_hidden(self, bsz):
|
||||
weight = next(self.parameters()).data
|
||||
if self.rnn_type == 'LSTM':
|
||||
return (weight.new(self.nlayers, bsz, self.nhid).zero_(),
|
||||
weight.new(self.nlayers, bsz, self.nhid).zero_())
|
||||
else:
|
||||
return weight.new(self.nlayers, bsz, self.nhid).zero_()
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
#
|
||||
# Copyright (c) 2018, Salesforce, Inc.
|
||||
# The Board of Trustees of the Leland Stanford Junior University
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .mqan_encoder import MQANEncoder
|
||||
from .identity_encoder import IdentityEncoder
|
||||
from .mqan_decoder import MQANDecoder
|
||||
|
||||
ENCODERS = {
|
||||
'MQANEncoder': MQANEncoder,
|
||||
'Identity': IdentityEncoder
|
||||
}
|
||||
DECODERS = {
|
||||
'MQANDecoder': MQANDecoder
|
||||
}
|
||||
|
||||
class Seq2Seq(nn.Module):
|
||||
def __init__(self, numericalizer, args, encoder_embeddings, decoder_embeddings):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.encoder = ENCODERS[args.seq2seq_encoder](numericalizer, args, encoder_embeddings)
|
||||
self.decoder = DECODERS[args.seq2seq_decoder](numericalizer, args, decoder_embeddings)
|
||||
|
||||
def forward(self, batch, iteration):
|
||||
self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state = self.encoder(batch)
|
||||
|
||||
loss, predictions = self.decoder(batch, self_attended_context, final_context, context_rnn_state,
|
||||
final_question, question_rnn_state)
|
||||
|
||||
return loss, predictions
|
|
@ -0,0 +1,68 @@
|
|||
#
|
||||
# Copyright (c) 2018, The Board of Trustees of the Leland Stanford Junior University
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .common import CombinedEmbedding
|
||||
|
||||
class IdentityEncoder(nn.Module):
|
||||
def __init__(self, numericalizer, args, encoder_embeddings):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.pad_idx = numericalizer.pad_id
|
||||
|
||||
if sum(emb.dim for emb in encoder_embeddings) != args.dimension:
|
||||
raise ValueError('Hidden dimension must be equal to the sum of the embedding sizes to use IdentityEncoder')
|
||||
if args.rnn_layers > 0:
|
||||
raise ValueError('Cannot have RNN layers with IdentityEncoder')
|
||||
|
||||
self.encoder_embeddings = CombinedEmbedding(numericalizer, encoder_embeddings, args.dimension,
|
||||
trained_dimension=0,
|
||||
project=False,
|
||||
finetune_pretrained=args.train_encoder_embeddings)
|
||||
|
||||
def forward(self, batch):
|
||||
context, context_lengths = batch.context.value, batch.context.length
|
||||
question, question_lengths = batch.question.value, batch.question.length
|
||||
|
||||
context_padding = context.data == self.pad_idx
|
||||
question_padding = question.data == self.pad_idx
|
||||
|
||||
context_embedded = self.encoder_embeddings(context, padding=context_padding)
|
||||
question_embedded = self.encoder_embeddings(question, padding=question_padding)
|
||||
|
||||
# pick the top-most N transformer layers to pass to the decoder for cross-attention
|
||||
# (add 1 to account for the embedding layer - the decoder will drop it later)
|
||||
self_attended_context = context_embedded.all_layers[:-(self.args.transformer_layers+1)]
|
||||
final_context = context_embedded.last_layer
|
||||
final_question = question_embedded.last_layer
|
||||
context_rnn_state = None
|
||||
question_rnn_state = None
|
||||
|
||||
return self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state
|
|
@ -0,0 +1,277 @@
|
|||
#
|
||||
# Copyright (c) 2018, Salesforce, Inc.
|
||||
# The Board of Trustees of the Leland Stanford Junior University
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from .common import *
|
||||
|
||||
|
||||
class MQANDecoder(nn.Module):
|
||||
def __init__(self, numericalizer, args, decoder_embeddings):
|
||||
super().__init__()
|
||||
self.numericalizer = numericalizer
|
||||
self.pad_idx = numericalizer.pad_id
|
||||
self.init_idx = numericalizer.init_id
|
||||
self.args = args
|
||||
|
||||
self.decoder_embeddings = CombinedEmbedding(numericalizer, decoder_embeddings, args.dimension,
|
||||
trained_dimension=args.trainable_decoder_embeddings,
|
||||
project=True,
|
||||
finetune_pretrained=False)
|
||||
|
||||
self.self_attentive_decoder = TransformerDecoder(args.dimension, args.transformer_heads,
|
||||
args.transformer_hidden,
|
||||
args.transformer_layers,
|
||||
args.dropout_ratio)
|
||||
|
||||
if args.rnn_layers > 0:
|
||||
self.rnn_decoder = LSTMDecoder(args.dimension, args.dimension,
|
||||
dropout=args.dropout_ratio, num_layers=args.rnn_layers)
|
||||
switch_input_len = 3 * args.dimension
|
||||
else:
|
||||
self.context_attn = LSTMDecoderAttention(args.dimension, dot=True)
|
||||
self.question_attn = LSTMDecoderAttention(args.dimension, dot=True)
|
||||
self.dropout = nn.Dropout(args.dropout_ratio)
|
||||
switch_input_len = 2 * args.dimension
|
||||
self.vocab_pointer_switch = nn.Sequential(Feedforward(switch_input_len, 1), nn.Sigmoid())
|
||||
self.context_question_switch = nn.Sequential(Feedforward(switch_input_len, 1), nn.Sigmoid())
|
||||
|
||||
self.generative_vocab_size = numericalizer.generative_vocab_size
|
||||
self.out = nn.Linear(args.dimension, self.generative_vocab_size)
|
||||
|
||||
def set_embeddings(self, embeddings):
|
||||
if self.decoder_embeddings is not None:
|
||||
self.decoder_embeddings.set_embeddings(embeddings)
|
||||
|
||||
def forward(self, batch, self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state):
|
||||
context, context_lengths, context_limited = batch.context.value, batch.context.length, batch.context.limited
|
||||
question, question_lengths, question_limited = batch.question.value, batch.question.length, batch.question.limited
|
||||
answer, answer_lengths, answer_limited = batch.answer.value, batch.answer.length, batch.answer.limited
|
||||
decoder_vocab = batch.decoder_vocab
|
||||
|
||||
self.map_to_full = decoder_vocab.decode
|
||||
|
||||
context_indices = context_limited if context_limited is not None else context
|
||||
question_indices = question_limited if question_limited is not None else question
|
||||
answer_indices = answer_limited if answer_limited is not None else answer
|
||||
|
||||
context_padding = context_indices.data == self.pad_idx
|
||||
question_padding = question_indices.data == self.pad_idx
|
||||
|
||||
if self.args.rnn_layers > 0:
|
||||
self.rnn_decoder.applyMasks(context_padding, question_padding)
|
||||
else:
|
||||
self.context_attn.applyMasks(context_padding)
|
||||
self.question_attn.applyMasks(question_padding)
|
||||
|
||||
if self.training:
|
||||
answer_padding = (answer_indices.data == self.pad_idx)[:, :-1]
|
||||
|
||||
answer_embedded = self.decoder_embeddings(answer[:, :-1], padding=answer_padding).last_layer
|
||||
self_attended_decoded = self.self_attentive_decoder(answer_embedded,
|
||||
self_attended_context,
|
||||
context_padding=context_padding,
|
||||
answer_padding=answer_padding,
|
||||
positional_encodings=True)
|
||||
|
||||
if self.args.rnn_layers > 0:
|
||||
rnn_decoder_outputs = self.rnn_decoder(self_attended_decoded, final_context, final_question,
|
||||
hidden=context_rnn_state)
|
||||
decoder_output, vocab_pointer_switch_input, context_question_switch_input, context_attention, \
|
||||
question_attention, rnn_state = rnn_decoder_outputs
|
||||
else:
|
||||
context_decoder_output, context_attention = self.context_attn(self_attended_decoded, final_context)
|
||||
question_decoder_output, question_attention = self.question_attn(self_attended_decoded, final_question)
|
||||
|
||||
vocab_pointer_switch_input = torch.cat((context_decoder_output, self_attended_decoded), dim=-1)
|
||||
context_question_switch_input = torch.cat((question_decoder_output, self_attended_decoded), dim=-1)
|
||||
|
||||
decoder_output = self.dropout(context_decoder_output)
|
||||
|
||||
vocab_pointer_switch = self.vocab_pointer_switch(vocab_pointer_switch_input)
|
||||
context_question_switch = self.context_question_switch(context_question_switch_input)
|
||||
|
||||
probs = self.probs(self.out, decoder_output, vocab_pointer_switch, context_question_switch,
|
||||
context_attention, question_attention,
|
||||
context_indices, question_indices,
|
||||
decoder_vocab)
|
||||
|
||||
probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=self.pad_idx)
|
||||
loss = F.nll_loss(probs.log(), targets)
|
||||
return loss, None
|
||||
|
||||
else:
|
||||
return None, self.greedy(self_attended_context, final_context, final_question,
|
||||
context_indices, question_indices,
|
||||
decoder_vocab, rnn_state=context_rnn_state).data
|
||||
|
||||
def probs(self, generator, outputs, vocab_pointer_switches, context_question_switches,
|
||||
context_attention, question_attention,
|
||||
context_indices, question_indices,
|
||||
decoder_vocab):
|
||||
|
||||
size = list(outputs.size())
|
||||
|
||||
size[-1] = self.generative_vocab_size
|
||||
scores = generator(outputs.view(-1, outputs.size(-1))).view(size)
|
||||
p_vocab = F.softmax(scores, dim=scores.dim() - 1)
|
||||
scaled_p_vocab = vocab_pointer_switches.expand_as(p_vocab) * p_vocab
|
||||
|
||||
effective_vocab_size = len(decoder_vocab)
|
||||
if self.generative_vocab_size < effective_vocab_size:
|
||||
size[-1] = effective_vocab_size - self.generative_vocab_size
|
||||
buff = scaled_p_vocab.new_full(size, EPSILON)
|
||||
scaled_p_vocab = torch.cat([scaled_p_vocab, buff], dim=buff.dim() - 1)
|
||||
|
||||
# p_context_ptr
|
||||
scaled_p_vocab.scatter_add_(scaled_p_vocab.dim() - 1, context_indices.unsqueeze(1).expand_as(context_attention),
|
||||
(context_question_switches * (1 - vocab_pointer_switches)).expand_as(
|
||||
context_attention) * context_attention)
|
||||
|
||||
# p_question_ptr
|
||||
scaled_p_vocab.scatter_add_(scaled_p_vocab.dim() - 1,
|
||||
question_indices.unsqueeze(1).expand_as(question_attention),
|
||||
((1 - context_question_switches) * (1 - vocab_pointer_switches)).expand_as(
|
||||
question_attention) * question_attention)
|
||||
|
||||
return scaled_p_vocab
|
||||
|
||||
def greedy(self, self_attended_context, context, question, context_indices, question_indices, decoder_vocab,
|
||||
rnn_state=None):
|
||||
B, TC, C = context.size()
|
||||
T = self.args.max_output_length
|
||||
outs = context.new_full((B, T), self.pad_idx, dtype=torch.long)
|
||||
hiddens = [self_attended_context[0].new_zeros((B, T, C))
|
||||
for l in range(len(self.self_attentive_decoder.layers) + 1)]
|
||||
hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0])
|
||||
eos_yet = context.new_zeros((B,)).byte()
|
||||
|
||||
decoder_output = None
|
||||
for t in range(T):
|
||||
if t == 0:
|
||||
init_token = self_attended_context[-1].new_full((B, 1), self.init_idx,
|
||||
dtype=torch.long)
|
||||
embedding = self.decoder_embeddings(init_token).last_layer
|
||||
else:
|
||||
current_token_id = outs[:, t - 1].unsqueeze(1)
|
||||
embedding = self.decoder_embeddings(current_token_id).last_layer
|
||||
|
||||
hiddens[0][:, t] = hiddens[0][:, t] + (math.sqrt(self.self_attentive_decoder.d_model) * embedding).squeeze(
|
||||
1)
|
||||
for l in range(len(self.self_attentive_decoder.layers)):
|
||||
hiddens[l + 1][:, t] = self.self_attentive_decoder.layers[l].feedforward(
|
||||
self.self_attentive_decoder.layers[l].attention(
|
||||
self.self_attentive_decoder.layers[l].selfattn(hiddens[l][:, t], hiddens[l][:, :t + 1],
|
||||
hiddens[l][:, :t + 1])
|
||||
, self_attended_context[l], self_attended_context[l]))
|
||||
|
||||
self_attended_decoded = hiddens[-1][:, t].unsqueeze(1)
|
||||
if self.args.rnn_layers > 0:
|
||||
rnn_decoder_outputs = self.rnn_decoder(self_attended_decoded, context, question,
|
||||
hidden=rnn_state, output=decoder_output)
|
||||
decoder_output, vocab_pointer_switch_input, context_question_switch_input, context_attention, \
|
||||
question_attention, rnn_state = rnn_decoder_outputs
|
||||
else:
|
||||
context_decoder_output, context_attention = self.context_attn(self_attended_decoded, context)
|
||||
question_decoder_output, question_attention = self.question_attn(self_attended_decoded, question)
|
||||
|
||||
vocab_pointer_switch_input = torch.cat((context_decoder_output, self_attended_decoded), dim=-1)
|
||||
context_question_switch_input = torch.cat((question_decoder_output, self_attended_decoded), dim=-1)
|
||||
|
||||
decoder_output = self.dropout(context_decoder_output)
|
||||
|
||||
vocab_pointer_switch = self.vocab_pointer_switch(vocab_pointer_switch_input)
|
||||
context_question_switch = self.context_question_switch(context_question_switch_input)
|
||||
|
||||
probs = self.probs(self.out, decoder_output, vocab_pointer_switch, context_question_switch,
|
||||
context_attention, question_attention,
|
||||
context_indices, question_indices, decoder_vocab)
|
||||
pred_probs, preds = probs.max(-1)
|
||||
preds = preds.squeeze(1)
|
||||
eos_yet = eos_yet | (preds == self.numericalizer.eos_id).byte()
|
||||
outs[:, t] = preds.cpu().apply_(self.map_to_full)
|
||||
if eos_yet.all():
|
||||
break
|
||||
return outs
|
||||
|
||||
|
||||
class LSTMDecoder(nn.Module):
|
||||
def __init__(self, d_in, d_hid, dropout=0.0, num_layers=1):
|
||||
super().__init__()
|
||||
self.d_hid = d_hid
|
||||
self.d_in = d_in
|
||||
self.num_layers = num_layers
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.input_feed = True
|
||||
if self.input_feed:
|
||||
d_in += 1 * d_hid
|
||||
|
||||
self.rnn = MultiLSTMCell(self.num_layers, d_in, d_hid, dropout)
|
||||
self.context_attn = LSTMDecoderAttention(d_hid, dot=True)
|
||||
self.question_attn = LSTMDecoderAttention(d_hid, dot=True)
|
||||
|
||||
def applyMasks(self, context_mask, question_mask):
|
||||
self.context_attn.applyMasks(context_mask)
|
||||
self.question_attn.applyMasks(question_mask)
|
||||
|
||||
def forward(self, input : torch.Tensor, context, question, output=None, hidden=None):
|
||||
context_output = output if output is not None else self.make_init_output(context)
|
||||
|
||||
context_outputs, vocab_pointer_switch_inputs, context_question_switch_inputs, context_attentions, question_attentions = [], [], [], [], []
|
||||
for decoder_input in input.split(1, dim=1):
|
||||
context_output = self.dropout(context_output)
|
||||
if self.input_feed:
|
||||
rnn_input = torch.cat([decoder_input, context_output], 2)
|
||||
else:
|
||||
rnn_input = decoder_input
|
||||
|
||||
rnn_input = rnn_input.squeeze(1)
|
||||
dec_state, hidden = self.rnn(rnn_input, hidden)
|
||||
dec_state = dec_state.unsqueeze(1)
|
||||
|
||||
context_output, context_attention = self.context_attn(dec_state, context)
|
||||
question_output, question_attention = self.question_attn(dec_state, question)
|
||||
vocab_pointer_switch_inputs.append(torch.cat([dec_state, context_output, decoder_input], -1))
|
||||
context_question_switch_inputs.append(torch.cat([dec_state, question_output, decoder_input], -1))
|
||||
|
||||
context_output = self.dropout(context_output)
|
||||
context_outputs.append(context_output)
|
||||
context_attentions.append(context_attention)
|
||||
question_attentions.append(question_attention)
|
||||
|
||||
return [torch.cat(x, dim=1) for x in (context_outputs,
|
||||
vocab_pointer_switch_inputs,
|
||||
context_question_switch_inputs,
|
||||
context_attentions,
|
||||
question_attentions)] + [hidden]
|
||||
|
||||
def make_init_output(self, context):
|
||||
batch_size = context.size(0)
|
||||
h_size = (batch_size, 1, self.d_hid)
|
||||
return context.new_zeros(h_size)
|
|
@ -0,0 +1,108 @@
|
|||
#
|
||||
# Copyright (c) 2018, Salesforce, Inc.
|
||||
# The Board of Trustees of the Leland Stanford Junior University
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from .common import *
|
||||
|
||||
|
||||
class MQANEncoder(nn.Module):
|
||||
def __init__(self, numericalizer, args, encoder_embeddings):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.pad_idx = numericalizer.pad_id
|
||||
|
||||
self.encoder_embeddings = CombinedEmbedding(numericalizer, encoder_embeddings, args.dimension,
|
||||
trained_dimension=0,
|
||||
project=True,
|
||||
finetune_pretrained=args.train_encoder_embeddings)
|
||||
|
||||
def dp(args):
|
||||
return args.dropout_ratio if args.rnn_layers > 1 else 0.
|
||||
|
||||
self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension,
|
||||
batch_first=True, bidirectional=True, num_layers=1, dropout=0)
|
||||
self.coattention = CoattentiveLayer(args.dimension, dropout=0.3)
|
||||
dim = 2 * args.dimension + args.dimension + args.dimension
|
||||
|
||||
self.context_bilstm_after_coattention = PackedLSTM(dim, args.dimension,
|
||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||
num_layers=args.rnn_layers)
|
||||
self.self_attentive_encoder_context = TransformerEncoder(args.dimension, args.transformer_heads,
|
||||
args.transformer_hidden, args.transformer_layers,
|
||||
args.dropout_ratio)
|
||||
self.bilstm_context = PackedLSTM(args.dimension, args.dimension,
|
||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||
num_layers=args.rnn_layers)
|
||||
|
||||
self.question_bilstm_after_coattention = PackedLSTM(dim, args.dimension,
|
||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||
num_layers=args.rnn_layers)
|
||||
self.self_attentive_encoder_question = TransformerEncoder(args.dimension, args.transformer_heads,
|
||||
args.transformer_hidden, args.transformer_layers,
|
||||
args.dropout_ratio)
|
||||
self.bilstm_question = PackedLSTM(args.dimension, args.dimension,
|
||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||
num_layers=args.rnn_layers)
|
||||
|
||||
def forward(self, batch):
|
||||
context, context_lengths = batch.context.value, batch.context.length
|
||||
question, question_lengths = batch.question.value, batch.question.length
|
||||
|
||||
context_padding = context.data == self.pad_idx
|
||||
question_padding = question.data == self.pad_idx
|
||||
|
||||
context_embedded = self.encoder_embeddings(context, padding=context_padding).last_layer
|
||||
question_embedded = self.encoder_embeddings(question, padding=question_padding).last_layer
|
||||
|
||||
context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0]
|
||||
question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0]
|
||||
|
||||
coattended_context, coattended_question = self.coattention(context_encoded, question_encoded,
|
||||
context_padding, question_padding)
|
||||
|
||||
context_summary = torch.cat([coattended_context, context_encoded, context_embedded], -1)
|
||||
condensed_context, _ = self.context_bilstm_after_coattention(context_summary, context_lengths)
|
||||
self_attended_context = self.self_attentive_encoder_context(condensed_context, padding=context_padding)
|
||||
final_context, (context_rnn_h, context_rnn_c) = self.bilstm_context(self_attended_context[-1],
|
||||
context_lengths)
|
||||
context_rnn_state = [self.reshape_rnn_state(x) for x in (context_rnn_h, context_rnn_c)]
|
||||
|
||||
question_summary = torch.cat([coattended_question, question_encoded, question_embedded], -1)
|
||||
condensed_question, _ = self.question_bilstm_after_coattention(question_summary, question_lengths)
|
||||
self_attended_question = self.self_attentive_encoder_question(condensed_question, padding=question_padding)
|
||||
final_question, (question_rnn_h, question_rnn_c) = self.bilstm_question(self_attended_question[-1],
|
||||
question_lengths)
|
||||
question_rnn_state = [self.reshape_rnn_state(x) for x in (question_rnn_h, question_rnn_c)]
|
||||
|
||||
return self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state
|
||||
|
||||
def reshape_rnn_state(self, h):
|
||||
return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \
|
||||
.transpose(1, 2).contiguous() \
|
||||
.view(h.size(0) // 2, h.size(1), h.size(2) * 2).contiguous()
|
|
@ -1,420 +0,0 @@
|
|||
#
|
||||
# Copyright (c) 2018, Salesforce, Inc.
|
||||
# The Board of Trustees of the Leland Stanford Junior University
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
from ..util import get_trainable_params
|
||||
|
||||
from .common import *
|
||||
|
||||
class MQANEncoder(nn.Module):
|
||||
def __init__(self, numericalizer, args):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self.pad_idx = numericalizer.pad_id
|
||||
|
||||
if self.args.glove_and_char:
|
||||
self.encoder_embeddings = Embedding(numericalizer, args.dimension,
|
||||
trained_dimension=0,
|
||||
dropout=args.dropout_ratio,
|
||||
project=True,
|
||||
requires_grad=args.retrain_encoder_embedding)
|
||||
|
||||
def dp(args):
|
||||
return args.dropout_ratio if args.rnn_layers > 1 else 0.
|
||||
|
||||
self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension,
|
||||
batch_first=True, bidirectional=True, num_layers=1, dropout=0)
|
||||
self.coattention = CoattentiveLayer(args.dimension, dropout=0.3)
|
||||
dim = 2 * args.dimension + args.dimension + args.dimension
|
||||
|
||||
self.context_bilstm_after_coattention = PackedLSTM(dim, args.dimension,
|
||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||
num_layers=args.rnn_layers)
|
||||
self.self_attentive_encoder_context = TransformerEncoder(args.dimension, args.transformer_heads,
|
||||
args.transformer_hidden, args.transformer_layers,
|
||||
args.dropout_ratio)
|
||||
self.bilstm_context = PackedLSTM(args.dimension, args.dimension,
|
||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||
num_layers=args.rnn_layers)
|
||||
|
||||
self.question_bilstm_after_coattention = PackedLSTM(dim, args.dimension,
|
||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||
num_layers=args.rnn_layers)
|
||||
self.self_attentive_encoder_question = TransformerEncoder(args.dimension, args.transformer_heads,
|
||||
args.transformer_hidden, args.transformer_layers,
|
||||
args.dropout_ratio)
|
||||
self.bilstm_question = PackedLSTM(args.dimension, args.dimension,
|
||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||
num_layers=args.rnn_layers)
|
||||
|
||||
def set_embeddings(self, embeddings):
|
||||
self.encoder_embeddings.set_embeddings(embeddings)
|
||||
|
||||
def forward(self, batch):
|
||||
context, context_lengths = batch.context.value, batch.context.length
|
||||
question, question_lengths = batch.question.value, batch.question.length
|
||||
|
||||
context_embedded = self.encoder_embeddings(context)
|
||||
question_embedded = self.encoder_embeddings(question)
|
||||
|
||||
context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0]
|
||||
question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0]
|
||||
|
||||
context_padding = context.data == self.pad_idx
|
||||
question_padding = question.data == self.pad_idx
|
||||
|
||||
coattended_context, coattended_question = self.coattention(context_encoded, question_encoded,
|
||||
context_padding, question_padding)
|
||||
|
||||
context_summary = torch.cat([coattended_context, context_encoded, context_embedded], -1)
|
||||
condensed_context, _ = self.context_bilstm_after_coattention(context_summary, context_lengths)
|
||||
self_attended_context = self.self_attentive_encoder_context(condensed_context, padding=context_padding)
|
||||
final_context, (context_rnn_h, context_rnn_c) = self.bilstm_context(self_attended_context[-1],
|
||||
context_lengths)
|
||||
context_rnn_state = [self.reshape_rnn_state(x) for x in (context_rnn_h, context_rnn_c)]
|
||||
|
||||
question_summary = torch.cat([coattended_question, question_encoded, question_embedded], -1)
|
||||
condensed_question, _ = self.question_bilstm_after_coattention(question_summary, question_lengths)
|
||||
self_attended_question = self.self_attentive_encoder_question(condensed_question, padding=question_padding)
|
||||
final_question, (question_rnn_h, question_rnn_c) = self.bilstm_question(self_attended_question[-1],
|
||||
question_lengths)
|
||||
question_rnn_state = [self.reshape_rnn_state(x) for x in (question_rnn_h, question_rnn_c)]
|
||||
|
||||
return self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state
|
||||
|
||||
def reshape_rnn_state(self, h):
|
||||
return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \
|
||||
.transpose(1, 2).contiguous() \
|
||||
.view(h.size(0) // 2, h.size(1), h.size(2) * 2).contiguous()
|
||||
|
||||
|
||||
class MQANDecoder(nn.Module):
|
||||
def __init__(self, numericalizer, args, devices):
|
||||
super().__init__()
|
||||
self.numericalizer = numericalizer
|
||||
self.pad_idx = numericalizer.pad_id
|
||||
self.init_idx = numericalizer.init_id
|
||||
self.args = args
|
||||
self.devices = devices
|
||||
|
||||
if args.pretrained_decoder_lm:
|
||||
pretrained_save_dict = torch.load(os.path.join(args.embeddings, args.pretrained_decoder_lm), map_location=devices[0])
|
||||
|
||||
self.pretrained_decoder_vocab_itos = pretrained_save_dict['vocab']
|
||||
self.pretrained_decoder_vocab_stoi = defaultdict(lambda: 0, {
|
||||
w: i for i, w in enumerate(self.pretrained_decoder_vocab_itos)
|
||||
})
|
||||
self.pretrained_decoder_embeddings = PretrainedDecoderLM(rnn_type=pretrained_save_dict['settings']['rnn_type'],
|
||||
ntoken=len(self.pretrained_decoder_vocab_itos),
|
||||
emsize=pretrained_save_dict['settings']['emsize'],
|
||||
nhid=pretrained_save_dict['settings']['nhid'],
|
||||
nlayers=pretrained_save_dict['settings']['nlayers'],
|
||||
dropout=0.0)
|
||||
self.pretrained_decoder_embeddings.load_state_dict(pretrained_save_dict['model'], strict=True)
|
||||
pretrained_lm_params = get_trainable_params(self.pretrained_decoder_embeddings)
|
||||
for p in pretrained_lm_params:
|
||||
p.requires_grad = False
|
||||
|
||||
if self.pretrained_decoder_embeddings.nhid != args.dimension:
|
||||
self.pretrained_decoder_embedding_projection = Feedforward(self.pretrained_decoder_embeddings.nhid,
|
||||
args.dimension)
|
||||
else:
|
||||
self.pretrained_decoder_embedding_projection = None
|
||||
self.decoder_embeddings = None
|
||||
else:
|
||||
self.pretrained_decoder_vocab_itos = None
|
||||
self.pretrained_decoder_vocab_stoi = None
|
||||
self.pretrained_decoder_embeddings = None
|
||||
self.decoder_embeddings = Embedding(self.numericalizer, args.dimension,
|
||||
include_pretrained=args.glove_decoder,
|
||||
trained_dimension=args.trainable_decoder_embedding,
|
||||
dropout=args.dropout_ratio, project=True)
|
||||
|
||||
self.self_attentive_decoder = TransformerDecoder(args.dimension, args.transformer_heads, args.transformer_hidden, args.transformer_layers, args.dropout_ratio)
|
||||
self.dual_ptr_rnn_decoder = DualPtrRNNDecoder(args.dimension, args.dimension,
|
||||
dropout=args.dropout_ratio, num_layers=args.rnn_layers)
|
||||
|
||||
self.generative_vocab_size = numericalizer.generative_vocab_size
|
||||
self.out = nn.Linear(args.dimension, self.generative_vocab_size)
|
||||
|
||||
def set_embeddings(self, embeddings):
|
||||
if self.decoder_embeddings is not None:
|
||||
self.decoder_embeddings.set_embeddings(embeddings)
|
||||
|
||||
def forward(self, batch, self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state):
|
||||
context, context_lengths, context_limited, context_tokens = batch.context.value, batch.context.length, batch.context.limited, batch.context.tokens
|
||||
question, question_lengths, question_limited, question_tokens = batch.question.value, batch.question.length, batch.question.limited, batch.question.tokens
|
||||
answer, answer_lengths, answer_limited, answer_tokens = batch.answer.value, batch.answer.length, batch.answer.limited, batch.answer.tokens
|
||||
decoder_vocab = batch.decoder_vocab
|
||||
|
||||
self.map_to_full = decoder_vocab.decode
|
||||
|
||||
context_indices = context_limited if context_limited is not None else context
|
||||
question_indices = question_limited if question_limited is not None else question
|
||||
answer_indices = answer_limited if answer_limited is not None else answer
|
||||
|
||||
context_padding = context_indices.data == self.pad_idx
|
||||
question_padding = question_indices.data == self.pad_idx
|
||||
|
||||
self.dual_ptr_rnn_decoder.applyMasks(context_padding, question_padding)
|
||||
|
||||
if self.training:
|
||||
answer_padding = (answer_indices.data == self.pad_idx)[:, :-1]
|
||||
|
||||
if self.args.pretrained_decoder_lm:
|
||||
# note that pretrained_decoder_embeddings is time first
|
||||
answer_pretrained_numerical = [
|
||||
[self.pretrained_decoder_vocab_stoi[sentence[time]] for sentence in answer_tokens] for time in
|
||||
range(len(answer_tokens[0]))
|
||||
]
|
||||
answer_pretrained_numerical = torch.tensor(answer_pretrained_numerical, dtype=torch.long)
|
||||
|
||||
with torch.no_grad():
|
||||
answer_embedded, _ = self.pretrained_decoder_embeddings.encode(answer_pretrained_numerical)
|
||||
answer_embedded.transpose_(0, 1)
|
||||
|
||||
if self.pretrained_decoder_embedding_projection is not None:
|
||||
answer_embedded = self.pretrained_decoder_embedding_projection(answer_embedded)
|
||||
else:
|
||||
answer_embedded = self.decoder_embeddings(answer)
|
||||
self_attended_decoded = self.self_attentive_decoder(answer_embedded[:, :-1].contiguous(),
|
||||
self_attended_context, context_padding=context_padding,
|
||||
answer_padding=answer_padding,
|
||||
positional_encodings=True)
|
||||
decoder_outputs = self.dual_ptr_rnn_decoder(self_attended_decoded,
|
||||
final_context, final_question, hidden=context_rnn_state)
|
||||
rnn_output, context_attention, question_attention, context_alignment, question_alignment, vocab_pointer_switch, context_question_switch, rnn_state = decoder_outputs
|
||||
|
||||
probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch,
|
||||
context_attention, question_attention,
|
||||
context_indices, question_indices,
|
||||
decoder_vocab)
|
||||
|
||||
probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=self.pad_idx)
|
||||
loss = F.nll_loss(probs.log(), targets)
|
||||
return loss, None
|
||||
|
||||
else:
|
||||
return None, self.greedy(self_attended_context, final_context, final_question,
|
||||
context_indices, question_indices,
|
||||
decoder_vocab, rnn_state=context_rnn_state).data
|
||||
|
||||
def probs(self, generator, outputs, vocab_pointer_switches, context_question_switches,
|
||||
context_attention, question_attention,
|
||||
context_indices, question_indices,
|
||||
decoder_vocab):
|
||||
|
||||
size = list(outputs.size())
|
||||
|
||||
size[-1] = self.generative_vocab_size
|
||||
scores = generator(outputs.view(-1, outputs.size(-1))).view(size)
|
||||
p_vocab = F.softmax(scores, dim=scores.dim() - 1)
|
||||
scaled_p_vocab = vocab_pointer_switches.expand_as(p_vocab) * p_vocab
|
||||
|
||||
effective_vocab_size = len(decoder_vocab)
|
||||
if self.generative_vocab_size < effective_vocab_size:
|
||||
size[-1] = effective_vocab_size - self.generative_vocab_size
|
||||
buff = scaled_p_vocab.new_full(size, EPSILON)
|
||||
scaled_p_vocab = torch.cat([scaled_p_vocab, buff], dim=buff.dim() - 1)
|
||||
|
||||
# p_context_ptr
|
||||
scaled_p_vocab.scatter_add_(scaled_p_vocab.dim() - 1, context_indices.unsqueeze(1).expand_as(context_attention),
|
||||
(context_question_switches * (1 - vocab_pointer_switches)).expand_as(
|
||||
context_attention) * context_attention)
|
||||
|
||||
# p_question_ptr
|
||||
scaled_p_vocab.scatter_add_(scaled_p_vocab.dim() - 1,
|
||||
question_indices.unsqueeze(1).expand_as(question_attention),
|
||||
((1 - context_question_switches) * (1 - vocab_pointer_switches)).expand_as(
|
||||
question_attention) * question_attention)
|
||||
|
||||
return scaled_p_vocab
|
||||
|
||||
def greedy(self, self_attended_context, context, question, context_indices, question_indices, decoder_vocab,
|
||||
rnn_state=None):
|
||||
B, TC, C = context.size()
|
||||
T = self.args.max_output_length
|
||||
outs = context.new_full((B, T), self.pad_idx, dtype=torch.long)
|
||||
hiddens = [self_attended_context[0].new_zeros((B, T, C))
|
||||
for l in range(len(self.self_attentive_decoder.layers) + 1)]
|
||||
hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0])
|
||||
eos_yet = context.new_zeros((B,)).byte()
|
||||
|
||||
pretrained_lm_hidden = None
|
||||
if self.args.pretrained_decoder_lm:
|
||||
pretrained_lm_hidden = self.pretrained_decoder_embeddings.init_hidden(B)
|
||||
rnn_output, context_alignment, question_alignment = None, None, None
|
||||
for t in range(T):
|
||||
if t == 0:
|
||||
if self.args.pretrained_decoder_lm:
|
||||
init_token = self_attended_context[-1].new_full((1, B),
|
||||
self.pretrained_decoder_vocab_stoi[self.numericalizer.init_token],
|
||||
dtype=torch.long)
|
||||
|
||||
# note that pretrained_decoder_embeddings is time first
|
||||
embedding, pretrained_lm_hidden = self.pretrained_decoder_embeddings.encode(init_token,
|
||||
pretrained_lm_hidden)
|
||||
embedding.transpose_(0, 1)
|
||||
|
||||
if self.pretrained_decoder_embedding_projection is not None:
|
||||
embedding = self.pretrained_decoder_embedding_projection(embedding)
|
||||
else:
|
||||
init_token = self_attended_context[-1].new_full((B, 1), self.init_idx,
|
||||
dtype=torch.long)
|
||||
embedding = self.decoder_embeddings(init_token, [1] * B)
|
||||
else:
|
||||
if self.args.pretrained_decoder_lm:
|
||||
current_token = [self.numericalizer.decode([x])[0] for x in outs[:, t - 1]]
|
||||
current_token_id = torch.tensor([[self.pretrained_decoder_vocab_stoi[x] for x in current_token]],
|
||||
dtype=torch.long, requires_grad=False)
|
||||
embedding, pretrained_lm_hidden = self.pretrained_decoder_embeddings.encode(current_token_id,
|
||||
pretrained_lm_hidden)
|
||||
|
||||
# note that pretrained_decoder_embeddings is time first
|
||||
embedding.transpose_(0, 1)
|
||||
|
||||
if self.pretrained_decoder_embedding_projection is not None:
|
||||
embedding = self.pretrained_decoder_embedding_projection(embedding)
|
||||
else:
|
||||
current_token_id = outs[:, t - 1].unsqueeze(1)
|
||||
embedding = self.decoder_embeddings(current_token_id, [1] * B)
|
||||
|
||||
hiddens[0][:, t] = hiddens[0][:, t] + (math.sqrt(self.self_attentive_decoder.d_model) * embedding).squeeze(
|
||||
1)
|
||||
for l in range(len(self.self_attentive_decoder.layers)):
|
||||
hiddens[l + 1][:, t] = self.self_attentive_decoder.layers[l].feedforward(
|
||||
self.self_attentive_decoder.layers[l].attention(
|
||||
self.self_attentive_decoder.layers[l].selfattn(hiddens[l][:, t], hiddens[l][:, :t + 1],
|
||||
hiddens[l][:, :t + 1])
|
||||
, self_attended_context[l], self_attended_context[l]))
|
||||
decoder_outputs = self.dual_ptr_rnn_decoder(hiddens[-1][:, t].unsqueeze(1),
|
||||
context, question,
|
||||
context_alignment=context_alignment,
|
||||
question_alignment=question_alignment,
|
||||
hidden=rnn_state, output=rnn_output)
|
||||
rnn_output, context_attention, question_attention, context_alignment, question_alignment, vocab_pointer_switch, context_question_switch, rnn_state = decoder_outputs
|
||||
probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch,
|
||||
context_attention, question_attention,
|
||||
context_indices, question_indices,
|
||||
decoder_vocab)
|
||||
pred_probs, preds = probs.max(-1)
|
||||
preds = preds.squeeze(1)
|
||||
eos_yet = eos_yet | (preds == self.numericalizer.eos_id).byte()
|
||||
outs[:, t] = preds.cpu().apply_(self.map_to_full)
|
||||
if eos_yet.all():
|
||||
break
|
||||
return outs
|
||||
|
||||
|
||||
class MultitaskQuestionAnsweringNetwork(nn.Module):
|
||||
|
||||
def __init__(self, numericalizer, args, devices):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
|
||||
self.encoder = MQANEncoder(numericalizer, args)
|
||||
self.decoder = MQANDecoder(numericalizer, args, devices)
|
||||
|
||||
|
||||
def set_embeddings(self, embeddings):
|
||||
self.encoder.set_embeddings(embeddings)
|
||||
self.decoder.set_embeddings(embeddings)
|
||||
|
||||
def forward(self, batch, iteration):
|
||||
self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state = self.encoder(batch)
|
||||
|
||||
loss, predictions = self.decoder(batch, self_attended_context, final_context, context_rnn_state,
|
||||
final_question, question_rnn_state)
|
||||
|
||||
return loss, predictions
|
||||
|
||||
|
||||
class DualPtrRNNDecoder(nn.Module):
|
||||
|
||||
def __init__(self, d_in, d_hid, dropout=0.0, num_layers=1):
|
||||
super().__init__()
|
||||
self.d_hid = d_hid
|
||||
self.d_in = d_in
|
||||
self.num_layers = num_layers
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.input_feed = True
|
||||
if self.input_feed:
|
||||
d_in += 1 * d_hid
|
||||
|
||||
self.rnn = LSTMDecoder(self.num_layers, d_in, d_hid, dropout)
|
||||
self.context_attn = LSTMDecoderAttention(d_hid, dot=True)
|
||||
self.question_attn = LSTMDecoderAttention(d_hid, dot=True)
|
||||
|
||||
self.vocab_pointer_switch = nn.Sequential(Feedforward(2 * self.d_hid + d_in, 1), nn.Sigmoid())
|
||||
self.context_question_switch = nn.Sequential(Feedforward(2 * self.d_hid + d_in, 1), nn.Sigmoid())
|
||||
|
||||
def forward(self, input, context, question, output=None, hidden=None, context_alignment=None, question_alignment=None):
|
||||
context_output = output.squeeze(1) if output is not None else self.make_init_output(context)
|
||||
context_alignment = context_alignment if context_alignment is not None else self.make_init_output(context)
|
||||
question_alignment = question_alignment if question_alignment is not None else self.make_init_output(question)
|
||||
|
||||
context_outputs, vocab_pointer_switches, context_question_switches, context_attentions, question_attentions, context_alignments, question_alignments = [], [], [], [], [], [], []
|
||||
for emb_t in input.split(1, dim=1):
|
||||
emb_t = emb_t.squeeze(1)
|
||||
context_output = self.dropout(context_output)
|
||||
if self.input_feed:
|
||||
emb_t = torch.cat([emb_t, context_output], 1)
|
||||
dec_state, hidden = self.rnn(emb_t, hidden)
|
||||
context_output, context_attention, context_alignment = self.context_attn(dec_state, context)
|
||||
question_output, question_attention, question_alignment = self.question_attn(dec_state, question)
|
||||
vocab_pointer_switch = self.vocab_pointer_switch(torch.cat([dec_state, context_output, emb_t], -1))
|
||||
context_question_switch = self.context_question_switch(torch.cat([dec_state, question_output, emb_t], -1))
|
||||
context_output = self.dropout(context_output)
|
||||
context_outputs.append(context_output)
|
||||
vocab_pointer_switches.append(vocab_pointer_switch)
|
||||
context_question_switches.append(context_question_switch)
|
||||
context_attentions.append(context_attention)
|
||||
context_alignments.append(context_alignment)
|
||||
question_attentions.append(question_attention)
|
||||
question_alignments.append(question_alignment)
|
||||
|
||||
context_outputs, vocab_pointer_switches, context_question_switches, context_attention, question_attention = [self.package_outputs(x) for x in [context_outputs, vocab_pointer_switches, context_question_switches, context_attentions, question_attentions]]
|
||||
return context_outputs, context_attention, question_attention, context_alignment, question_alignment, vocab_pointer_switches, context_question_switches, hidden
|
||||
|
||||
|
||||
def applyMasks(self, context_mask, question_mask):
|
||||
self.context_attn.applyMasks(context_mask)
|
||||
self.question_attn.applyMasks(question_mask)
|
||||
|
||||
def make_init_output(self, context):
|
||||
batch_size = context.size(0)
|
||||
h_size = (batch_size, self.d_hid)
|
||||
return context.new_zeros(h_size)
|
||||
|
||||
def package_outputs(self, outputs):
|
||||
outputs = torch.stack(outputs, dim=1)
|
||||
return outputs
|
|
@ -36,10 +36,9 @@ import sys
|
|||
import logging
|
||||
from pprint import pformat
|
||||
|
||||
from .util import set_seed, preprocess_examples, load_config_json, make_data_loader, log_model_size, init_devices, \
|
||||
make_numericalizer
|
||||
from .util import set_seed, preprocess_examples, load_config_json, make_data_loader, log_model_size, init_devices
|
||||
from .metrics import compute_metrics
|
||||
from .utils.embeddings import load_embeddings
|
||||
from .data.embeddings import load_embeddings
|
||||
from .tasks.registry import get_tasks
|
||||
from . import models
|
||||
|
||||
|
@ -67,18 +66,16 @@ def get_all_splits(args):
|
|||
return splits
|
||||
|
||||
|
||||
def prepare_data(args, numericalizer):
|
||||
def prepare_data(args, numericalizer, embeddings):
|
||||
splits = get_all_splits(args)
|
||||
vectors = load_embeddings(args)
|
||||
logger.info(f'Vocabulary has {numericalizer.num_tokens} tokens from training')
|
||||
new_vectors = []
|
||||
new_words = []
|
||||
for split in splits:
|
||||
new_vectors += numericalizer.grow_vocab(split, vectors)
|
||||
logger.info(f'Vocabulary has expanded to {numericalizer.num_tokens} tokens')
|
||||
if new_vectors:
|
||||
# concat the old embedding matrix and all the new vector along the first dimension
|
||||
new_embedding_matrix = torch.cat([numericalizer.vocab.vectors.cpu()] + new_vectors, dim=0)
|
||||
numericalizer.vocab.vectors = new_embedding_matrix
|
||||
new_words += numericalizer.grow_vocab(split)
|
||||
logger.info(f'Vocabulary has expanded to {numericalizer.num_tokens} tokens')
|
||||
|
||||
for emb in embeddings:
|
||||
emb.grow_for_vocab(numericalizer.vocab, new_words)
|
||||
|
||||
return splits
|
||||
|
||||
|
@ -186,17 +183,20 @@ def main(argv=sys.argv):
|
|||
devices = init_devices(args)
|
||||
save_dict = torch.load(args.best_checkpoint, map_location=devices[0])
|
||||
|
||||
numericalizer = make_numericalizer(args)
|
||||
numericalizer, encoder_embeddings, decoder_embeddings = load_embeddings(args.embeddings, args.encoder_embeddings,
|
||||
args.decoder_embeddings,
|
||||
args.max_generative_vocab,
|
||||
logger)
|
||||
numericalizer.load(args.path)
|
||||
for emb in set(encoder_embeddings + decoder_embeddings):
|
||||
emb.init_for_vocab(numericalizer.vocab)
|
||||
|
||||
logger.info(f'Initializing Model')
|
||||
Model = getattr(models, args.model)
|
||||
model = Model(numericalizer, args, devices)
|
||||
model = Model(numericalizer, args, encoder_embeddings, decoder_embeddings)
|
||||
model_dict = save_dict['model_state_dict']
|
||||
model.load_state_dict(model_dict)
|
||||
splits = prepare_data(args, numericalizer)
|
||||
if args.model != 'MultiLingualTranslationModel':
|
||||
model.set_embeddings(numericalizer.vocab.vectors)
|
||||
splits = prepare_data(args, numericalizer, set(encoder_embeddings + decoder_embeddings))
|
||||
|
||||
run(args, numericalizer, splits, model, devices[0])
|
||||
|
||||
|
|
|
@ -32,17 +32,15 @@
|
|||
from argparse import ArgumentParser
|
||||
import ujson as json
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from pprint import pformat
|
||||
|
||||
from .data.example import Batch
|
||||
from .util import set_seed, init_devices, load_config_json, log_model_size, make_numericalizer
|
||||
from .util import set_seed, init_devices, load_config_json, log_model_size
|
||||
from . import models
|
||||
from .utils.embeddings import load_embeddings
|
||||
from .data.embeddings import load_embeddings
|
||||
from .tasks.registry import get_tasks
|
||||
from .tasks.generic_dataset import Example
|
||||
|
||||
|
@ -52,27 +50,24 @@ class ProcessedExample():
|
|||
pass
|
||||
|
||||
class Server():
|
||||
def __init__(self, args, numericalizer, model, device):
|
||||
def __init__(self, args, numericalizer, embeddings, model, device):
|
||||
self.args = args
|
||||
self.device = device
|
||||
self.numericalizer = numericalizer
|
||||
self.model = model
|
||||
|
||||
logger.info(f'Vocabulary has {numericalizer.num_tokens} tokens from training')
|
||||
self._vector_collections = load_embeddings(args)
|
||||
self._embeddings = embeddings
|
||||
|
||||
self._cached_tasks = dict()
|
||||
|
||||
def numericalize_example(self, ex):
|
||||
new_vectors = self.numericalizer.grow_vocab([ex], self._vector_collections)
|
||||
if new_vectors:
|
||||
# concat the old embedding matrix and all the new vector along the first dimension
|
||||
new_embedding_matrix = torch.cat([self.numericalizer.vocab.vectors.cpu()] + new_vectors, dim=0)
|
||||
self.numericalizer.vocab.vectors = new_embedding_matrix
|
||||
self.model.set_embeddings(new_embedding_matrix)
|
||||
new_words = self.numericalizer.grow_vocab([ex])
|
||||
for emb in self._embeddings:
|
||||
emb.grow_for_vocab(self.numericalizer.vocab, new_words)
|
||||
|
||||
# batch of size 1
|
||||
return Batch.from_examples([ex], self.numericalizer, self.numericalizer.decoder_vocab, device=self.device)
|
||||
return Batch.from_examples([ex], self.numericalizer, device=self.device)
|
||||
|
||||
def handle_request(self, line):
|
||||
request = json.loads(line)
|
||||
|
@ -174,15 +169,19 @@ def main(argv=sys.argv):
|
|||
devices = init_devices(args)
|
||||
save_dict = torch.load(args.best_checkpoint, map_location=devices[0])
|
||||
|
||||
numericalizer = make_numericalizer(args)
|
||||
numericalizer, encoder_embeddings, decoder_embeddings = load_embeddings(args.embeddings, args.encoder_embeddings,
|
||||
args.decoder_embeddings,
|
||||
args.max_generative_vocab)
|
||||
numericalizer.load(args.path)
|
||||
for emb in set(encoder_embeddings + decoder_embeddings):
|
||||
emb.init_for_vocab(numericalizer.vocab)
|
||||
|
||||
logger.info(f'Initializing Model')
|
||||
Model = getattr(models, args.model)
|
||||
model = Model(numericalizer, args, devices)
|
||||
model = Model(numericalizer, args, encoder_embeddings, decoder_embeddings)
|
||||
model_dict = save_dict['model_state_dict']
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
server = Server(args, numericalizer, model, devices[0])
|
||||
server = Server(args, numericalizer, encoder_embeddings + decoder_embeddings, model, devices[0])
|
||||
|
||||
server.run()
|
||||
|
|
|
@ -47,9 +47,9 @@ from . import arguments
|
|||
from . import models
|
||||
from .validate import validate
|
||||
from .util import elapsed_time, set_seed, preprocess_examples, get_trainable_params, make_data_loader, log_model_size, \
|
||||
init_devices, make_numericalizer
|
||||
init_devices
|
||||
from .utils.saver import Saver
|
||||
from .utils.embeddings import load_embeddings
|
||||
from .data.embeddings import load_embeddings
|
||||
|
||||
|
||||
def initialize_logger(args):
|
||||
|
@ -111,16 +111,21 @@ def prepare_data(args, logger):
|
|||
if args.vocab_tasks is not None and task.name in args.vocab_tasks:
|
||||
vocab_sets.extend(split)
|
||||
|
||||
numericalizer = make_numericalizer(args)
|
||||
numericalizer, encoder_embeddings, decoder_embeddings = load_embeddings(args.embeddings, args.encoder_embeddings,
|
||||
args.decoder_embeddings,
|
||||
args.max_generative_vocab,
|
||||
logger)
|
||||
if args.load is not None:
|
||||
numericalizer.load(args.save)
|
||||
else:
|
||||
vectors = load_embeddings(args, logger)
|
||||
vocab_sets = (train_sets + val_sets) if len(vocab_sets) == 0 else vocab_sets
|
||||
logger.info(f'Building vocabulary')
|
||||
numericalizer.build_vocab(vectors, Example.vocab_fields, vocab_sets)
|
||||
numericalizer.build_vocab(Example.vocab_fields, vocab_sets)
|
||||
numericalizer.save(args.save)
|
||||
|
||||
for vec in set(encoder_embeddings + decoder_embeddings):
|
||||
vec.init_for_vocab(numericalizer.vocab)
|
||||
|
||||
logger.info(f'Vocabulary has {numericalizer.num_tokens} tokens')
|
||||
logger.debug(f'The first 200 tokens:')
|
||||
logger.debug(numericalizer.vocab.itos[:200])
|
||||
|
@ -133,7 +138,7 @@ def prepare_data(args, logger):
|
|||
logger.info('Preprocessing validation data')
|
||||
preprocess_examples(args, args.val_tasks, val_sets, logger, train=args.val_filter)
|
||||
|
||||
return numericalizer, train_sets, val_sets, aux_sets
|
||||
return numericalizer, encoder_embeddings, decoder_embeddings, train_sets, val_sets, aux_sets
|
||||
|
||||
|
||||
def get_learning_rate(i, args):
|
||||
|
@ -392,11 +397,11 @@ def train(args, devices, model, opt, train_sets, train_iterations, numericalizer
|
|||
break
|
||||
|
||||
|
||||
def init_model(args, numericalizer, devices, logger):
|
||||
def init_model(args, numericalizer, encoder_embeddings, decoder_embeddings, devices, logger):
|
||||
model_name = args.model
|
||||
logger.info(f'Initializing {model_name}')
|
||||
Model = getattr(models, model_name)
|
||||
model = Model(numericalizer, args, devices)
|
||||
model = Model(numericalizer, args, encoder_embeddings, decoder_embeddings)
|
||||
params = get_trainable_params(model)
|
||||
log_model_size(logger, model, model_name)
|
||||
|
||||
|
@ -432,7 +437,7 @@ def main(argv=sys.argv):
|
|||
if args.load is not None:
|
||||
logger.info(f'Loading vocab from {os.path.join(args.save, args.load)}')
|
||||
save_dict = torch.load(os.path.join(args.save, args.load))
|
||||
numericalizer, train_sets, val_sets, aux_sets = prepare_data(args, logger)
|
||||
numericalizer, encoder_embeddings, decoder_embeddings, train_sets, val_sets, aux_sets = prepare_data(args, logger)
|
||||
if (args.use_curriculum and aux_sets is None) or (not args.use_curriculum and len(aux_sets)):
|
||||
logging.error('sth unpleasant is happening with curriculum')
|
||||
|
||||
|
@ -445,7 +450,7 @@ def main(argv=sys.argv):
|
|||
else:
|
||||
writer = None
|
||||
|
||||
model = init_model(args, numericalizer, devices, logger)
|
||||
model = init_model(args, numericalizer, encoder_embeddings, decoder_embeddings, devices, logger)
|
||||
opt = init_opt(args, model)
|
||||
start_iteration = 1
|
||||
|
||||
|
|
|
@ -187,34 +187,23 @@ def load_config_json(args):
|
|||
args.almond_type_embeddings = False
|
||||
with open(os.path.join(args.path, 'config.json')) as config_file:
|
||||
config = json.load(config_file)
|
||||
retrieve = ['model', 'transformer_layers', 'rnn_layers', 'transformer_hidden', 'dimension',
|
||||
'load', 'max_val_context_length', 'val_batch_size', 'transformer_heads', 'max_output_length',
|
||||
'max_effective_vocab', 'max_generative_vocab', 'lower', 'glove_and_char',
|
||||
'small_glove', 'almond_type_embeddings', 'almond_grammar',
|
||||
'trainable_decoder_embedding', 'glove_decoder', 'pretrained_decoder_lm',
|
||||
'retrain_encoder_embedding', 'question', 'locale', 'use_google_translate']
|
||||
retrieve = ['model', 'seq2seq_encoder', 'seq2seq_decoder', 'transformer_layers', 'rnn_layers',
|
||||
'transformer_hidden', 'dimension', 'load', 'max_val_context_length', 'val_batch_size',
|
||||
'transformer_heads', 'max_output_length', 'max_generative_vocab', 'lower', 'encoder_embeddings',
|
||||
'decoder_embeddings', 'trainable_decoder_embeddings', 'train_encoder_embeddings',
|
||||
'question', 'locale', 'use_google_translate']
|
||||
|
||||
for r in retrieve:
|
||||
if r in config:
|
||||
setattr(args, r, config[r])
|
||||
elif r == 'locale':
|
||||
setattr(args, r, 'en')
|
||||
elif r in ('small_glove', 'almond_type_embbedings'):
|
||||
setattr(args, r, False)
|
||||
elif r in ('glove_decoder', 'glove_and_char'):
|
||||
setattr(args, r, True)
|
||||
elif r == 'trainable_decoder_embedding':
|
||||
setattr(args, r, 0)
|
||||
elif r == 'retrain_encoder_embedding':
|
||||
elif r == 'train_encoder_embedding':
|
||||
setattr(args, r, False)
|
||||
else:
|
||||
setattr(args, r, None)
|
||||
args.dropout_ratio = 0.0
|
||||
|
||||
args.best_checkpoint = os.path.join(args.path, args.checkpoint_name)
|
||||
|
||||
|
||||
def make_numericalizer(args):
|
||||
return SimpleNumericalizer(max_effective_vocab=args.max_effective_vocab,
|
||||
max_generative_vocab=args.max_generative_vocab,
|
||||
pad_first=False)
|
||||
args.best_checkpoint = os.path.join(args.path, args.checkpoint_name)
|
|
@ -23,10 +23,12 @@ workdir=`mktemp -d $TMPDIR/decaNLP-tests-XXXXXX`
|
|||
trap on_error ERR INT TERM
|
||||
|
||||
i=0
|
||||
for hparams in "" "--use_curriculum"; do
|
||||
for hparams in "--encoder_embeddings=small_glove+char --decoder_embeddings=small_glove+char" \
|
||||
"--encoder_embeddings=bert-base-uncased --decoder_embeddings= --trainable_decoder_embeddings=50" \
|
||||
"--encoder_embeddings=bert-base-uncased --decoder_embeddings= --trainable_decoder_embeddings=50 --seq2seq_encoder=Identity --dimension=768 --rnn_layers=0" ; do
|
||||
|
||||
# train
|
||||
pipenv run python3 -m decanlp train --train_tasks almond --train_iterations 4 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --root "" --embeddings $SRCDIR/embeddings --small_glove --no_commit
|
||||
pipenv run python3 -m decanlp train --train_tasks almond --train_iterations 6 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --root "" --embeddings $SRCDIR/embeddings --no_commit
|
||||
|
||||
# greedy decode
|
||||
pipenv run python3 -m decanlp predict --tasks almond --evaluate test --path $workdir/model_$i --overwrite --eval_dir $workdir/model_$i/eval_results/ --data $SRCDIR/dataset/ --embeddings $SRCDIR/embeddings
|
||||
|
|
Loading…
Reference in New Issue