diff --git a/decanlp/arguments.py b/decanlp/arguments.py index 8d54f409..fa49fd29 100644 --- a/decanlp/arguments.py +++ b/decanlp/arguments.py @@ -82,7 +82,6 @@ def parse(argv): parser.add_argument('--vocab_tasks', nargs='+', type=str, help='tasks to use in the construction of the vocabulary') parser.add_argument('--max_output_length', default=100, type=int, help='maximum output length for generation') - parser.add_argument('--max_effective_vocab', default=int(1e6), type=int, help='max effective vocabulary size for pretrained embeddings') parser.add_argument('--max_generative_vocab', default=50000, type=int, help='max vocabulary for the generative softmax') parser.add_argument('--max_train_context_length', default=500, type=int, help='maximum length of the contexts during training') parser.add_argument('--max_val_context_length', default=500, type=int, help='maximum length of the contexts during validation') @@ -90,19 +89,22 @@ def parse(argv): parser.add_argument('--subsample', default=20000000, type=int, help='subsample the datasets') parser.add_argument('--preserve_case', action='store_false', dest='lower', help='whether to preserve casing for all text') - parser.add_argument('--model', type=str, default='MultitaskQuestionAnsweringNetwork', help='which model to import') + parser.add_argument('--model', type=str, choices=['Seq2Seq'], default='Seq2Seq', help='which model to import') + parser.add_argument('--seq2seq_encoder', type=str, choices=['MQANEncoder', 'Identity'], default='MQANEncoder', + help='which encoder to use for the Seq2Seq model') + parser.add_argument('--seq2seq_decoder', type=str, choices=['MQANDecoder'], default='MQANDecoder', + help='which decoder to use for the Seq2Seq model') parser.add_argument('--dimension', default=200, type=int, help='output dimensions for all layers') parser.add_argument('--rnn_layers', default=1, type=int, help='number of layers for RNN modules') parser.add_argument('--transformer_layers', default=2, type=int, help='number of layers for transformer modules') parser.add_argument('--transformer_hidden', default=150, type=int, help='hidden size of the transformer modules') parser.add_argument('--transformer_heads', default=3, type=int, help='number of heads for transformer modules') parser.add_argument('--dropout_ratio', default=0.2, type=float, help='dropout for the model') - parser.add_argument('--no_glove_and_char', action='store_false', dest='glove_and_char', help='turn off GloVe and CharNGram embeddings') - parser.add_argument('--locale', default='en', help='locale to use for word embeddings') - parser.add_argument('--retrain_encoder_embedding', default=False, action='store_true', help='whether to retrain encoder embeddings') - parser.add_argument('--trainable_decoder_embedding', default=0, type=int, help='size of trainable portion of decoder embedding (0 or omit to disable)') - parser.add_argument('--no_glove_decoder', action='store_false', dest='glove_decoder', help='turn off GloVe embeddings from decoder') - parser.add_argument('--pretrained_decoder_lm', help='pretrained language model to use as embedding layer for the decoder (omit to disable)') + + parser.add_argument('--encoder_embeddings', default='glove+char', help='which word embedding to use on the encoder side; use a bert-* pretrained model for BERT; multiple embeddings can be concatenated with +') + parser.add_argument('--train_encoder_embeddings', action='store_true', default=False, help='back propagate into pretrained encoder embedding (recommended for BERT)') + parser.add_argument('--decoder_embeddings', default='glove+char', help='which pretrained word embedding to use on the decoder side') + parser.add_argument('--trainable_decoder_embeddings', default=0, type=int, help='size of trainable portion of decoder embedding (0 or omit to disable)') parser.add_argument('--warmup', default=800, type=int, help='warmup for learning rate') parser.add_argument('--grad_clip', default=1.0, type=float, help='gradient clipping') @@ -123,8 +125,6 @@ def parse(argv): parser.add_argument('--skip_cache', action='store_true', dest='skip_cache_bool', help='whether to use exisiting cached splits or generate new ones') parser.add_argument('--lr_rate', default=0.001, type=float, help='initial_learning_rate') - parser.add_argument('--small_glove', action='store_true', help='Use glove.6B.50d instead of glove.840B.300d') - parser.add_argument('--almond_type_embeddings', action='store_true', help='Add type-based word embeddings for Almond task') parser.add_argument('--use_curriculum', action='store_true', help='Use curriculum learning') parser.add_argument('--aux_dataset', default='', type=str, help='path to auxiliary dataset (ignored if curriculum is not used)') parser.add_argument('--curriculum_max_frac', default=1.0, type=float, help='max fraction of harder dataset to keep for curriculum') diff --git a/decanlp/cache_embeddings.py b/decanlp/cache_embeddings.py index e34fe452..c1be0fa3 100644 --- a/decanlp/cache_embeddings.py +++ b/decanlp/cache_embeddings.py @@ -29,14 +29,12 @@ from argparse import ArgumentParser -import torch -import numpy as np -import random import logging import sys from pprint import pformat -from .utils.embeddings import load_embeddings +from .util import set_seed +from .data.embeddings import load_embeddings logger = logging.getLogger(__name__) @@ -44,8 +42,8 @@ logger = logging.getLogger(__name__) def get_args(argv): parser = ArgumentParser(prog=argv[0]) parser.add_argument('--seed', default=123, type=int, help='Random seed.') - parser.add_argument('--embeddings', default='./decaNLP/.embeddings', type=str, help='where to save embeddings.') - parser.add_argument('--locale', default='en', help='locale to use for word embeddings') + parser.add_argument('-d', '--destdir', default='./decaNLP/.embeddings', type=str, help='where to save embeddings.') + parser.add_argument('--embeddings', default='glove+char', help='which embeddings to download') args = parser.parse_args(argv[1:]) return args @@ -55,9 +53,5 @@ def main(argv=sys.argv): args = get_args(argv) logger.info(f'Arguments:\n{pformat(vars(args))}') - np.random.seed(args.seed) - random.seed(args.seed) - torch.manual_seed(args.seed) - torch.cuda.manual_seed(args.seed) - - load_embeddings(args, load_almond_embeddings=False) + set_seed(args.seed) + load_embeddings(args.destdir, args.embeddings, '') diff --git a/decanlp/utils/embeddings.py b/decanlp/data/almond_embeddings.py similarity index 72% rename from decanlp/utils/embeddings.py rename to decanlp/data/almond_embeddings.py index 0e1b4242..7bee3845 100644 --- a/decanlp/utils/embeddings.py +++ b/decanlp/data/almond_embeddings.py @@ -30,17 +30,15 @@ import torch -import logging -from ..data import word_vectors - -_logger = logging.getLogger(__name__) +from .word_vectors import Vectors ENTITIES = ['DATE', 'DURATION', 'EMAIL_ADDRESS', 'HASHTAG', 'LOCATION', 'NUMBER', 'PHONE_NUMBER', 'QUOTED_STRING', 'TIME', 'URL', 'USERNAME', 'PATH_NAME', 'CURRENCY'] MAX_ARG_VALUES = 5 -class AlmondEmbeddings(word_vectors.Vectors): + +class AlmondEmbeddings(Vectors): def __init__(self, name=None, cache=None, **kw): super().__init__(name, cache, **kw) @@ -63,27 +61,4 @@ class AlmondEmbeddings(word_vectors.Vectors): self.itos = itos self.stoi = {word: i for i, word in enumerate(itos)} self.vectors = torch.stack(vectors, dim=0).view(-1, dim) - self.dim = dim - - -def load_embeddings(args, logger=_logger, load_almond_embeddings=True): - logger.info(f'Getting pretrained word vectors') - - language = args.locale.split('-')[0] - - if language == 'en': - char_vectors = word_vectors.CharNGram(cache=args.embeddings) - if args.small_glove: - glove_vectors = word_vectors.GloVe(cache=args.embeddings, name="6B", dim=50) - else: - glove_vectors = word_vectors.GloVe(cache=args.embeddings) - vectors = [char_vectors, glove_vectors] - # elif args.locale == 'zh': - # Chinese word embeddings - else: - # default to fastText - vectors = [word_vectors.FastText(cache=args.embeddings, language=language)] - - if load_almond_embeddings and args.almond_type_embeddings: - vectors.append(AlmondEmbeddings()) - return vectors + self.dim = dim \ No newline at end of file diff --git a/decanlp/data/embeddings.py b/decanlp/data/embeddings.py new file mode 100644 index 00000000..209cdf31 --- /dev/null +++ b/decanlp/data/embeddings.py @@ -0,0 +1,232 @@ +# +# Copyright (c) 2018-2019, Salesforce, Inc. +# The Board of Trustees of the Leland Stanford Junior University +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch +import os +from collections import defaultdict +import logging +from transformers import AutoTokenizer, AutoModel, BertConfig +from typing import NamedTuple, List + +from .numericalizer import SimpleNumericalizer, BertNumericalizer +from . import word_vectors +from .almond_embeddings import AlmondEmbeddings +from .pretrained_lstm_lm import PretrainedLTSMLM + +_logger = logging.getLogger(__name__) + + +class EmbeddingOutput(NamedTuple): + all_layers: List[torch.Tensor] + last_layer: torch.Tensor + + +class WordVectorEmbedding(torch.nn.Module): + def __init__(self, vec_collection): + super().__init__() + self._vec_collection = vec_collection + self.dim = vec_collection.dim + self.num_layers = 0 + self.embedding = None + + def init_for_vocab(self, vocab): + vectors = torch.empty(len(vocab), self.dim, device=torch.device('cpu')) + for ti, token in enumerate(vocab.itos): + vectors[ti] = self._vec_collection[token.strip()] + + self.embedding = torch.nn.Embedding(len(vocab.itos), self.dim) + self.embedding.weight.data = vectors + + def grow_for_vocab(self, vocab, new_words): + if not new_words: + return + new_vectors = [] + for word in new_words: + new_vector = self._vec_collection[word] + + # charNgram returns a [1, D] tensor, while Glove returns a [D] tensor + # normalize to [1, D] so we can concat along the second dimension + # and later concat all vectors along the first + new_vector = new_vector if new_vector.dim() > 1 else new_vector.unsqueeze(0) + new_vectors.append(new_vector) + + self.embedding.weight.data = torch.cat([self.embedding.weight.data.cpu()] + new_vectors, dim=0) + + def forward(self, input : torch.Tensor, padding=None): + last_layer = self.embedding(input.cpu()).to(input.device) + return EmbeddingOutput(all_layers=[last_layer], last_layer=last_layer) + + def to(self, *args, **kwargs): + # ignore attempts to move the word embedding, which should stay on CPU + kwargs['device'] = torch.device('cpu') + return super().to(*args, **kwargs) + + def cuda(self, device=None): + # ignore attempts to move the word embedding + pass + + +class BertEmbedding(torch.nn.Module): + def __init__(self, model): + super().__init__() + model.config.output_hidden_states = True + self.dim = model.config.hidden_size + self.num_layers = model.config.num_hidden_layers + self.model = model + + def init_for_vocab(self, vocab): + self.model.resize_token_embeddings(len(vocab)) + + def grow_for_vocab(self, vocab, new_words): + self.model.resize_token_embeddings(len(vocab)) + + def forward(self, input : torch.Tensor, padding=None): + last_hidden_state, _pooled, hidden_states = self.model(input, attention_mask=(~padding).to(dtype=torch.float)) + + return EmbeddingOutput(all_layers=hidden_states, last_layer=last_hidden_state) + + +class PretrainedLMEmbedding(torch.nn.Module): + def __init__(self, model_name, cachedir): + super().__init__() + # map to CPU first, we will be moved to the right place later + pretrained_save_dict = torch.load(os.path.join(cachedir, model_name), map_location=torch.device('cpu')) + + self.itos = pretrained_save_dict['vocab'] + self.stoi = defaultdict(lambda: 0, { + w: i for i, w in enumerate(self.pretrained_decoder_vocab_itos) + }) + self.dim = pretrained_save_dict['settings']['nhid'] + self.num_layers = 1 + self.model = PretrainedLTSMLM(rnn_type=pretrained_save_dict['settings']['rnn_type'], + ntoken=len(self.pretrained_decoder_vocab_itos), + emsize=pretrained_save_dict['settings']['emsize'], + nhid=pretrained_save_dict['settings']['nhid'], + nlayers=pretrained_save_dict['settings']['nlayers'], + dropout=0.0) + self.model.load_state_dict(pretrained_save_dict['model'], strict=True) + + self.vocab_to_pretrained = None + + def init_for_vocab(self, vocab): + self.vocab_to_pretrained = torch.empty(len(self.vocab), dtype=torch.int64) + + unk_id = self.stoi[''] + for ti, token in enumerate(vocab.itos): + if token in self.pretrained_decoder_vocab_stoi: + self.vocab_to_pretrained[ti] = self.stoi[token] + else: + self.vocab_to_pretrained[ti] = unk_id + + def grow_for_vocab(self, vocab, new_words): + self.init_for_vocab(vocab) + + def forward(self, input : torch.Tensor, padding=None): + pretrained_indices = torch.gather(self.vocab_to_pretrained, dim=0, index=input) + rnn_output = self.model(pretrained_indices) + return EmbeddingOutput(all_layers=[rnn_output], last_layer=rnn_output) + + +def _is_bert(embedding_name): + return embedding_name.startswith('bert-') + + +def _name_to_vector(emb_name, cachedir): + if emb_name == 'glove': + return WordVectorEmbedding(word_vectors.GloVe(cache=cachedir)) + elif emb_name == 'small_glove': + return WordVectorEmbedding(word_vectors.GloVe(cache=cachedir, name="6B", dim=50)) + elif emb_name == 'char': + return WordVectorEmbedding(word_vectors.CharNGram(cache=cachedir)) + elif emb_name == 'almond_type': + return AlmondEmbeddings() + elif emb_name.startswith('fasttext/'): + # FIXME this should use the fasttext library + return WordVectorEmbedding(word_vectors.FastText(cache=cachedir, language=emb_name[len('fasttext/'):])) + elif emb_name.startswith('pretrained_lstm/'): + return PretrainedLMEmbedding(emb_name[len('pretrained_lstm/'):], cachedir=cachedir) + else: + raise ValueError(f'Unrecognized embedding name {emb_name}') + + +def load_embeddings(cachedir, encoder_emb_names, decoder_emb_names, max_generative_vocab=50000, logger=_logger): + logger.info(f'Getting pretrained word vectors and pretrained models') + + encoder_emb_names = encoder_emb_names.split('+') + decoder_emb_names = decoder_emb_names.split('+') + + all_vectors = {} + encoder_vectors = [] + decoder_vectors = [] + + numericalizer = None + for emb_name in encoder_emb_names: + if not emb_name: + continue + if _is_bert(emb_name): + if numericalizer is not None: + raise ValueError('Cannot specify multiple BERT embeddings') + + config = BertConfig.from_pretrained(emb_name, cache_dir=cachedir) + config.output_hidden_states = True + numericalizer = BertNumericalizer(config, emb_name, max_generative_vocab=max_generative_vocab, cache=cachedir) + + # load the tokenizer once to ensure all files are downloaded + AutoTokenizer.from_pretrained(emb_name, cache_dir=cachedir) + + encoder_vectors.append(BertEmbedding(AutoModel.from_pretrained(emb_name, config=config, cache_dir=cachedir))) + else: + if numericalizer is not None: + logger.warning('Combining BERT embeddings with other pretrained embeddings is unlikely to work') + + if emb_name in all_vectors: + encoder_vectors.append(all_vectors[emb_name]) + else: + vec = _name_to_vector(emb_name, cachedir) + all_vectors[emb_name] = vec + encoder_vectors.append(vec) + + for emb_name in decoder_emb_names: + if not emb_name: + continue + if _is_bert(emb_name): + raise ValueError('BERT embeddings cannot be specified in the decoder') + + if emb_name in all_vectors: + decoder_vectors.append(all_vectors[emb_name]) + else: + vec = _name_to_vector(emb_name, cachedir) + all_vectors[emb_name] = vec + decoder_vectors.append(vec) + + if numericalizer is None: + numericalizer = SimpleNumericalizer(max_generative_vocab=max_generative_vocab, pad_first=False) + + return numericalizer, encoder_vectors, decoder_vectors diff --git a/decanlp/data/numericalizer/bert.py b/decanlp/data/numericalizer/bert.py index aeab8e4d..acbe0472 100644 --- a/decanlp/data/numericalizer/bert.py +++ b/decanlp/data/numericalizer/bert.py @@ -41,7 +41,8 @@ class BertNumericalizer(object): Numericalizer that uses BertTokenizer from huggingface's transformers library. """ - def __init__(self, pretrained_tokenizer, max_generative_vocab, cache=None, fix_length=None): + def __init__(self, config, pretrained_tokenizer, max_generative_vocab, cache=None, fix_length=None): + self.config = config self._pretrained_name = pretrained_tokenizer self.max_generative_vocab = max_generative_vocab self._cache = cache @@ -50,13 +51,18 @@ class BertNumericalizer(object): self.fix_length = fix_length + @property + def vocab(self): + return self._tokenizer + @property def num_tokens(self): - return self._tokenizer.vocab_size + return len(self._tokenizer) def load(self, save_dir): - self.config = BertConfig.from_pretrained(os.path.join(save_dir, 'bert-config.json'), cache_dir=self._cache) self._tokenizer = MaskedBertTokenizer.from_pretrained(save_dir, config=self.config, cache_dir=self._cache) + # HACK we cannot save the tokenizer without this + del self._tokenizer.init_kwargs['config'] with open(os.path.join(save_dir, 'decoder-vocab.txt'), 'r') as fp: self._decoder_words = [line.rstrip('\n') for line in fp] @@ -64,15 +70,15 @@ class BertNumericalizer(object): self._init() def save(self, save_dir): - self.config.save_pretrained(os.path.join(save_dir, 'bert-config.json')) - self._tokenizer.save_pretrained(os.path.join(save_dir)) + self._tokenizer.save_pretrained(save_dir) with open(os.path.join(save_dir, 'decoder-vocab.txt'), 'w') as fp: for word in self._decoder_words: fp.write(word + '\n') - def build_vocab(self, vectors, vocab_fields, vocab_sets): - self.config = BertConfig.from_pretrained(self._pretrained_name, cache_dir=self._cache) + def build_vocab(self, vocab_fields, vocab_sets): self._tokenizer = MaskedBertTokenizer.from_pretrained(self._pretrained_name, config=self.config, cache_dir=self._cache) + # HACK we cannot save the tokenizer without this + del self._tokenizer.init_kwargs['config'] # ensure that init, eos, unk and pad are set # this method has no effect if the tokens are already set according to the tokenizer class @@ -83,19 +89,31 @@ class BertNumericalizer(object): 'pad_token': '[PAD]' }) - # do a pass over all the answers in the dataset, and construct a counter of wordpieces + # do a pass over all the data in the dataset + # in this pass, we + # 1) tokenize everything, to ensure we account for all added tokens + # 2) we construct a counter of wordpieces in the answers, for the decoder vocabulary decoder_words = collections.Counter() for dataset in vocab_sets: for example in dataset: + self._tokenizer.tokenize(example.context, example.context_word_mask) + self._tokenizer.tokenize(example.question, example.question_word_mask) + tokens = self._tokenizer.tokenize(example.answer, example.answer_word_mask) decoder_words.update(tokens) - self._decoder_words = decoder_words.most_common(self.max_generative_vocab) + self._decoder_words = [word for word, _freq in decoder_words.most_common(self.max_generative_vocab)] self._init() - def grow_vocab(self, examples, vectors): - # TODO + def grow_vocab(self, examples): + # do a pass over all the data in the dataset and tokenize everything + # this will add any new tokens that are not to be converted into word-pieces + for example in examples: + self._tokenizer.tokenize(example.context, example.context_word_mask) + self._tokenizer.tokenize(example.question, example.question_word_mask) + + # return no new words - BertEmbedding will resize the embedding regardless return [] def _init(self): @@ -128,9 +146,9 @@ class BertNumericalizer(object): wp_tokenized.append(self._tokenizer.tokenize(tokens, mask)) if self.fix_length is None: - max_len = max(len(x) for x in minibatch) + 2 + max_len = max(len(x) for x in wp_tokenized) else: - max_len = self.fix_length + 2 + max_len = self.fix_length padded = [] lengths = [] numerical = [] @@ -148,7 +166,7 @@ class BertNumericalizer(object): [self.pad_token] * max(0, max_len - len(wp_tokens)) padded.append(padded_example) - lengths.append(len(padded_example) - max(0, max_len - len(wp_tokens))) + lengths.append(len(wp_tokens) + 2) numerical.append(self._tokenizer.convert_tokens_to_ids(padded_example)) decoder_numerical.append([decoder_vocab.encode(word) for word in padded_example]) @@ -157,7 +175,7 @@ class BertNumericalizer(object): numerical = torch.tensor(numerical, dtype=torch.int64, device=device) decoder_numerical = torch.tensor(decoder_numerical, dtype=torch.int64, device=device) - return SequentialField(tokens=padded, length=length, value=numerical, limited=decoder_numerical) + return SequentialField(length=length, value=numerical, limited=decoder_numerical) def decode(self, tensor): return self._tokenizer.convert_ids_to_tokens(tensor) @@ -176,7 +194,10 @@ class BertNumericalizer(object): if token in (self.init_token, self.pad_token): continue if token.startswith('##'): - tokens[-1] += token[2:] + if len(tokens) == 0: + tokens.append(token[2:]) + else: + tokens[-1] += token[2:] else: tokens.append(token) diff --git a/decanlp/data/numericalizer/masked_bert_tokenizer.py b/decanlp/data/numericalizer/masked_bert_tokenizer.py index c85559c1..e14f3117 100644 --- a/decanlp/data/numericalizer/masked_bert_tokenizer.py +++ b/decanlp/data/numericalizer/masked_bert_tokenizer.py @@ -54,10 +54,10 @@ class MaskedWordPieceTokenizer: def tokenize(self, tokens, mask): output_tokens = [] - for token, should_word_split in tokens, mask: + for token, should_word_split in zip(tokens, mask): if not should_word_split: if token not in self.vocab and token not in self.added_tokens_encoder: - token_id = len(self.added_tokens_encoder) + token_id = len(self.vocab) + len(self.added_tokens_encoder) self.added_tokens_encoder[token] = token_id self.added_tokens_decoder[token_id] = token output_tokens.append(token) @@ -95,11 +95,60 @@ class MaskedWordPieceTokenizer: return output_tokens +class IToSWrapper: + """Wrap the ordered dict vocabs to look like a list int -> str""" + + def __init__(self, base_vocab, added_tokens): + self.base_vocab = base_vocab + self.added_tokens = added_tokens + + def __getitem__(self, key): + if isinstance(key, slice): + return [self[key] for key in range(key.start or 0, key.stop or len(self), key.step or 1)] + + if key < len(self.base_vocab): + return self.base_vocab[key] + else: + return self.added_tokens[key] + + def __len__(self): + return len(self.base_vocab) + len(self.added_tokens) + + def __iter__(self): + for key in range(len(self.base_vocab)): + yield self.base_vocab[key] + for key in range(len(self.base_vocab), len(self.base_vocab) + len(self.added_tokens)): + yield self.added_tokens[key] + + +class SToIWrapper: + """Wrap the ordered dict vocabs to look like a single dictionary""" + + def __init__(self, base_vocab, added_tokens): + self.base_vocab = base_vocab + self.added_tokens = added_tokens + + def __getitem__(self, key): + if key in self.base_vocab: + return self.base_vocab[key] + else: + return self.added_tokens[key] + + def __len__(self): + return len(self.base_vocab) + len(self.added_tokens) + + def __iter__(self): + for key in self.base_vocab: + yield key + for key in self.added_tokens: + yield key + + class MaskedBertTokenizer(BertTokenizer): """ A modified BertTokenizer that respects a mask deciding whether a token should be split or not. """ - def __init__(self, *args, do_lower_case, do_basic_tokenize, **kwargs): + def __init__(self, *args, do_lower_case=False, do_basic_tokenize=False, **kwargs): # override do_lower_case and do_basic_tokenize unconditionally super().__init__(*args, do_lower_case=False, do_basic_tokenize=False, **kwargs) @@ -109,14 +158,21 @@ class MaskedBertTokenizer(BertTokenizer): added_tokens_decoder=self.added_tokens_decoder, unk_token=self.unk_token) + self._itos = IToSWrapper(self.ids_to_tokens, self.added_tokens_decoder) + self._stoi = SToIWrapper(self.vocab, self.added_tokens_encoder) + def tokenize(self, tokens, mask=None): return self.wordpiece_tokenizer.tokenize(tokens, mask) - # provide an interface that DecoderVocabulary can like + # provide an interface similar to Vocab + + def __len__(self): + return len(self.vocab) + len(self.added_tokens_encoder) + @property def stoi(self): - return self.vocab + return self._stoi @property def itos(self): - return self.ids_to_tokens \ No newline at end of file + return self._itos \ No newline at end of file diff --git a/decanlp/data/numericalizer/sequential_field.py b/decanlp/data/numericalizer/sequential_field.py index bf81219b..7083f6f5 100644 --- a/decanlp/data/numericalizer/sequential_field.py +++ b/decanlp/data/numericalizer/sequential_field.py @@ -34,5 +34,4 @@ from typing import NamedTuple, List class SequentialField(NamedTuple): value : torch.tensor length : torch.tensor - limited : torch.tensor - tokens : List[List[str]] \ No newline at end of file + limited : torch.tensor \ No newline at end of file diff --git a/decanlp/data/numericalizer/simple.py b/decanlp/data/numericalizer/simple.py index 44b41d0a..032daf22 100644 --- a/decanlp/data/numericalizer/simple.py +++ b/decanlp/data/numericalizer/simple.py @@ -35,8 +35,7 @@ from .sequential_field import SequentialField from .decoder_vocab import DecoderVocabulary class SimpleNumericalizer(object): - def __init__(self, max_effective_vocab, max_generative_vocab, fix_length=None, pad_first=False): - self.max_effective_vocab = max_effective_vocab + def __init__(self, max_generative_vocab, fix_length=None, pad_first=False): self.max_generative_vocab = max_generative_vocab self.init_token = '' @@ -58,17 +57,15 @@ class SimpleNumericalizer(object): def save(self, save_dir): torch.save(self.vocab, os.path.join(save_dir, 'vocab.pth')) - def build_vocab(self, vectors, vocab_fields, vocab_sets): + def build_vocab(self, vocab_fields, vocab_sets): self.vocab = Vocab.build_from_data(vocab_fields, *vocab_sets, unk_token=self.unk_token, init_token=self.init_token, eos_token=self.eos_token, - pad_token=self.pad_token, - max_size=self.max_effective_vocab, - vectors=vectors) + pad_token=self.pad_token) self._init_vocab() - def _grow_vocab_one(self, sentence, vectors, new_vectors): + def _grow_vocab_one(self, sentence, new_words): assert isinstance(sentence, list) # check if all the words are in the vocabulary, and if not @@ -77,22 +74,15 @@ class SimpleNumericalizer(object): if word not in self.vocab.stoi: self.vocab.stoi[word] = len(self.vocab.itos) self.vocab.itos.append(word) + new_words.append(word) - new_vector = [vec[word] for vec in vectors] - - # charNgram returns a [1, D] tensor, while Glove returns a [D] tensor - # normalize to [1, D] so we can concat along the second dimension - # and later concat all vectors along the first - new_vector = [vec if vec.dim() > 1 else vec.unsqueeze(0) for vec in new_vector] - new_vectors.append(torch.cat(new_vector, dim=1)) - - def grow_vocab(self, examples, vectors): - new_vectors = [] + def grow_vocab(self, examples): + new_words = [] for ex in examples: - self._grow_vocab_one(ex.context, vectors, new_vectors) - self._grow_vocab_one(ex.question, vectors, new_vectors) - self._grow_vocab_one(ex.answer, vectors, new_vectors) - return new_vectors + self._grow_vocab_one(ex.context, new_words) + self._grow_vocab_one(ex.question, new_words) + self._grow_vocab_one(ex.answer, new_words) + return new_words def _init_vocab(self): self.init_id = self.vocab.stoi[self.init_token] @@ -113,7 +103,7 @@ class SimpleNumericalizer(object): if self.fix_length is None: max_len = max(len(x[0]) for x in minibatch) else: - max_len = self.fix_length + 2 + max_len = self.fix_length padded = [] lengths = [] numerical = [] @@ -140,7 +130,7 @@ class SimpleNumericalizer(object): numerical = torch.tensor(numerical, dtype=torch.int64, device=device) decoder_numerical = torch.tensor(decoder_numerical, dtype=torch.int64, device=device) - return SequentialField(tokens=padded, length=length, value=numerical, limited=decoder_numerical) + return SequentialField(length=length, value=numerical, limited=decoder_numerical) def decode(self, tensor): return [self.vocab.itos[idx] for idx in tensor] diff --git a/decanlp/data/numericalizer/vocab.py b/decanlp/data/numericalizer/vocab.py index 273a7ef4..6e1616fb 100644 --- a/decanlp/data/numericalizer/vocab.py +++ b/decanlp/data/numericalizer/vocab.py @@ -20,8 +20,7 @@ class Vocab(object): numerical identifiers. itos: A list of token strings indexed by their numerical identifiers. """ - def __init__(self, counter, max_size=None, min_freq=1, specials=('',), - vectors=None, cat_vectors=True): + def __init__(self, counter, max_size=None, min_freq=1, specials=('',)): """Create a Vocab object from a collections.Counter. Arguments: @@ -60,10 +59,6 @@ class Vocab(object): self.itos.append(word) self.stoi[word] = len(self.itos) - 1 - self.vectors = None - if vectors is not None: - self.load_vectors(vectors, cat=cat_vectors) - def __eq__(self, other): if self.freqs != other.freqs: return False @@ -71,8 +66,6 @@ class Vocab(object): return False if self.itos != other.itos: return False - if self.vectors != other.vectors: - return False return True def __len__(self): @@ -85,54 +78,6 @@ class Vocab(object): self.itos.append(w) self.stoi[w] = len(self.itos) - 1 - def load_vectors(self, vectors, cat=True): - """ - Arguments: - vectors: one of or a list containing instantiations of the - GloVe, CharNGram, or Vectors classes. - """ - if not isinstance(vectors, list): - vectors = [vectors] - if cat: - tot_dim = sum(v.dim for v in vectors) - self.vectors = torch.Tensor(len(self), tot_dim) - for ti, token in enumerate(self.itos): - start_dim = 0 - for v in vectors: - end_dim = start_dim + v.dim - self.vectors[ti][start_dim:end_dim] = v[token.strip()] - start_dim = end_dim - assert(start_dim == tot_dim) - else: - self.vectors = [torch.Tensor(len(self), v.dim) for v in vectors] - for ti, t in enumerate(self.itos): - for vi, v in enumerate(vectors): - self.vectors[vi][ti] = v[t.strip()] - - def set_vectors(self, stoi, vectors, dim, unk_init=torch.Tensor.zero_): - """ - Set the vectors for the Vocab instance from a collection of Tensors. - - Arguments: - stoi: A dictionary of string to the index of the associated vector - in the `vectors` input argument. - vectors: An indexed iterable (or other structure supporting __getitem__) that - given an input index, returns a FloatTensor representing the vector - for the token associated with the index. For example, - vector[stoi["string"]] should return the vector for "string". - dim: The dimensionality of the vectors. - unk_init (callback): by default, initialize out-of-vocabulary word vectors - to zero vectors; can be any function that takes in a Tensor and - returns a Tensor of the same size. Default: torch.Tensor.zero_ - """ - self.vectors = torch.Tensor(len(self), dim) - for i, token in enumerate(self.itos): - wv_index = stoi.get(token, None) - if wv_index is not None: - self.vectors[i] = vectors[wv_index] - else: - self.vectors[i] = unk_init(self.vectors[i]) - @staticmethod def build_from_data(field_names, *args, unk_token=None, pad_token=None, init_token=None, eos_token=None, **kwargs): """Construct the Vocab object for this field from one or more datasets. diff --git a/decanlp/data/pretrained_lstm_lm.py b/decanlp/data/pretrained_lstm_lm.py new file mode 100644 index 00000000..b21e6da9 --- /dev/null +++ b/decanlp/data/pretrained_lstm_lm.py @@ -0,0 +1,93 @@ +# The following code was copied and adapted from github.com/floyhub/world-language-model +# +# BSD 3-Clause License +# +# Copyright (c) 2017, +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from torch import nn + +class PretrainedLTSMLM(nn.Module): + """Container module with an encoder, a recurrent module, and a decoder.""" + + def __init__(self, rnn_type, ntoken, emsize, nhid, nlayers, dropout=0.5, tie_weights=False): + super(PretrainedLTSMLM, self).__init__() + self.drop = nn.Dropout(dropout) + self.encoder = nn.Embedding(ntoken, emsize) # Token2Embeddings + if rnn_type in ['LSTM', 'GRU']: + self.rnn = getattr(nn, rnn_type)(emsize, nhid, nlayers, dropout=dropout) + else: + try: + nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type] + except KeyError: + raise ValueError( """An invalid option for `--model` was supplied, + options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""") + self.rnn = nn.RNN(emsize, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout) + self.decoder = nn.Linear(nhid, ntoken) + + # Optionally tie weights as in: + # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) + # https://arxiv.org/abs/1608.05859 + # and + # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) + # https://arxiv.org/abs/1611.01462 + if tie_weights: + if nhid != emsize: + raise ValueError('When using the tied flag, nhid must be equal to emsize') + self.decoder.weight = self.encoder.weight + + self.init_weights() + + self.rnn_type = rnn_type + self.nhid = nhid + self.nlayers = nlayers + + def init_weights(self): + initrange = 0.1 + self.encoder.weight.data.uniform_(-initrange, initrange) + self.decoder.bias.data.fill_(0) + self.decoder.weight.data.uniform_(-initrange, initrange) + + def encode(self, input, hidden=None): + emb = self.drop(self.encoder(input)) + output, hidden = self.rnn(emb, hidden) + output = self.drop(output) + return output, hidden + + def forward(self, input, hidden=None): + encoded, hidden = self.encode(input, hidden) + decoded = self.decoder(encoded.view(encoded.size(0)*encoded.size(1), encoded.size(2))) + return decoded.view(encoded.size(0), encoded.size(1), decoded.size(1)), hidden + + def init_hidden(self, bsz): + weight = next(self.parameters()).data + if self.rnn_type == 'LSTM': + return (weight.new(self.nlayers, bsz, self.nhid).zero_(), + weight.new(self.nlayers, bsz, self.nhid).zero_()) + else: + return weight.new(self.nlayers, bsz, self.nhid).zero_() diff --git a/decanlp/models/__init__.py b/decanlp/models/__init__.py index b7bbcc43..4b5df30f 100644 --- a/decanlp/models/__init__.py +++ b/decanlp/models/__init__.py @@ -27,4 +27,4 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from .multitask_question_answering_network import MultitaskQuestionAnsweringNetwork \ No newline at end of file +from .general_seq2seq import Seq2Seq \ No newline at end of file diff --git a/decanlp/models/common.py b/decanlp/models/common.py index 1da63ecc..56b9d489 100644 --- a/decanlp/models/common.py +++ b/decanlp/models/common.py @@ -37,18 +37,24 @@ import os import sys import numpy as np import torch.nn as nn +from typing import NamedTuple, List from torch.nn.utils.rnn import pad_packed_sequence as unpack from torch.nn.utils.rnn import pack_padded_sequence as pack +class EmbeddingOutput(NamedTuple): + all_layers: List[torch.Tensor] + last_layer: torch.Tensor + + INF = 1e10 EPSILON = 1e-10 -class LSTMDecoder(nn.Module): +class MultiLSTMCell(nn.Module): def __init__(self, num_layers, input_size, rnn_size, dropout): - super(LSTMDecoder, self).__init__() + super(MultiLSTMCell, self).__init__() self.dropout = nn.Dropout(dropout) self.num_layers = num_layers self.layers = nn.ModuleList() @@ -304,6 +310,7 @@ class LinearFeedforward(nn.Module): def forward(self, x): return self.dropout(self.linear(self.feedforward(x))) + class PackedLSTM(nn.Module): def __init__(self, d_in, d_out, bidirectional=False, num_layers=1, @@ -354,32 +361,23 @@ class Feedforward(nn.Module): return self.activation(self.linear(self.dropout(x))) -class Embedding(nn.Module): +class CombinedEmbedding(nn.Module): - def __init__(self, numericalizer, output_dimension, include_pretrained=True, trained_dimension=0, dropout=0.0, project=True, requires_grad=False): + def __init__(self, numericalizer, pretrained_embeddings, + output_dimension, + finetune_pretrained=False, + trained_dimension=0, + project=True): super().__init__() self.project = project - self.requires_grad = requires_grad + self.finetune_pretrained = finetune_pretrained + self.pretrained_embeddings = tuple(pretrained_embeddings) + dimension = 0 - pretrained_dimension = numericalizer.vocab.vectors.size(-1) + for idx, embedding in enumerate(self.pretrained_embeddings): + dimension += embedding.dim + self.add_module('pretrained_' + str(idx), embedding) - if include_pretrained: - # NOTE: this must be a list so that pytorch will not iterate into the module when - # traversing this module - # in turn, this means that moving this Embedding() to the GPU will not move the - # actual embedding, which will stay on CPU; this is necessary because a) we call - # set_embeddings() sometimes with CPU-only tensors, and b) the embedding tensor - # is too big for the GPU anyway - self.pretrained_embeddings = [nn.Embedding(numericalizer.num_tokens, pretrained_dimension)] - self.pretrained_embeddings[0].weight.data = numericalizer.vocab.vectors - self.pretrained_embeddings[0].weight.requires_grad = self.requires_grad - dimension += pretrained_dimension - else: - self.pretrained_embeddings = None - - # OTOH, if we have a trained embedding, we move it around together with the module - # (ie, potentially on GPU), because the saving when applying gradient outweights - # the cost, and hopefully the embedding is small enough to fit in GPU memory if trained_dimension > 0: self.trained_embeddings = nn.Embedding(numericalizer.num_tokens, trained_dimension) dimension += trained_dimension @@ -387,34 +385,53 @@ class Embedding(nn.Module): self.trained_embeddings = None if self.project: self.projection = Feedforward(dimension, output_dimension) - self.dropout = nn.Dropout(dropout) + else: + assert dimension == output_dimension self.dimension = output_dimension - def forward(self, x, lengths=None, device=-1): + def _combine_embeddings(self, embeddings): + if len(embeddings) == 1: + all_layers = embeddings[0].all_layers + last_layer = embeddings[0].last_layer + if self.project: + last_layer = self.projection(last_layer) + return EmbeddingOutput(all_layers=all_layers, last_layer=last_layer) + + all_layers = None + last_layer = [] + for emb in embeddings: + if all_layers is None: + all_layers = [[layer] for layer in emb.all_layers] + elif len(all_layers) != len(emb.all_layers): + raise ValueError('Cannot combine embeddings that use different numbers of layers') + else: + for layer_list, layer in zip(all_layers, emb.all_layers): + layer_list.append(layer) + last_layer.append(emb.last_layer) + + all_layers = [torch.cat(layer, dim=2) for layer in all_layers] + last_layer = torch.cat(last_layer, dim=2) + if self.project: + last_layer = self.projection(last_layer) + return EmbeddingOutput(all_layers=all_layers, last_layer=last_layer) + + def forward(self, x, padding=None): + embedded = [] if self.pretrained_embeddings is not None: - pretrained_embeddings = self.pretrained_embeddings[0](x.cpu()).to(x.device).detach() - else: - pretrained_embeddings = None + if self.finetune_pretrained: + embedded += [emb(x, padding=padding) for emb in self.pretrained_embeddings] + else: + with torch.no_grad(): + embedded += [emb(x, padding=padding) for emb in self.pretrained_embeddings] + if self.trained_embeddings is not None: trained_vocabulary_size = self.trained_embeddings.weight.size()[0] valid_x = torch.lt(x, trained_vocabulary_size) masked_x = torch.where(valid_x, x, torch.zeros_like(x)) - trained_embeddings = self.trained_embeddings(masked_x) - else: - trained_embeddings = None - if pretrained_embeddings is not None and trained_embeddings is not None: - embeddings = torch.cat((pretrained_embeddings, trained_embeddings), dim=2) - elif pretrained_embeddings is not None: - embeddings = pretrained_embeddings - else: - embeddings = trained_embeddings + output = self.trained_embeddings(masked_x) + embedded.append(EmbeddingOutput(all_layers=[output], last_layer=output)) - return self.projection(embeddings) if self.project else embeddings - - def set_embeddings(self, w): - if self.pretrained_embeddings is not None: - self.pretrained_embeddings[0].weight.data = w - self.pretrained_embeddings[0].weight.requires_grad = self.requires_grad + return self._combine_embeddings(embedded) class SemanticFusionUnit(nn.Module): @@ -443,23 +460,38 @@ class LSTMDecoderAttention(nn.Module): self.dot = dot def applyMasks(self, context_mask): - self.context_mask = context_mask + # context_mask is batch x encoder_time, convert it to batch x 1 x encoder_time + self.context_mask = context_mask.unsqueeze(1) + + def forward(self, input : torch.Tensor, context : torch.Tensor): + # input is batch x decoder_time x dim + # context is batch x encoder_time x dim + # output will be batch x decoder_time x dim + # context_attention will be batch x decoder_time x encoder_time - def forward(self, input, context): if not self.dot: - targetT = self.linear_in(input).unsqueeze(2) # batch x dim x 1 + targetT = self.linear_in(input) # batch x decoder_time x dim x 1 else: - targetT = input.unsqueeze(2) + targetT = input - context_scores = torch.bmm(context, targetT).squeeze(2) + x = input.shape + transposed_context = torch.transpose(context, 2, 1) + x = transposed_context.shape + context_scores = torch.matmul(targetT, transposed_context) context_scores.masked_fill_(self.context_mask, -float('inf')) context_attention = F.softmax(context_scores, dim=-1) + EPSILON - context_alignment = torch.bmm(context_attention.unsqueeze(1), context).squeeze(1) - combined_representation = torch.cat([input, context_alignment], 1) + # convert context_attention to batch x decoder_time x 1 x encoder_time + # convert context to batch x 1 x encoder_time x dim + # context_alignment will be batch x decoder_time x 1 x dim + context_alignment = torch.matmul(context_attention.unsqueeze(2), context.unsqueeze(1)) + # squeeze out the extra dimension + context_alignment = context_alignment.squeeze(2) + + combined_representation = torch.cat([input, context_alignment], 2) output = self.tanh(self.linear_out(combined_representation)) - return output, context_attention, context_alignment + return output, context_attention class CoattentiveLayer(nn.Module): @@ -471,8 +503,8 @@ class CoattentiveLayer(nn.Module): self.dropout = nn.Dropout(dropout) def forward(self, context, question, context_padding, question_padding): - context_padding = torch.cat([context.new_zeros((context.size(0), 1), dtype=torch.long)==1, context_padding], 1) - question_padding = torch.cat([question.new_zeros((question.size(0), 1), dtype=torch.long)==1, question_padding], 1) + context_padding = torch.cat([context.new_zeros((context.size(0), 1), dtype=torch.bool), context_padding], 1) + question_padding = torch.cat([question.new_zeros((question.size(0), 1), dtype=torch.bool), question_padding], 1) context_sentinel = self.embed_sentinel(context.new_zeros((context.size(0), 1), dtype=torch.long)) context = torch.cat([context_sentinel, self.dropout(context)], 1) # batch_size x (context_length + 1) x features @@ -503,94 +535,3 @@ class CoattentiveLayer(nn.Module): return F.softmax(raw_scores, dim=1) -# The following code was copied and adapted from github.com/floyhub/world-language-model -# -# BSD 3-Clause License -# -# Copyright (c) 2017, -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# * Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -class PretrainedDecoderLM(nn.Module): - """Container module with an encoder, a recurrent module, and a decoder.""" - - def __init__(self, rnn_type, ntoken, emsize, nhid, nlayers, dropout=0.5, tie_weights=False): - super(PretrainedDecoderLM, self).__init__() - self.drop = nn.Dropout(dropout) - self.encoder = nn.Embedding(ntoken, emsize) # Token2Embeddings - if rnn_type in ['LSTM', 'GRU']: - self.rnn = getattr(nn, rnn_type)(emsize, nhid, nlayers, dropout=dropout) - else: - try: - nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type] - except KeyError: - raise ValueError( """An invalid option for `--model` was supplied, - options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""") - self.rnn = nn.RNN(emsize, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout) - self.decoder = nn.Linear(nhid, ntoken) - - # Optionally tie weights as in: - # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) - # https://arxiv.org/abs/1608.05859 - # and - # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) - # https://arxiv.org/abs/1611.01462 - if tie_weights: - if nhid != emsize: - raise ValueError('When using the tied flag, nhid must be equal to emsize') - self.decoder.weight = self.encoder.weight - - self.init_weights() - - self.rnn_type = rnn_type - self.nhid = nhid - self.nlayers = nlayers - - def init_weights(self): - initrange = 0.1 - self.encoder.weight.data.uniform_(-initrange, initrange) - self.decoder.bias.data.fill_(0) - self.decoder.weight.data.uniform_(-initrange, initrange) - - def encode(self, input, hidden=None): - emb = self.drop(self.encoder(input)) - output, hidden = self.rnn(emb, hidden) - output = self.drop(output) - return output, hidden - - def forward(self, input, hidden=None): - encoded, hidden = self.encode(input, hidden) - decoded = self.decoder(encoded.view(encoded.size(0)*encoded.size(1), encoded.size(2))) - return decoded.view(encoded.size(0), encoded.size(1), decoded.size(1)), hidden - - def init_hidden(self, bsz): - weight = next(self.parameters()).data - if self.rnn_type == 'LSTM': - return (weight.new(self.nlayers, bsz, self.nhid).zero_(), - weight.new(self.nlayers, bsz, self.nhid).zero_()) - else: - return weight.new(self.nlayers, bsz, self.nhid).zero_() diff --git a/decanlp/models/general_seq2seq.py b/decanlp/models/general_seq2seq.py new file mode 100644 index 00000000..171d5217 --- /dev/null +++ b/decanlp/models/general_seq2seq.py @@ -0,0 +1,59 @@ +# +# Copyright (c) 2018, Salesforce, Inc. +# The Board of Trustees of the Leland Stanford Junior University +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from torch import nn + +from .mqan_encoder import MQANEncoder +from .identity_encoder import IdentityEncoder +from .mqan_decoder import MQANDecoder + +ENCODERS = { + 'MQANEncoder': MQANEncoder, + 'Identity': IdentityEncoder +} +DECODERS = { + 'MQANDecoder': MQANDecoder +} + +class Seq2Seq(nn.Module): + def __init__(self, numericalizer, args, encoder_embeddings, decoder_embeddings): + super().__init__() + self.args = args + + self.encoder = ENCODERS[args.seq2seq_encoder](numericalizer, args, encoder_embeddings) + self.decoder = DECODERS[args.seq2seq_decoder](numericalizer, args, decoder_embeddings) + + def forward(self, batch, iteration): + self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state = self.encoder(batch) + + loss, predictions = self.decoder(batch, self_attended_context, final_context, context_rnn_state, + final_question, question_rnn_state) + + return loss, predictions \ No newline at end of file diff --git a/decanlp/models/identity_encoder.py b/decanlp/models/identity_encoder.py new file mode 100644 index 00000000..69eb2642 --- /dev/null +++ b/decanlp/models/identity_encoder.py @@ -0,0 +1,68 @@ +# +# Copyright (c) 2018, The Board of Trustees of the Leland Stanford Junior University +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from torch import nn + +from .common import CombinedEmbedding + +class IdentityEncoder(nn.Module): + def __init__(self, numericalizer, args, encoder_embeddings): + super().__init__() + self.args = args + self.pad_idx = numericalizer.pad_id + + if sum(emb.dim for emb in encoder_embeddings) != args.dimension: + raise ValueError('Hidden dimension must be equal to the sum of the embedding sizes to use IdentityEncoder') + if args.rnn_layers > 0: + raise ValueError('Cannot have RNN layers with IdentityEncoder') + + self.encoder_embeddings = CombinedEmbedding(numericalizer, encoder_embeddings, args.dimension, + trained_dimension=0, + project=False, + finetune_pretrained=args.train_encoder_embeddings) + + def forward(self, batch): + context, context_lengths = batch.context.value, batch.context.length + question, question_lengths = batch.question.value, batch.question.length + + context_padding = context.data == self.pad_idx + question_padding = question.data == self.pad_idx + + context_embedded = self.encoder_embeddings(context, padding=context_padding) + question_embedded = self.encoder_embeddings(question, padding=question_padding) + + # pick the top-most N transformer layers to pass to the decoder for cross-attention + # (add 1 to account for the embedding layer - the decoder will drop it later) + self_attended_context = context_embedded.all_layers[:-(self.args.transformer_layers+1)] + final_context = context_embedded.last_layer + final_question = question_embedded.last_layer + context_rnn_state = None + question_rnn_state = None + + return self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state \ No newline at end of file diff --git a/decanlp/models/mqan_decoder.py b/decanlp/models/mqan_decoder.py new file mode 100644 index 00000000..233446e7 --- /dev/null +++ b/decanlp/models/mqan_decoder.py @@ -0,0 +1,277 @@ +# +# Copyright (c) 2018, Salesforce, Inc. +# The Board of Trustees of the Leland Stanford Junior University +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from .common import * + + +class MQANDecoder(nn.Module): + def __init__(self, numericalizer, args, decoder_embeddings): + super().__init__() + self.numericalizer = numericalizer + self.pad_idx = numericalizer.pad_id + self.init_idx = numericalizer.init_id + self.args = args + + self.decoder_embeddings = CombinedEmbedding(numericalizer, decoder_embeddings, args.dimension, + trained_dimension=args.trainable_decoder_embeddings, + project=True, + finetune_pretrained=False) + + self.self_attentive_decoder = TransformerDecoder(args.dimension, args.transformer_heads, + args.transformer_hidden, + args.transformer_layers, + args.dropout_ratio) + + if args.rnn_layers > 0: + self.rnn_decoder = LSTMDecoder(args.dimension, args.dimension, + dropout=args.dropout_ratio, num_layers=args.rnn_layers) + switch_input_len = 3 * args.dimension + else: + self.context_attn = LSTMDecoderAttention(args.dimension, dot=True) + self.question_attn = LSTMDecoderAttention(args.dimension, dot=True) + self.dropout = nn.Dropout(args.dropout_ratio) + switch_input_len = 2 * args.dimension + self.vocab_pointer_switch = nn.Sequential(Feedforward(switch_input_len, 1), nn.Sigmoid()) + self.context_question_switch = nn.Sequential(Feedforward(switch_input_len, 1), nn.Sigmoid()) + + self.generative_vocab_size = numericalizer.generative_vocab_size + self.out = nn.Linear(args.dimension, self.generative_vocab_size) + + def set_embeddings(self, embeddings): + if self.decoder_embeddings is not None: + self.decoder_embeddings.set_embeddings(embeddings) + + def forward(self, batch, self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state): + context, context_lengths, context_limited = batch.context.value, batch.context.length, batch.context.limited + question, question_lengths, question_limited = batch.question.value, batch.question.length, batch.question.limited + answer, answer_lengths, answer_limited = batch.answer.value, batch.answer.length, batch.answer.limited + decoder_vocab = batch.decoder_vocab + + self.map_to_full = decoder_vocab.decode + + context_indices = context_limited if context_limited is not None else context + question_indices = question_limited if question_limited is not None else question + answer_indices = answer_limited if answer_limited is not None else answer + + context_padding = context_indices.data == self.pad_idx + question_padding = question_indices.data == self.pad_idx + + if self.args.rnn_layers > 0: + self.rnn_decoder.applyMasks(context_padding, question_padding) + else: + self.context_attn.applyMasks(context_padding) + self.question_attn.applyMasks(question_padding) + + if self.training: + answer_padding = (answer_indices.data == self.pad_idx)[:, :-1] + + answer_embedded = self.decoder_embeddings(answer[:, :-1], padding=answer_padding).last_layer + self_attended_decoded = self.self_attentive_decoder(answer_embedded, + self_attended_context, + context_padding=context_padding, + answer_padding=answer_padding, + positional_encodings=True) + + if self.args.rnn_layers > 0: + rnn_decoder_outputs = self.rnn_decoder(self_attended_decoded, final_context, final_question, + hidden=context_rnn_state) + decoder_output, vocab_pointer_switch_input, context_question_switch_input, context_attention, \ + question_attention, rnn_state = rnn_decoder_outputs + else: + context_decoder_output, context_attention = self.context_attn(self_attended_decoded, final_context) + question_decoder_output, question_attention = self.question_attn(self_attended_decoded, final_question) + + vocab_pointer_switch_input = torch.cat((context_decoder_output, self_attended_decoded), dim=-1) + context_question_switch_input = torch.cat((question_decoder_output, self_attended_decoded), dim=-1) + + decoder_output = self.dropout(context_decoder_output) + + vocab_pointer_switch = self.vocab_pointer_switch(vocab_pointer_switch_input) + context_question_switch = self.context_question_switch(context_question_switch_input) + + probs = self.probs(self.out, decoder_output, vocab_pointer_switch, context_question_switch, + context_attention, question_attention, + context_indices, question_indices, + decoder_vocab) + + probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=self.pad_idx) + loss = F.nll_loss(probs.log(), targets) + return loss, None + + else: + return None, self.greedy(self_attended_context, final_context, final_question, + context_indices, question_indices, + decoder_vocab, rnn_state=context_rnn_state).data + + def probs(self, generator, outputs, vocab_pointer_switches, context_question_switches, + context_attention, question_attention, + context_indices, question_indices, + decoder_vocab): + + size = list(outputs.size()) + + size[-1] = self.generative_vocab_size + scores = generator(outputs.view(-1, outputs.size(-1))).view(size) + p_vocab = F.softmax(scores, dim=scores.dim() - 1) + scaled_p_vocab = vocab_pointer_switches.expand_as(p_vocab) * p_vocab + + effective_vocab_size = len(decoder_vocab) + if self.generative_vocab_size < effective_vocab_size: + size[-1] = effective_vocab_size - self.generative_vocab_size + buff = scaled_p_vocab.new_full(size, EPSILON) + scaled_p_vocab = torch.cat([scaled_p_vocab, buff], dim=buff.dim() - 1) + + # p_context_ptr + scaled_p_vocab.scatter_add_(scaled_p_vocab.dim() - 1, context_indices.unsqueeze(1).expand_as(context_attention), + (context_question_switches * (1 - vocab_pointer_switches)).expand_as( + context_attention) * context_attention) + + # p_question_ptr + scaled_p_vocab.scatter_add_(scaled_p_vocab.dim() - 1, + question_indices.unsqueeze(1).expand_as(question_attention), + ((1 - context_question_switches) * (1 - vocab_pointer_switches)).expand_as( + question_attention) * question_attention) + + return scaled_p_vocab + + def greedy(self, self_attended_context, context, question, context_indices, question_indices, decoder_vocab, + rnn_state=None): + B, TC, C = context.size() + T = self.args.max_output_length + outs = context.new_full((B, T), self.pad_idx, dtype=torch.long) + hiddens = [self_attended_context[0].new_zeros((B, T, C)) + for l in range(len(self.self_attentive_decoder.layers) + 1)] + hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0]) + eos_yet = context.new_zeros((B,)).byte() + + decoder_output = None + for t in range(T): + if t == 0: + init_token = self_attended_context[-1].new_full((B, 1), self.init_idx, + dtype=torch.long) + embedding = self.decoder_embeddings(init_token).last_layer + else: + current_token_id = outs[:, t - 1].unsqueeze(1) + embedding = self.decoder_embeddings(current_token_id).last_layer + + hiddens[0][:, t] = hiddens[0][:, t] + (math.sqrt(self.self_attentive_decoder.d_model) * embedding).squeeze( + 1) + for l in range(len(self.self_attentive_decoder.layers)): + hiddens[l + 1][:, t] = self.self_attentive_decoder.layers[l].feedforward( + self.self_attentive_decoder.layers[l].attention( + self.self_attentive_decoder.layers[l].selfattn(hiddens[l][:, t], hiddens[l][:, :t + 1], + hiddens[l][:, :t + 1]) + , self_attended_context[l], self_attended_context[l])) + + self_attended_decoded = hiddens[-1][:, t].unsqueeze(1) + if self.args.rnn_layers > 0: + rnn_decoder_outputs = self.rnn_decoder(self_attended_decoded, context, question, + hidden=rnn_state, output=decoder_output) + decoder_output, vocab_pointer_switch_input, context_question_switch_input, context_attention, \ + question_attention, rnn_state = rnn_decoder_outputs + else: + context_decoder_output, context_attention = self.context_attn(self_attended_decoded, context) + question_decoder_output, question_attention = self.question_attn(self_attended_decoded, question) + + vocab_pointer_switch_input = torch.cat((context_decoder_output, self_attended_decoded), dim=-1) + context_question_switch_input = torch.cat((question_decoder_output, self_attended_decoded), dim=-1) + + decoder_output = self.dropout(context_decoder_output) + + vocab_pointer_switch = self.vocab_pointer_switch(vocab_pointer_switch_input) + context_question_switch = self.context_question_switch(context_question_switch_input) + + probs = self.probs(self.out, decoder_output, vocab_pointer_switch, context_question_switch, + context_attention, question_attention, + context_indices, question_indices, decoder_vocab) + pred_probs, preds = probs.max(-1) + preds = preds.squeeze(1) + eos_yet = eos_yet | (preds == self.numericalizer.eos_id).byte() + outs[:, t] = preds.cpu().apply_(self.map_to_full) + if eos_yet.all(): + break + return outs + + +class LSTMDecoder(nn.Module): + def __init__(self, d_in, d_hid, dropout=0.0, num_layers=1): + super().__init__() + self.d_hid = d_hid + self.d_in = d_in + self.num_layers = num_layers + self.dropout = nn.Dropout(dropout) + + self.input_feed = True + if self.input_feed: + d_in += 1 * d_hid + + self.rnn = MultiLSTMCell(self.num_layers, d_in, d_hid, dropout) + self.context_attn = LSTMDecoderAttention(d_hid, dot=True) + self.question_attn = LSTMDecoderAttention(d_hid, dot=True) + + def applyMasks(self, context_mask, question_mask): + self.context_attn.applyMasks(context_mask) + self.question_attn.applyMasks(question_mask) + + def forward(self, input : torch.Tensor, context, question, output=None, hidden=None): + context_output = output if output is not None else self.make_init_output(context) + + context_outputs, vocab_pointer_switch_inputs, context_question_switch_inputs, context_attentions, question_attentions = [], [], [], [], [] + for decoder_input in input.split(1, dim=1): + context_output = self.dropout(context_output) + if self.input_feed: + rnn_input = torch.cat([decoder_input, context_output], 2) + else: + rnn_input = decoder_input + + rnn_input = rnn_input.squeeze(1) + dec_state, hidden = self.rnn(rnn_input, hidden) + dec_state = dec_state.unsqueeze(1) + + context_output, context_attention = self.context_attn(dec_state, context) + question_output, question_attention = self.question_attn(dec_state, question) + vocab_pointer_switch_inputs.append(torch.cat([dec_state, context_output, decoder_input], -1)) + context_question_switch_inputs.append(torch.cat([dec_state, question_output, decoder_input], -1)) + + context_output = self.dropout(context_output) + context_outputs.append(context_output) + context_attentions.append(context_attention) + question_attentions.append(question_attention) + + return [torch.cat(x, dim=1) for x in (context_outputs, + vocab_pointer_switch_inputs, + context_question_switch_inputs, + context_attentions, + question_attentions)] + [hidden] + + def make_init_output(self, context): + batch_size = context.size(0) + h_size = (batch_size, 1, self.d_hid) + return context.new_zeros(h_size) diff --git a/decanlp/models/mqan_encoder.py b/decanlp/models/mqan_encoder.py new file mode 100644 index 00000000..73a8bca5 --- /dev/null +++ b/decanlp/models/mqan_encoder.py @@ -0,0 +1,108 @@ +# +# Copyright (c) 2018, Salesforce, Inc. +# The Board of Trustees of the Leland Stanford Junior University +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from .common import * + + +class MQANEncoder(nn.Module): + def __init__(self, numericalizer, args, encoder_embeddings): + super().__init__() + self.args = args + self.pad_idx = numericalizer.pad_id + + self.encoder_embeddings = CombinedEmbedding(numericalizer, encoder_embeddings, args.dimension, + trained_dimension=0, + project=True, + finetune_pretrained=args.train_encoder_embeddings) + + def dp(args): + return args.dropout_ratio if args.rnn_layers > 1 else 0. + + self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension, + batch_first=True, bidirectional=True, num_layers=1, dropout=0) + self.coattention = CoattentiveLayer(args.dimension, dropout=0.3) + dim = 2 * args.dimension + args.dimension + args.dimension + + self.context_bilstm_after_coattention = PackedLSTM(dim, args.dimension, + batch_first=True, dropout=dp(args), bidirectional=True, + num_layers=args.rnn_layers) + self.self_attentive_encoder_context = TransformerEncoder(args.dimension, args.transformer_heads, + args.transformer_hidden, args.transformer_layers, + args.dropout_ratio) + self.bilstm_context = PackedLSTM(args.dimension, args.dimension, + batch_first=True, dropout=dp(args), bidirectional=True, + num_layers=args.rnn_layers) + + self.question_bilstm_after_coattention = PackedLSTM(dim, args.dimension, + batch_first=True, dropout=dp(args), bidirectional=True, + num_layers=args.rnn_layers) + self.self_attentive_encoder_question = TransformerEncoder(args.dimension, args.transformer_heads, + args.transformer_hidden, args.transformer_layers, + args.dropout_ratio) + self.bilstm_question = PackedLSTM(args.dimension, args.dimension, + batch_first=True, dropout=dp(args), bidirectional=True, + num_layers=args.rnn_layers) + + def forward(self, batch): + context, context_lengths = batch.context.value, batch.context.length + question, question_lengths = batch.question.value, batch.question.length + + context_padding = context.data == self.pad_idx + question_padding = question.data == self.pad_idx + + context_embedded = self.encoder_embeddings(context, padding=context_padding).last_layer + question_embedded = self.encoder_embeddings(question, padding=question_padding).last_layer + + context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0] + question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0] + + coattended_context, coattended_question = self.coattention(context_encoded, question_encoded, + context_padding, question_padding) + + context_summary = torch.cat([coattended_context, context_encoded, context_embedded], -1) + condensed_context, _ = self.context_bilstm_after_coattention(context_summary, context_lengths) + self_attended_context = self.self_attentive_encoder_context(condensed_context, padding=context_padding) + final_context, (context_rnn_h, context_rnn_c) = self.bilstm_context(self_attended_context[-1], + context_lengths) + context_rnn_state = [self.reshape_rnn_state(x) for x in (context_rnn_h, context_rnn_c)] + + question_summary = torch.cat([coattended_question, question_encoded, question_embedded], -1) + condensed_question, _ = self.question_bilstm_after_coattention(question_summary, question_lengths) + self_attended_question = self.self_attentive_encoder_question(condensed_question, padding=question_padding) + final_question, (question_rnn_h, question_rnn_c) = self.bilstm_question(self_attended_question[-1], + question_lengths) + question_rnn_state = [self.reshape_rnn_state(x) for x in (question_rnn_h, question_rnn_c)] + + return self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state + + def reshape_rnn_state(self, h): + return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \ + .transpose(1, 2).contiguous() \ + .view(h.size(0) // 2, h.size(1), h.size(2) * 2).contiguous() \ No newline at end of file diff --git a/decanlp/models/multitask_question_answering_network.py b/decanlp/models/multitask_question_answering_network.py deleted file mode 100644 index 70549039..00000000 --- a/decanlp/models/multitask_question_answering_network.py +++ /dev/null @@ -1,420 +0,0 @@ -# -# Copyright (c) 2018, Salesforce, Inc. -# The Board of Trustees of the Leland Stanford Junior University -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# * Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from collections import defaultdict - -from ..util import get_trainable_params - -from .common import * - -class MQANEncoder(nn.Module): - def __init__(self, numericalizer, args): - super().__init__() - self.args = args - self.pad_idx = numericalizer.pad_id - - if self.args.glove_and_char: - self.encoder_embeddings = Embedding(numericalizer, args.dimension, - trained_dimension=0, - dropout=args.dropout_ratio, - project=True, - requires_grad=args.retrain_encoder_embedding) - - def dp(args): - return args.dropout_ratio if args.rnn_layers > 1 else 0. - - self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension, - batch_first=True, bidirectional=True, num_layers=1, dropout=0) - self.coattention = CoattentiveLayer(args.dimension, dropout=0.3) - dim = 2 * args.dimension + args.dimension + args.dimension - - self.context_bilstm_after_coattention = PackedLSTM(dim, args.dimension, - batch_first=True, dropout=dp(args), bidirectional=True, - num_layers=args.rnn_layers) - self.self_attentive_encoder_context = TransformerEncoder(args.dimension, args.transformer_heads, - args.transformer_hidden, args.transformer_layers, - args.dropout_ratio) - self.bilstm_context = PackedLSTM(args.dimension, args.dimension, - batch_first=True, dropout=dp(args), bidirectional=True, - num_layers=args.rnn_layers) - - self.question_bilstm_after_coattention = PackedLSTM(dim, args.dimension, - batch_first=True, dropout=dp(args), bidirectional=True, - num_layers=args.rnn_layers) - self.self_attentive_encoder_question = TransformerEncoder(args.dimension, args.transformer_heads, - args.transformer_hidden, args.transformer_layers, - args.dropout_ratio) - self.bilstm_question = PackedLSTM(args.dimension, args.dimension, - batch_first=True, dropout=dp(args), bidirectional=True, - num_layers=args.rnn_layers) - - def set_embeddings(self, embeddings): - self.encoder_embeddings.set_embeddings(embeddings) - - def forward(self, batch): - context, context_lengths = batch.context.value, batch.context.length - question, question_lengths = batch.question.value, batch.question.length - - context_embedded = self.encoder_embeddings(context) - question_embedded = self.encoder_embeddings(question) - - context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0] - question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0] - - context_padding = context.data == self.pad_idx - question_padding = question.data == self.pad_idx - - coattended_context, coattended_question = self.coattention(context_encoded, question_encoded, - context_padding, question_padding) - - context_summary = torch.cat([coattended_context, context_encoded, context_embedded], -1) - condensed_context, _ = self.context_bilstm_after_coattention(context_summary, context_lengths) - self_attended_context = self.self_attentive_encoder_context(condensed_context, padding=context_padding) - final_context, (context_rnn_h, context_rnn_c) = self.bilstm_context(self_attended_context[-1], - context_lengths) - context_rnn_state = [self.reshape_rnn_state(x) for x in (context_rnn_h, context_rnn_c)] - - question_summary = torch.cat([coattended_question, question_encoded, question_embedded], -1) - condensed_question, _ = self.question_bilstm_after_coattention(question_summary, question_lengths) - self_attended_question = self.self_attentive_encoder_question(condensed_question, padding=question_padding) - final_question, (question_rnn_h, question_rnn_c) = self.bilstm_question(self_attended_question[-1], - question_lengths) - question_rnn_state = [self.reshape_rnn_state(x) for x in (question_rnn_h, question_rnn_c)] - - return self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state - - def reshape_rnn_state(self, h): - return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \ - .transpose(1, 2).contiguous() \ - .view(h.size(0) // 2, h.size(1), h.size(2) * 2).contiguous() - - -class MQANDecoder(nn.Module): - def __init__(self, numericalizer, args, devices): - super().__init__() - self.numericalizer = numericalizer - self.pad_idx = numericalizer.pad_id - self.init_idx = numericalizer.init_id - self.args = args - self.devices = devices - - if args.pretrained_decoder_lm: - pretrained_save_dict = torch.load(os.path.join(args.embeddings, args.pretrained_decoder_lm), map_location=devices[0]) - - self.pretrained_decoder_vocab_itos = pretrained_save_dict['vocab'] - self.pretrained_decoder_vocab_stoi = defaultdict(lambda: 0, { - w: i for i, w in enumerate(self.pretrained_decoder_vocab_itos) - }) - self.pretrained_decoder_embeddings = PretrainedDecoderLM(rnn_type=pretrained_save_dict['settings']['rnn_type'], - ntoken=len(self.pretrained_decoder_vocab_itos), - emsize=pretrained_save_dict['settings']['emsize'], - nhid=pretrained_save_dict['settings']['nhid'], - nlayers=pretrained_save_dict['settings']['nlayers'], - dropout=0.0) - self.pretrained_decoder_embeddings.load_state_dict(pretrained_save_dict['model'], strict=True) - pretrained_lm_params = get_trainable_params(self.pretrained_decoder_embeddings) - for p in pretrained_lm_params: - p.requires_grad = False - - if self.pretrained_decoder_embeddings.nhid != args.dimension: - self.pretrained_decoder_embedding_projection = Feedforward(self.pretrained_decoder_embeddings.nhid, - args.dimension) - else: - self.pretrained_decoder_embedding_projection = None - self.decoder_embeddings = None - else: - self.pretrained_decoder_vocab_itos = None - self.pretrained_decoder_vocab_stoi = None - self.pretrained_decoder_embeddings = None - self.decoder_embeddings = Embedding(self.numericalizer, args.dimension, - include_pretrained=args.glove_decoder, - trained_dimension=args.trainable_decoder_embedding, - dropout=args.dropout_ratio, project=True) - - self.self_attentive_decoder = TransformerDecoder(args.dimension, args.transformer_heads, args.transformer_hidden, args.transformer_layers, args.dropout_ratio) - self.dual_ptr_rnn_decoder = DualPtrRNNDecoder(args.dimension, args.dimension, - dropout=args.dropout_ratio, num_layers=args.rnn_layers) - - self.generative_vocab_size = numericalizer.generative_vocab_size - self.out = nn.Linear(args.dimension, self.generative_vocab_size) - - def set_embeddings(self, embeddings): - if self.decoder_embeddings is not None: - self.decoder_embeddings.set_embeddings(embeddings) - - def forward(self, batch, self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state): - context, context_lengths, context_limited, context_tokens = batch.context.value, batch.context.length, batch.context.limited, batch.context.tokens - question, question_lengths, question_limited, question_tokens = batch.question.value, batch.question.length, batch.question.limited, batch.question.tokens - answer, answer_lengths, answer_limited, answer_tokens = batch.answer.value, batch.answer.length, batch.answer.limited, batch.answer.tokens - decoder_vocab = batch.decoder_vocab - - self.map_to_full = decoder_vocab.decode - - context_indices = context_limited if context_limited is not None else context - question_indices = question_limited if question_limited is not None else question - answer_indices = answer_limited if answer_limited is not None else answer - - context_padding = context_indices.data == self.pad_idx - question_padding = question_indices.data == self.pad_idx - - self.dual_ptr_rnn_decoder.applyMasks(context_padding, question_padding) - - if self.training: - answer_padding = (answer_indices.data == self.pad_idx)[:, :-1] - - if self.args.pretrained_decoder_lm: - # note that pretrained_decoder_embeddings is time first - answer_pretrained_numerical = [ - [self.pretrained_decoder_vocab_stoi[sentence[time]] for sentence in answer_tokens] for time in - range(len(answer_tokens[0])) - ] - answer_pretrained_numerical = torch.tensor(answer_pretrained_numerical, dtype=torch.long) - - with torch.no_grad(): - answer_embedded, _ = self.pretrained_decoder_embeddings.encode(answer_pretrained_numerical) - answer_embedded.transpose_(0, 1) - - if self.pretrained_decoder_embedding_projection is not None: - answer_embedded = self.pretrained_decoder_embedding_projection(answer_embedded) - else: - answer_embedded = self.decoder_embeddings(answer) - self_attended_decoded = self.self_attentive_decoder(answer_embedded[:, :-1].contiguous(), - self_attended_context, context_padding=context_padding, - answer_padding=answer_padding, - positional_encodings=True) - decoder_outputs = self.dual_ptr_rnn_decoder(self_attended_decoded, - final_context, final_question, hidden=context_rnn_state) - rnn_output, context_attention, question_attention, context_alignment, question_alignment, vocab_pointer_switch, context_question_switch, rnn_state = decoder_outputs - - probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch, - context_attention, question_attention, - context_indices, question_indices, - decoder_vocab) - - probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=self.pad_idx) - loss = F.nll_loss(probs.log(), targets) - return loss, None - - else: - return None, self.greedy(self_attended_context, final_context, final_question, - context_indices, question_indices, - decoder_vocab, rnn_state=context_rnn_state).data - - def probs(self, generator, outputs, vocab_pointer_switches, context_question_switches, - context_attention, question_attention, - context_indices, question_indices, - decoder_vocab): - - size = list(outputs.size()) - - size[-1] = self.generative_vocab_size - scores = generator(outputs.view(-1, outputs.size(-1))).view(size) - p_vocab = F.softmax(scores, dim=scores.dim() - 1) - scaled_p_vocab = vocab_pointer_switches.expand_as(p_vocab) * p_vocab - - effective_vocab_size = len(decoder_vocab) - if self.generative_vocab_size < effective_vocab_size: - size[-1] = effective_vocab_size - self.generative_vocab_size - buff = scaled_p_vocab.new_full(size, EPSILON) - scaled_p_vocab = torch.cat([scaled_p_vocab, buff], dim=buff.dim() - 1) - - # p_context_ptr - scaled_p_vocab.scatter_add_(scaled_p_vocab.dim() - 1, context_indices.unsqueeze(1).expand_as(context_attention), - (context_question_switches * (1 - vocab_pointer_switches)).expand_as( - context_attention) * context_attention) - - # p_question_ptr - scaled_p_vocab.scatter_add_(scaled_p_vocab.dim() - 1, - question_indices.unsqueeze(1).expand_as(question_attention), - ((1 - context_question_switches) * (1 - vocab_pointer_switches)).expand_as( - question_attention) * question_attention) - - return scaled_p_vocab - - def greedy(self, self_attended_context, context, question, context_indices, question_indices, decoder_vocab, - rnn_state=None): - B, TC, C = context.size() - T = self.args.max_output_length - outs = context.new_full((B, T), self.pad_idx, dtype=torch.long) - hiddens = [self_attended_context[0].new_zeros((B, T, C)) - for l in range(len(self.self_attentive_decoder.layers) + 1)] - hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0]) - eos_yet = context.new_zeros((B,)).byte() - - pretrained_lm_hidden = None - if self.args.pretrained_decoder_lm: - pretrained_lm_hidden = self.pretrained_decoder_embeddings.init_hidden(B) - rnn_output, context_alignment, question_alignment = None, None, None - for t in range(T): - if t == 0: - if self.args.pretrained_decoder_lm: - init_token = self_attended_context[-1].new_full((1, B), - self.pretrained_decoder_vocab_stoi[self.numericalizer.init_token], - dtype=torch.long) - - # note that pretrained_decoder_embeddings is time first - embedding, pretrained_lm_hidden = self.pretrained_decoder_embeddings.encode(init_token, - pretrained_lm_hidden) - embedding.transpose_(0, 1) - - if self.pretrained_decoder_embedding_projection is not None: - embedding = self.pretrained_decoder_embedding_projection(embedding) - else: - init_token = self_attended_context[-1].new_full((B, 1), self.init_idx, - dtype=torch.long) - embedding = self.decoder_embeddings(init_token, [1] * B) - else: - if self.args.pretrained_decoder_lm: - current_token = [self.numericalizer.decode([x])[0] for x in outs[:, t - 1]] - current_token_id = torch.tensor([[self.pretrained_decoder_vocab_stoi[x] for x in current_token]], - dtype=torch.long, requires_grad=False) - embedding, pretrained_lm_hidden = self.pretrained_decoder_embeddings.encode(current_token_id, - pretrained_lm_hidden) - - # note that pretrained_decoder_embeddings is time first - embedding.transpose_(0, 1) - - if self.pretrained_decoder_embedding_projection is not None: - embedding = self.pretrained_decoder_embedding_projection(embedding) - else: - current_token_id = outs[:, t - 1].unsqueeze(1) - embedding = self.decoder_embeddings(current_token_id, [1] * B) - - hiddens[0][:, t] = hiddens[0][:, t] + (math.sqrt(self.self_attentive_decoder.d_model) * embedding).squeeze( - 1) - for l in range(len(self.self_attentive_decoder.layers)): - hiddens[l + 1][:, t] = self.self_attentive_decoder.layers[l].feedforward( - self.self_attentive_decoder.layers[l].attention( - self.self_attentive_decoder.layers[l].selfattn(hiddens[l][:, t], hiddens[l][:, :t + 1], - hiddens[l][:, :t + 1]) - , self_attended_context[l], self_attended_context[l])) - decoder_outputs = self.dual_ptr_rnn_decoder(hiddens[-1][:, t].unsqueeze(1), - context, question, - context_alignment=context_alignment, - question_alignment=question_alignment, - hidden=rnn_state, output=rnn_output) - rnn_output, context_attention, question_attention, context_alignment, question_alignment, vocab_pointer_switch, context_question_switch, rnn_state = decoder_outputs - probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch, - context_attention, question_attention, - context_indices, question_indices, - decoder_vocab) - pred_probs, preds = probs.max(-1) - preds = preds.squeeze(1) - eos_yet = eos_yet | (preds == self.numericalizer.eos_id).byte() - outs[:, t] = preds.cpu().apply_(self.map_to_full) - if eos_yet.all(): - break - return outs - - -class MultitaskQuestionAnsweringNetwork(nn.Module): - - def __init__(self, numericalizer, args, devices): - super().__init__() - self.args = args - - self.encoder = MQANEncoder(numericalizer, args) - self.decoder = MQANDecoder(numericalizer, args, devices) - - - def set_embeddings(self, embeddings): - self.encoder.set_embeddings(embeddings) - self.decoder.set_embeddings(embeddings) - - def forward(self, batch, iteration): - self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state = self.encoder(batch) - - loss, predictions = self.decoder(batch, self_attended_context, final_context, context_rnn_state, - final_question, question_rnn_state) - - return loss, predictions - - -class DualPtrRNNDecoder(nn.Module): - - def __init__(self, d_in, d_hid, dropout=0.0, num_layers=1): - super().__init__() - self.d_hid = d_hid - self.d_in = d_in - self.num_layers = num_layers - self.dropout = nn.Dropout(dropout) - - self.input_feed = True - if self.input_feed: - d_in += 1 * d_hid - - self.rnn = LSTMDecoder(self.num_layers, d_in, d_hid, dropout) - self.context_attn = LSTMDecoderAttention(d_hid, dot=True) - self.question_attn = LSTMDecoderAttention(d_hid, dot=True) - - self.vocab_pointer_switch = nn.Sequential(Feedforward(2 * self.d_hid + d_in, 1), nn.Sigmoid()) - self.context_question_switch = nn.Sequential(Feedforward(2 * self.d_hid + d_in, 1), nn.Sigmoid()) - - def forward(self, input, context, question, output=None, hidden=None, context_alignment=None, question_alignment=None): - context_output = output.squeeze(1) if output is not None else self.make_init_output(context) - context_alignment = context_alignment if context_alignment is not None else self.make_init_output(context) - question_alignment = question_alignment if question_alignment is not None else self.make_init_output(question) - - context_outputs, vocab_pointer_switches, context_question_switches, context_attentions, question_attentions, context_alignments, question_alignments = [], [], [], [], [], [], [] - for emb_t in input.split(1, dim=1): - emb_t = emb_t.squeeze(1) - context_output = self.dropout(context_output) - if self.input_feed: - emb_t = torch.cat([emb_t, context_output], 1) - dec_state, hidden = self.rnn(emb_t, hidden) - context_output, context_attention, context_alignment = self.context_attn(dec_state, context) - question_output, question_attention, question_alignment = self.question_attn(dec_state, question) - vocab_pointer_switch = self.vocab_pointer_switch(torch.cat([dec_state, context_output, emb_t], -1)) - context_question_switch = self.context_question_switch(torch.cat([dec_state, question_output, emb_t], -1)) - context_output = self.dropout(context_output) - context_outputs.append(context_output) - vocab_pointer_switches.append(vocab_pointer_switch) - context_question_switches.append(context_question_switch) - context_attentions.append(context_attention) - context_alignments.append(context_alignment) - question_attentions.append(question_attention) - question_alignments.append(question_alignment) - - context_outputs, vocab_pointer_switches, context_question_switches, context_attention, question_attention = [self.package_outputs(x) for x in [context_outputs, vocab_pointer_switches, context_question_switches, context_attentions, question_attentions]] - return context_outputs, context_attention, question_attention, context_alignment, question_alignment, vocab_pointer_switches, context_question_switches, hidden - - - def applyMasks(self, context_mask, question_mask): - self.context_attn.applyMasks(context_mask) - self.question_attn.applyMasks(question_mask) - - def make_init_output(self, context): - batch_size = context.size(0) - h_size = (batch_size, self.d_hid) - return context.new_zeros(h_size) - - def package_outputs(self, outputs): - outputs = torch.stack(outputs, dim=1) - return outputs diff --git a/decanlp/predict.py b/decanlp/predict.py index db475574..8b548299 100644 --- a/decanlp/predict.py +++ b/decanlp/predict.py @@ -36,10 +36,9 @@ import sys import logging from pprint import pformat -from .util import set_seed, preprocess_examples, load_config_json, make_data_loader, log_model_size, init_devices, \ - make_numericalizer +from .util import set_seed, preprocess_examples, load_config_json, make_data_loader, log_model_size, init_devices from .metrics import compute_metrics -from .utils.embeddings import load_embeddings +from .data.embeddings import load_embeddings from .tasks.registry import get_tasks from . import models @@ -67,18 +66,16 @@ def get_all_splits(args): return splits -def prepare_data(args, numericalizer): +def prepare_data(args, numericalizer, embeddings): splits = get_all_splits(args) - vectors = load_embeddings(args) logger.info(f'Vocabulary has {numericalizer.num_tokens} tokens from training') - new_vectors = [] + new_words = [] for split in splits: - new_vectors += numericalizer.grow_vocab(split, vectors) - logger.info(f'Vocabulary has expanded to {numericalizer.num_tokens} tokens') - if new_vectors: - # concat the old embedding matrix and all the new vector along the first dimension - new_embedding_matrix = torch.cat([numericalizer.vocab.vectors.cpu()] + new_vectors, dim=0) - numericalizer.vocab.vectors = new_embedding_matrix + new_words += numericalizer.grow_vocab(split) + logger.info(f'Vocabulary has expanded to {numericalizer.num_tokens} tokens') + + for emb in embeddings: + emb.grow_for_vocab(numericalizer.vocab, new_words) return splits @@ -186,17 +183,20 @@ def main(argv=sys.argv): devices = init_devices(args) save_dict = torch.load(args.best_checkpoint, map_location=devices[0]) - numericalizer = make_numericalizer(args) + numericalizer, encoder_embeddings, decoder_embeddings = load_embeddings(args.embeddings, args.encoder_embeddings, + args.decoder_embeddings, + args.max_generative_vocab, + logger) numericalizer.load(args.path) + for emb in set(encoder_embeddings + decoder_embeddings): + emb.init_for_vocab(numericalizer.vocab) logger.info(f'Initializing Model') Model = getattr(models, args.model) - model = Model(numericalizer, args, devices) + model = Model(numericalizer, args, encoder_embeddings, decoder_embeddings) model_dict = save_dict['model_state_dict'] model.load_state_dict(model_dict) - splits = prepare_data(args, numericalizer) - if args.model != 'MultiLingualTranslationModel': - model.set_embeddings(numericalizer.vocab.vectors) + splits = prepare_data(args, numericalizer, set(encoder_embeddings + decoder_embeddings)) run(args, numericalizer, splits, model, devices[0]) diff --git a/decanlp/server.py b/decanlp/server.py index 05b98377..08f62869 100644 --- a/decanlp/server.py +++ b/decanlp/server.py @@ -32,17 +32,15 @@ from argparse import ArgumentParser import ujson as json import torch -import numpy as np -import random import asyncio import logging import sys from pprint import pformat from .data.example import Batch -from .util import set_seed, init_devices, load_config_json, log_model_size, make_numericalizer +from .util import set_seed, init_devices, load_config_json, log_model_size from . import models -from .utils.embeddings import load_embeddings +from .data.embeddings import load_embeddings from .tasks.registry import get_tasks from .tasks.generic_dataset import Example @@ -52,27 +50,24 @@ class ProcessedExample(): pass class Server(): - def __init__(self, args, numericalizer, model, device): + def __init__(self, args, numericalizer, embeddings, model, device): self.args = args self.device = device self.numericalizer = numericalizer self.model = model logger.info(f'Vocabulary has {numericalizer.num_tokens} tokens from training') - self._vector_collections = load_embeddings(args) + self._embeddings = embeddings self._cached_tasks = dict() def numericalize_example(self, ex): - new_vectors = self.numericalizer.grow_vocab([ex], self._vector_collections) - if new_vectors: - # concat the old embedding matrix and all the new vector along the first dimension - new_embedding_matrix = torch.cat([self.numericalizer.vocab.vectors.cpu()] + new_vectors, dim=0) - self.numericalizer.vocab.vectors = new_embedding_matrix - self.model.set_embeddings(new_embedding_matrix) + new_words = self.numericalizer.grow_vocab([ex]) + for emb in self._embeddings: + emb.grow_for_vocab(self.numericalizer.vocab, new_words) # batch of size 1 - return Batch.from_examples([ex], self.numericalizer, self.numericalizer.decoder_vocab, device=self.device) + return Batch.from_examples([ex], self.numericalizer, device=self.device) def handle_request(self, line): request = json.loads(line) @@ -174,15 +169,19 @@ def main(argv=sys.argv): devices = init_devices(args) save_dict = torch.load(args.best_checkpoint, map_location=devices[0]) - numericalizer = make_numericalizer(args) + numericalizer, encoder_embeddings, decoder_embeddings = load_embeddings(args.embeddings, args.encoder_embeddings, + args.decoder_embeddings, + args.max_generative_vocab) numericalizer.load(args.path) + for emb in set(encoder_embeddings + decoder_embeddings): + emb.init_for_vocab(numericalizer.vocab) logger.info(f'Initializing Model') Model = getattr(models, args.model) - model = Model(numericalizer, args, devices) + model = Model(numericalizer, args, encoder_embeddings, decoder_embeddings) model_dict = save_dict['model_state_dict'] model.load_state_dict(model_dict) - server = Server(args, numericalizer, model, devices[0]) + server = Server(args, numericalizer, encoder_embeddings + decoder_embeddings, model, devices[0]) server.run() diff --git a/decanlp/train.py b/decanlp/train.py index 82fbec4f..c82f48c0 100644 --- a/decanlp/train.py +++ b/decanlp/train.py @@ -47,9 +47,9 @@ from . import arguments from . import models from .validate import validate from .util import elapsed_time, set_seed, preprocess_examples, get_trainable_params, make_data_loader, log_model_size, \ - init_devices, make_numericalizer + init_devices from .utils.saver import Saver -from .utils.embeddings import load_embeddings +from .data.embeddings import load_embeddings def initialize_logger(args): @@ -111,16 +111,21 @@ def prepare_data(args, logger): if args.vocab_tasks is not None and task.name in args.vocab_tasks: vocab_sets.extend(split) - numericalizer = make_numericalizer(args) + numericalizer, encoder_embeddings, decoder_embeddings = load_embeddings(args.embeddings, args.encoder_embeddings, + args.decoder_embeddings, + args.max_generative_vocab, + logger) if args.load is not None: numericalizer.load(args.save) else: - vectors = load_embeddings(args, logger) vocab_sets = (train_sets + val_sets) if len(vocab_sets) == 0 else vocab_sets logger.info(f'Building vocabulary') - numericalizer.build_vocab(vectors, Example.vocab_fields, vocab_sets) + numericalizer.build_vocab(Example.vocab_fields, vocab_sets) numericalizer.save(args.save) + for vec in set(encoder_embeddings + decoder_embeddings): + vec.init_for_vocab(numericalizer.vocab) + logger.info(f'Vocabulary has {numericalizer.num_tokens} tokens') logger.debug(f'The first 200 tokens:') logger.debug(numericalizer.vocab.itos[:200]) @@ -133,7 +138,7 @@ def prepare_data(args, logger): logger.info('Preprocessing validation data') preprocess_examples(args, args.val_tasks, val_sets, logger, train=args.val_filter) - return numericalizer, train_sets, val_sets, aux_sets + return numericalizer, encoder_embeddings, decoder_embeddings, train_sets, val_sets, aux_sets def get_learning_rate(i, args): @@ -392,11 +397,11 @@ def train(args, devices, model, opt, train_sets, train_iterations, numericalizer break -def init_model(args, numericalizer, devices, logger): +def init_model(args, numericalizer, encoder_embeddings, decoder_embeddings, devices, logger): model_name = args.model logger.info(f'Initializing {model_name}') Model = getattr(models, model_name) - model = Model(numericalizer, args, devices) + model = Model(numericalizer, args, encoder_embeddings, decoder_embeddings) params = get_trainable_params(model) log_model_size(logger, model, model_name) @@ -432,7 +437,7 @@ def main(argv=sys.argv): if args.load is not None: logger.info(f'Loading vocab from {os.path.join(args.save, args.load)}') save_dict = torch.load(os.path.join(args.save, args.load)) - numericalizer, train_sets, val_sets, aux_sets = prepare_data(args, logger) + numericalizer, encoder_embeddings, decoder_embeddings, train_sets, val_sets, aux_sets = prepare_data(args, logger) if (args.use_curriculum and aux_sets is None) or (not args.use_curriculum and len(aux_sets)): logging.error('sth unpleasant is happening with curriculum') @@ -445,7 +450,7 @@ def main(argv=sys.argv): else: writer = None - model = init_model(args, numericalizer, devices, logger) + model = init_model(args, numericalizer, encoder_embeddings, decoder_embeddings, devices, logger) opt = init_opt(args, model) start_iteration = 1 diff --git a/decanlp/util.py b/decanlp/util.py index 8bd67b58..c521c835 100644 --- a/decanlp/util.py +++ b/decanlp/util.py @@ -187,34 +187,23 @@ def load_config_json(args): args.almond_type_embeddings = False with open(os.path.join(args.path, 'config.json')) as config_file: config = json.load(config_file) - retrieve = ['model', 'transformer_layers', 'rnn_layers', 'transformer_hidden', 'dimension', - 'load', 'max_val_context_length', 'val_batch_size', 'transformer_heads', 'max_output_length', - 'max_effective_vocab', 'max_generative_vocab', 'lower', 'glove_and_char', - 'small_glove', 'almond_type_embeddings', 'almond_grammar', - 'trainable_decoder_embedding', 'glove_decoder', 'pretrained_decoder_lm', - 'retrain_encoder_embedding', 'question', 'locale', 'use_google_translate'] + retrieve = ['model', 'seq2seq_encoder', 'seq2seq_decoder', 'transformer_layers', 'rnn_layers', + 'transformer_hidden', 'dimension', 'load', 'max_val_context_length', 'val_batch_size', + 'transformer_heads', 'max_output_length', 'max_generative_vocab', 'lower', 'encoder_embeddings', + 'decoder_embeddings', 'trainable_decoder_embeddings', 'train_encoder_embeddings', + 'question', 'locale', 'use_google_translate'] for r in retrieve: if r in config: setattr(args, r, config[r]) elif r == 'locale': setattr(args, r, 'en') - elif r in ('small_glove', 'almond_type_embbedings'): - setattr(args, r, False) - elif r in ('glove_decoder', 'glove_and_char'): - setattr(args, r, True) elif r == 'trainable_decoder_embedding': setattr(args, r, 0) - elif r == 'retrain_encoder_embedding': + elif r == 'train_encoder_embedding': setattr(args, r, False) else: setattr(args, r, None) args.dropout_ratio = 0.0 - args.best_checkpoint = os.path.join(args.path, args.checkpoint_name) - - -def make_numericalizer(args): - return SimpleNumericalizer(max_effective_vocab=args.max_effective_vocab, - max_generative_vocab=args.max_generative_vocab, - pad_first=False) \ No newline at end of file + args.best_checkpoint = os.path.join(args.path, args.checkpoint_name) \ No newline at end of file diff --git a/tests/test.sh b/tests/test.sh index b6dc6dc3..5f7d9368 100755 --- a/tests/test.sh +++ b/tests/test.sh @@ -23,10 +23,12 @@ workdir=`mktemp -d $TMPDIR/decaNLP-tests-XXXXXX` trap on_error ERR INT TERM i=0 -for hparams in "" "--use_curriculum"; do +for hparams in "--encoder_embeddings=small_glove+char --decoder_embeddings=small_glove+char" \ + "--encoder_embeddings=bert-base-uncased --decoder_embeddings= --trainable_decoder_embeddings=50" \ + "--encoder_embeddings=bert-base-uncased --decoder_embeddings= --trainable_decoder_embeddings=50 --seq2seq_encoder=Identity --dimension=768 --rnn_layers=0" ; do # train - pipenv run python3 -m decanlp train --train_tasks almond --train_iterations 4 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --root "" --embeddings $SRCDIR/embeddings --small_glove --no_commit + pipenv run python3 -m decanlp train --train_tasks almond --train_iterations 6 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --root "" --embeddings $SRCDIR/embeddings --no_commit # greedy decode pipenv run python3 -m decanlp predict --tasks almond --evaluate test --path $workdir/model_$i --overwrite --eval_dir $workdir/model_$i/eval_results/ --data $SRCDIR/dataset/ --embeddings $SRCDIR/embeddings