genienlp/decanlp/models/common.py

617 lines
24 KiB
Python

#
# Copyright (c) 2018, Salesforce, Inc.
# The Board of Trustees of the Leland Stanford Junior University
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
import math
import os
import sys
import numpy as np
import torch.nn as nn
from torch.nn.utils.rnn import pad_packed_sequence as unpack
from torch.nn.utils.rnn import pack_padded_sequence as pack
INF = 1e10
EPSILON = 1e-10
class LSTMDecoder(nn.Module):
def __init__(self, num_layers, input_size, rnn_size, dropout):
super(LSTMDecoder, self).__init__()
self.dropout = nn.Dropout(dropout)
self.num_layers = num_layers
self.layers = nn.ModuleList()
for i in range(num_layers):
self.layers.append(nn.LSTMCell(input_size, rnn_size))
input_size = rnn_size
def forward(self, input, hidden):
h_0, c_0 = hidden
h_1, c_1 = [], []
for i, layer in enumerate(self.layers):
input = self.dropout(input)
h_1_i, c_1_i = layer(input, (h_0[i], c_0[i]))
input = h_1_i
h_1 += [h_1_i]
c_1 += [c_1_i]
h_1 = torch.stack(h_1)
c_1 = torch.stack(c_1)
return input, (h_1, c_1)
def max_margin_loss(probs, targets, pad_idx=1):
batch_size, max_length, depth = probs.size()
targets_mask = (targets != pad_idx).float()
flat_mask = targets_mask.view(batch_size*max_length,)
flat_preds = probs.view(batch_size*max_length, depth)
one_hot = torch.zeros_like(probs)
one_hot_gold = one_hot.scatter_(2, targets.unsqueeze(2), 1)
marginal_scores = probs - one_hot_gold + 1
marginal_scores = marginal_scores.view(batch_size*max_length, depth)
max_margin = torch.max(marginal_scores, dim=1)[0]
gold_score = torch.masked_select(flat_preds, one_hot_gold.view(batch_size*max_length, depth).byte())
margin = max_margin - gold_score
return torch.sum(margin*flat_mask) + 1e-8
def positional_encodings_like(x, t=None):
if t is None:
positions = torch.arange(0., x.size(1))
if x.is_cuda:
positions = positions.cuda(x.get_device())
else:
positions = t
encodings = torch.zeros(*x.size()[1:])
if x.is_cuda:
encodings = encodings.cuda(x.get_device())
for channel in range(x.size(-1)):
if channel % 2 == 0:
encodings[:, channel] = torch.sin(
positions / 10000 ** (channel / x.size(2)))
else:
encodings[:, channel] = torch.cos(
positions / 10000 ** ((channel - 1) / x.size(2)))
return Variable(encodings)
# torch.matmul can't do (4, 3, 2) @ (4, 2) -> (4, 3)
def matmul(x, y):
if x.dim() == y.dim():
return x @ y
if x.dim() == y.dim() - 1:
return (x.unsqueeze(-2) @ y).squeeze(-2)
return (x @ y.unsqueeze(-2)).squeeze(-2)
def pad_to_match(x, y):
x_len, y_len = x.size(1), y.size(1)
if x_len == y_len:
return x, y
extra = x.new_ones((x.size(0), abs(y_len - x_len)))
if x_len < y_len:
return torch.cat((x, extra), 1), y
return x, torch.cat((y, extra), 1)
class LayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-6):
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta
class ResidualBlock(nn.Module):
def __init__(self, layer, d_model, dropout_ratio):
super().__init__()
self.layer = layer
self.dropout = nn.Dropout(dropout_ratio)
self.layernorm = LayerNorm(d_model)
def forward(self, *x, padding=None):
return self.layernorm(x[0] + self.dropout(self.layer(*x, padding=padding)))
class Attention(nn.Module):
def __init__(self, d_key, dropout_ratio, causal):
super().__init__()
self.scale = math.sqrt(d_key)
self.dropout = nn.Dropout(dropout_ratio)
self.causal = causal
def forward(self, query, key, value, padding=None):
dot_products = matmul(query, key.transpose(1, 2))
if query.dim() == 3 and self.causal:
tri = key.new_ones((key.size(1), key.size(1))).triu(1) * INF
dot_products.sub_(tri.unsqueeze(0))
if not padding is None:
dot_products.masked_fill_(padding.unsqueeze(1).expand_as(dot_products), -INF)
return matmul(self.dropout(F.softmax(dot_products / self.scale, dim=-1)), value)
class MultiHead(nn.Module):
def __init__(self, d_key, d_value, n_heads, dropout_ratio, causal=False):
super().__init__()
self.attention = Attention(d_key, dropout_ratio, causal=causal)
self.wq = Linear(d_key, d_key, bias=False)
self.wk = Linear(d_key, d_key, bias=False)
self.wv = Linear(d_value, d_value, bias=False)
self.n_heads = n_heads
def forward(self, query, key, value, padding=None):
query, key, value = self.wq(query), self.wk(key), self.wv(value)
query, key, value = (
x.chunk(self.n_heads, -1) for x in (query, key, value))
return torch.cat([self.attention(q, k, v, padding=padding)
for q, k, v in zip(query, key, value)], -1)
class LinearReLU(nn.Module):
def __init__(self, d_model, d_hidden):
super().__init__()
self.feedforward = Feedforward(d_model, d_hidden, activation='relu')
self.linear = Linear(d_hidden, d_model)
def forward(self, x, padding=None):
return self.linear(self.feedforward(x))
class TransformerEncoderLayer(nn.Module):
def __init__(self, dimension, n_heads, hidden, dropout):
super().__init__()
self.selfattn = ResidualBlock(
MultiHead(
dimension, dimension, n_heads, dropout),
dimension, dropout)
self.feedforward = ResidualBlock(
LinearReLU(dimension, hidden),
dimension, dropout)
def forward(self, x, padding=None):
return self.feedforward(self.selfattn(x, x, x, padding=padding))
class TransformerEncoder(nn.Module):
def __init__(self, dimension, n_heads, hidden, num_layers, dropout):
super().__init__()
self.layers = nn.ModuleList(
[TransformerEncoderLayer(dimension, n_heads, hidden, dropout) for i in range(num_layers)])
self.dropout = nn.Dropout(dropout)
def forward(self, x, padding=None):
x = self.dropout(x)
encoding = [x]
for layer in self.layers:
x = layer(x, padding=padding)
encoding.append(x)
return encoding
class TransformerDecoderLayer(nn.Module):
def __init__(self, dimension, n_heads, hidden, dropout, causal=True):
super().__init__()
self.selfattn = ResidualBlock(
MultiHead(dimension, dimension, n_heads,
dropout, causal),
dimension, dropout)
self.attention = ResidualBlock(
MultiHead(dimension, dimension, n_heads,
dropout),
dimension, dropout)
self.feedforward = ResidualBlock(
LinearReLU(dimension, hidden),
dimension, dropout)
def forward(self, x, encoding, context_padding=None, answer_padding=None):
x = self.selfattn(x, x, x, padding=answer_padding)
return self.feedforward(self.attention(x, encoding, encoding, padding=context_padding))
class TransformerDecoder(nn.Module):
def __init__(self, dimension, n_heads, hidden, num_layers, dropout, causal=True):
super().__init__()
self.layers = nn.ModuleList(
[TransformerDecoderLayer(dimension, n_heads, hidden, dropout, causal=causal) for i in range(num_layers)])
self.dropout = nn.Dropout(dropout)
self.d_model = dimension
def forward(self, x, encoding, context_padding=None, positional_encodings=True, answer_padding=None):
if positional_encodings:
x = x + positional_encodings_like(x)
x = self.dropout(x)
for layer, enc in zip(self.layers, encoding[1:]):
x = layer(x, enc, context_padding=context_padding, answer_padding=answer_padding)
return x
def mask(targets, out, squash=True, pad_idx=1):
mask = (targets != pad_idx)
out_mask = mask.unsqueeze(-1).expand_as(out).contiguous()
if squash:
out_after = out[out_mask].contiguous().view(-1, out.size(-1))
else:
out_after = out * out_mask.float()
targets_after = targets[mask]
return out_after, targets_after
class Highway(torch.nn.Module):
def __init__(self, d_in, activation='relu', n_layers=1):
super(Highway, self).__init__()
self.d_in = d_in
self._layers = torch.nn.ModuleList([Linear(d_in, 2 * d_in) for _ in range(n_layers)])
for layer in self._layers:
layer.bias[d_in:].fill_(1)
self.activation = getattr(F, activation)
def forward(self, inputs):
current_input = inputs
for layer in self._layers:
projected_input = layer(current_input)
linear_part = current_input
nonlinear_part = projected_input[:, :self.d_in] if projected_input.dim() == 2 else projected_input[:, :, :self.d_in]
nonlinear_part = self.activation(nonlinear_part)
gate = projected_input[:, self.d_in:(2 * self.d_in)] if projected_input.dim() == 2 else projected_input[:, :, self.d_in:(2 * self.d_in)]
gate = F.sigmoid(gate)
current_input = gate * linear_part + (1 - gate) * nonlinear_part
return current_input
class LinearFeedforward(nn.Module):
def __init__(self, d_in, d_hid, d_out, activation='relu'):
super().__init__()
self.feedforward = Feedforward(d_in, d_hid, activation=activation)
self.linear = Linear(d_hid, d_out)
self.dropout = nn.Dropout(0.2)
def forward(self, x):
return self.dropout(self.linear(self.feedforward(x)))
class PackedLSTM(nn.Module):
def __init__(self, d_in, d_out, bidirectional=False, num_layers=1,
dropout=0.0, batch_first=True):
"""A wrapper class that packs input sequences and unpacks output sequences"""
super().__init__()
if bidirectional:
d_out = d_out // 2
self.rnn = nn.LSTM(d_in, d_out,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional,
batch_first=batch_first)
self.batch_first = batch_first
def forward(self, inputs, lengths, hidden=None):
lens, indices = torch.sort(inputs.new_tensor(lengths, dtype=torch.long), 0, True)
inputs = inputs[indices] if self.batch_first else inputs[:, indices]
outputs, (h, c) = self.rnn(pack(inputs, lens.tolist(),
batch_first=self.batch_first), hidden)
outputs = unpack(outputs, batch_first=self.batch_first)[0]
_, _indices = torch.sort(indices, 0)
outputs = outputs[_indices] if self.batch_first else outputs[:, _indices]
h, c = h[:, _indices, :], c[:, _indices, :]
return outputs, (h, c)
class Linear(nn.Linear):
def forward(self, x):
size = x.size()
return super().forward(
x.contiguous().view(-1, size[-1])).view(*size[:-1], -1)
class Feedforward(nn.Module):
def __init__(self, d_in, d_out, activation=None, bias=True, dropout=0.2):
super().__init__()
if activation is not None:
self.activation = getattr(torch, activation)
else:
self.activation = lambda x: x
self.linear = Linear(d_in, d_out, bias=bias)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.activation(self.linear(self.dropout(x)))
class Embedding(nn.Module):
def __init__(self, field, output_dimension, include_pretrained=True, trained_dimension=0, dropout=0.0, project=True, requires_grad=False):
super().__init__()
self.field = field
self.project = project
self.requires_grad = requires_grad
dimension = 0
pretrained_dimension = field.vocab.vectors.size(-1)
if include_pretrained:
# NOTE: this must be a list so that pytorch will not iterate into the module when
# traversing this module
# in turn, this means that moving this Embedding() to the GPU will not move the
# actual embedding, which will stay on CPU; this is necessary because a) we call
# set_embeddings() sometimes with CPU-only tensors, and b) the embedding tensor
# is too big for the GPU anyway
self.pretrained_embeddings = [nn.Embedding(len(field.vocab), pretrained_dimension)]
self.pretrained_embeddings[0].weight.data = field.vocab.vectors
self.pretrained_embeddings[0].weight.requires_grad = self.requires_grad
dimension += pretrained_dimension
else:
self.pretrained_embeddings = None
# OTOH, if we have a trained embedding, we move it around together with the module
# (ie, potentially on GPU), because the saving when applying gradient outweights
# the cost, and hopefully the embedding is small enough to fit in GPU memory
if trained_dimension > 0:
self.trained_embeddings = nn.Embedding(len(field.vocab), trained_dimension)
dimension += trained_dimension
else:
self.trained_embeddings = None
if self.project:
self.projection = Feedforward(dimension, output_dimension)
self.dropout = nn.Dropout(dropout)
self.dimension = output_dimension
def forward(self, x, lengths=None, device=-1):
if self.pretrained_embeddings is not None:
pretrained_embeddings = self.pretrained_embeddings[0](x.cpu()).to(x.device).detach()
else:
pretrained_embeddings = None
if self.trained_embeddings is not None:
trained_vocabulary_size = self.trained_embeddings.weight.size()[0]
valid_x = torch.lt(x, trained_vocabulary_size)
masked_x = torch.where(valid_x, x, torch.zeros_like(x))
trained_embeddings = self.trained_embeddings(masked_x)
else:
trained_embeddings = None
if pretrained_embeddings is not None and trained_embeddings is not None:
embeddings = torch.cat((pretrained_embeddings, trained_embeddings), dim=2)
elif pretrained_embeddings is not None:
embeddings = pretrained_embeddings
else:
embeddings = trained_embeddings
return self.projection(embeddings) if self.project else embeddings
def set_embeddings(self, w):
if self.pretrained_embeddings is not None:
self.pretrained_embeddings[0].weight.data = w
self.pretrained_embeddings[0].weight.requires_grad = self.requires_grad
class SemanticFusionUnit(nn.Module):
def __init__(self, d, l):
super().__init__()
self.r_hat = Feedforward(d*l, d, 'tanh')
self.g = Feedforward(d*l, d, 'sigmoid')
self.dropout = nn.Dropout(0.2)
def forward(self, x):
c = self.dropout(torch.cat(x, -1))
r_hat = self.r_hat(c)
g = self.g(c)
o = g * r_hat + (1 - g) * x[0]
return o
class LSTMDecoderAttention(nn.Module):
def __init__(self, dim, dot=False):
super().__init__()
self.linear_in = nn.Linear(dim, dim, bias=False)
self.linear_out = nn.Linear(2 * dim, dim, bias=False)
self.tanh = nn.Tanh()
self.mask = None
self.dot = dot
def applyMasks(self, context_mask):
self.context_mask = context_mask
def forward(self, input, context):
if not self.dot:
targetT = self.linear_in(input).unsqueeze(2) # batch x dim x 1
else:
targetT = input.unsqueeze(2)
context_scores = torch.bmm(context, targetT).squeeze(2)
context_scores.masked_fill_(self.context_mask, -float('inf'))
context_attention = F.softmax(context_scores, dim=-1) + EPSILON
context_alignment = torch.bmm(context_attention.unsqueeze(1), context).squeeze(1)
combined_representation = torch.cat([input, context_alignment], 1)
output = self.tanh(self.linear_out(combined_representation))
return output, context_attention, context_alignment
class CoattentiveLayer(nn.Module):
def __init__(self, d, dropout=0.2):
super().__init__()
self.proj = Feedforward(d, d, dropout=0.0)
self.embed_sentinel = nn.Embedding(2, d)
self.dropout = nn.Dropout(dropout)
def forward(self, context, question, context_padding, question_padding):
context_padding = torch.cat([context.new_zeros((context.size(0), 1), dtype=torch.long)==1, context_padding], 1)
question_padding = torch.cat([question.new_zeros((question.size(0), 1), dtype=torch.long)==1, question_padding], 1)
context_sentinel = self.embed_sentinel(context.new_zeros((context.size(0), 1), dtype=torch.long))
context = torch.cat([context_sentinel, self.dropout(context)], 1) # batch_size x (context_length + 1) x features
question_sentinel = self.embed_sentinel(question.new_ones((question.size(0), 1), dtype=torch.long))
question = torch.cat([question_sentinel, question], 1) # batch_size x (question_length + 1) x features
question = torch.tanh(self.proj(question)) # batch_size x (question_length + 1) x features
affinity = context.bmm(question.transpose(1,2)) # batch_size x (context_length + 1) x (question_length + 1)
attn_over_context = self.normalize(affinity, context_padding) # batch_size x (context_length + 1) x 1
attn_over_question = self.normalize(affinity.transpose(1,2), question_padding) # batch_size x (question_length + 1) x 1
sum_of_context = self.attn(attn_over_context, context) # batch_size x (question_length + 1) x features
sum_of_question = self.attn(attn_over_question, question) # batch_size x (context_length + 1) x features
coattn_context = self.attn(attn_over_question, sum_of_context) # batch_size x (context_length + 1) x features
coattn_question = self.attn(attn_over_context, sum_of_question) # batch_size x (question_length + 1) x features
return torch.cat([coattn_context, sum_of_question], 2)[:, 1:], torch.cat([coattn_question, sum_of_context], 2)[:, 1:]
@staticmethod
def attn(weights, candidates):
w1, w2, w3 = weights.size()
c1, c2, c3 = candidates.size()
return weights.unsqueeze(3).expand(w1, w2, w3, c3).mul(candidates.unsqueeze(2).expand(c1, c2, w3, c3)).sum(1).squeeze(1)
@staticmethod
def normalize(original, padding):
raw_scores = original.clone()
raw_scores.masked_fill_(padding.unsqueeze(-1).expand_as(raw_scores), -INF)
return F.softmax(raw_scores, dim=1)
# The following code was copied and adapted from github.com/floyhub/world-language-model
#
# BSD 3-Clause License
#
# Copyright (c) 2017,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
class PretrainedDecoderLM(nn.Module):
"""Container module with an encoder, a recurrent module, and a decoder."""
def __init__(self, rnn_type, ntoken, emsize, nhid, nlayers, dropout=0.5, tie_weights=False):
super(PretrainedDecoderLM, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, emsize) # Token2Embeddings
if rnn_type in ['LSTM', 'GRU']:
self.rnn = getattr(nn, rnn_type)(emsize, nhid, nlayers, dropout=dropout)
else:
try:
nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
except KeyError:
raise ValueError( """An invalid option for `--model` was supplied,
options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""")
self.rnn = nn.RNN(emsize, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)
self.decoder = nn.Linear(nhid, ntoken)
# Optionally tie weights as in:
# "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
# https://arxiv.org/abs/1608.05859
# and
# "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
# https://arxiv.org/abs/1611.01462
if tie_weights:
if nhid != emsize:
raise ValueError('When using the tied flag, nhid must be equal to emsize')
self.decoder.weight = self.encoder.weight
self.init_weights()
self.rnn_type = rnn_type
self.nhid = nhid
self.nlayers = nlayers
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.fill_(0)
self.decoder.weight.data.uniform_(-initrange, initrange)
def encode(self, input, hidden=None):
emb = self.drop(self.encoder(input))
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
return output, hidden
def forward(self, input, hidden=None):
encoded, hidden = self.encode(input, hidden)
decoded = self.decoder(encoded.view(encoded.size(0)*encoded.size(1), encoded.size(2)))
return decoded.view(encoded.size(0), encoded.size(1), decoded.size(1)), hidden
def init_hidden(self, bsz):
weight = next(self.parameters()).data
if self.rnn_type == 'LSTM':
return (weight.new(self.nlayers, bsz, self.nhid).zero_(),
weight.new(self.nlayers, bsz, self.nhid).zero_())
else:
return weight.new(self.nlayers, bsz, self.nhid).zero_()