167 lines
5.2 KiB
Python
167 lines
5.2 KiB
Python
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
from test_tube import HyperOptArgumentParser
|
|
import torch
|
|
from torch.autograd import Variable
|
|
from sklearn.metrics import confusion_matrix, f1_score
|
|
from torch.nn import functional as F
|
|
|
|
|
|
class BiLSTMPack(nn.Module):
|
|
"""
|
|
Sample model to show how to define a template
|
|
"""
|
|
def __init__(self, hparams):
|
|
# init superclass
|
|
super(BiLSTMPack, self).__init__(hparams)
|
|
|
|
self.hidden = None
|
|
|
|
# trigger tag building
|
|
self.ner_tagset = {'O': 0, 'I-Bio': 1}
|
|
self.nb_tags = len(self.ner_tagset)
|
|
|
|
# build model
|
|
print('building model...')
|
|
if hparams.model_load_weights_path is None:
|
|
self.__build_model()
|
|
print('model built')
|
|
else:
|
|
self = BiLSTMPack.load(hparams.model_load_weights_path, hparams.on_gpu, hparams)
|
|
print('model loaded from: {}'.format(hparams.model_load_weights_path))
|
|
|
|
def __build_model(self):
|
|
"""
|
|
Layout model
|
|
:return:
|
|
"""
|
|
# design the number of final units
|
|
self.output_dim = self.hparams.nb_lstm_units
|
|
|
|
# when it's bidirectional our weights double
|
|
if self.hparams.bidirectional:
|
|
self.output_dim *= 2
|
|
|
|
# total number of words
|
|
total_words = len(self.tng_dataloader.dataset.words_token_to_idx)
|
|
|
|
# word embeddings
|
|
self.word_embedding = nn.Embedding(
|
|
num_embeddings=total_words + 1,
|
|
embedding_dim=self.hparams.embedding_dim,
|
|
padding_idx=0
|
|
)
|
|
|
|
# design the LSTM
|
|
self.lstm = nn.LSTM(
|
|
self.hparams.embedding_dim,
|
|
self.hparams.nb_lstm_units,
|
|
num_layers=self.hparams.nb_lstm_layers,
|
|
bidirectional=self.hparams.bidirectional,
|
|
dropout=self.hparams.drop_prob,
|
|
batch_first=True,
|
|
)
|
|
|
|
# map to tag space
|
|
self.fc_out = nn.Linear(self.output_dim, self.out_dim)
|
|
self.hidden_to_tag = nn.Linear(self.output_dim, self.nb_tags)
|
|
|
|
|
|
def init_hidden(self, batch_size):
|
|
|
|
# the weights are of the form (nb_layers * 2 if bidirectional, batch_size, nb_lstm_units)
|
|
mult = 2 if self.hparams.bidirectional else 1
|
|
hidden_a = torch.randn(self.hparams.nb_layers * mult, batch_size, self.nb_rnn_units)
|
|
hidden_b = torch.randn(self.hparams.nb_layers * mult, batch_size, self.nb_rnn_units)
|
|
|
|
if self.hparams.on_gpu:
|
|
hidden_a = hidden_a.cuda()
|
|
hidden_b = hidden_b.cuda()
|
|
|
|
hidden_a = Variable(hidden_a)
|
|
hidden_b = Variable(hidden_b)
|
|
|
|
return (hidden_a, hidden_b)
|
|
|
|
def forward(self, model_in):
|
|
# layout data (expand it, etc...)
|
|
# x = sequences
|
|
x, seq_lengths = model_in
|
|
batch_size, seq_len = x.size()
|
|
|
|
# reset RNN hidden state
|
|
self.hidden = self.init_hidden(batch_size)
|
|
|
|
# embed
|
|
x = self.word_embedding(x)
|
|
|
|
# run through rnn using packed sequences
|
|
x = torch.nn.utils.rnn.pack_padded_sequence(x, seq_lengths, batch_first=True)
|
|
x, self.hidden = self.lstm(x, self.hidden)
|
|
x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
|
|
|
|
# if asked for only last state, use the h_n which is the same as out(t=n)
|
|
if not self.return_sequence:
|
|
# pull out hidden states
|
|
# h_n = (nb_directions * nb_layers, batch_size, emb_size)
|
|
nb_directions = 2 if self.bidirectional else 1
|
|
(h_n, _) = self.hidden
|
|
|
|
# reshape to make indexing easier
|
|
# forward = 0, backward = 1 (of nb_directions)
|
|
h_n = h_n.view(self.nb_layers, nb_directions, batch_size, self.nb_rnn_units)
|
|
|
|
# pull out last forward
|
|
forward_h_n = h_n[-1, 0, :, :]
|
|
x = forward_h_n
|
|
|
|
# if bidirectional, also pull out the last hidden of backward network
|
|
if self.bidirectional:
|
|
backward_h_n = h_n[-1, 1, :, :]
|
|
x = torch.cat([forward_h_n, backward_h_n], dim=1)
|
|
|
|
# project to tag space
|
|
x = x.contiguous()
|
|
x = x.view(-1, self.output_dim)
|
|
x = self.hidden_to_tag(x)
|
|
|
|
return x
|
|
|
|
def loss(self, model_out):
|
|
# cross entropy loss
|
|
logits, y = model_out
|
|
y, y_lens = y
|
|
|
|
# flatten y and logits
|
|
y = y.view(-1)
|
|
logits = logits.view(-1, self.nb_tags)
|
|
|
|
# calculate a mask to remove padding tokens
|
|
mask = (y >= 0).float()
|
|
|
|
# count how many tokens we have
|
|
num_tokens = int(torch.sum(mask).data[0])
|
|
|
|
# pick the correct values and mask out
|
|
logits = logits[range(logits.shape[0]), y] * mask
|
|
|
|
# compute the ce loss
|
|
ce_loss = -torch.sum(logits)/num_tokens
|
|
|
|
return ce_loss
|
|
|
|
def pull_out_last_embedding(self, x, seq_lengths, batch_size, on_gpu):
|
|
# grab only the last activations from the non-padded ouput
|
|
x_last = torch.zeros([batch_size, 1, x.size(-1)])
|
|
for i, seq_len in enumerate(seq_lengths):
|
|
x_last[i, :, :] = x[i, seq_len-1, :]
|
|
|
|
# put on gpu when requested
|
|
if on_gpu:
|
|
x_last = x_last.cuda()
|
|
|
|
# turn into torch var
|
|
x_last = Variable(x_last)
|
|
|
|
return x_last |