Added GPT2seq2seq model

This commit is contained in:
Sina 2020-04-26 00:38:02 -07:00
parent 37388d82b7
commit af8d097485
2 changed files with 52 additions and 1 deletions

51
genienlp/GPT2seq2seq.py Normal file
View File

@ -0,0 +1,51 @@
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from torch.nn import CrossEntropyLoss
import torch
class GPT2Seq2Seq(GPT2LMHeadModel):
def __init__(self, config):
super().__init__(config)
self.sep_token = 50258
self.end_token = 50259
self.pad_token = 50257
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
sep_token_position = (input_ids==self.sep_token).to(torch.long)
assert (torch.sum(sep_token_position, dim=1)==1).all(), 'All input_ids must contain exactly one start_token'
token_type_ids = torch.cumsum(sep_token_position, dim=1) - sep_token_position
attention_mask = (input_ids!=self.pad_token).to(torch.long) # 0 means mask, 1 means no mask
position_ids = (torch.cumsum(attention_mask, dim=1)-1)*(1-token_type_ids)+(torch.cumsum(token_type_ids, dim=1)-1)*token_type_ids
token_type_ids = self.sep_token * (1-token_type_ids) + self.end_token * token_type_ids
# print('input_ids = ', input_ids)
# print('position_ids = ', position_ids)
# print('token_type_ids = ', token_type_ids)
# print('attention_mask = ', attention_mask)
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
position_ids = position_ids[:, -1].unsqueeze(-1)
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = attention_mask[:, -1].unsqueeze(-1)
inputs = {"input_ids": input_ids, "position_ids": position_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask, "past": past}
return inputs
if __name__ == '__main__':
model = GPT2Seq2Seq.from_pretrained('workdir/models/gpt2-medium-5')
model.eval()
tokenizer = GPT2Tokenizer.from_pretrained('workdir/models/gpt2-medium-5')
# print(tokenizer.convert_tokens_to_ids('</paraphrase>'))
# print(tokenizer.convert_tokens_to_ids('<paraphrase>'))
dct = tokenizer.batch_encode_plus(['show me restaurants around here. <paraphrase>', 'where is it? <paraphrase>'], return_tensors="pt", pad_to_max_length=True)
outputs = model.generate(input_ids=dct['input_ids'],
max_length=40,
num_beams=16,
early_stopping=True,
num_return_sequences=4,
do_sample=False,
temperature=1.0,
eos_token_id=50259,
pad_token_id=tokenizer.convert_tokens_to_ids(tokenizer.pad_token)) # do greedy decoding
print('outputs = ', outputs)
for output in outputs:
print('Generated: {}'.format(tokenizer.decode(output, skip_special_tokens=True)))

View File

@ -119,7 +119,7 @@ def main():
for idx in range(args.num_new_queries):
copy_row = row.copy()
copy_row[args.utterance_column] = new_queries[new_query_count]
copy_row[args.id_column] += ('-' + str(idx))
copy_row[args.id_column] = 'A' + copy_row[args.id_column] + '-' + str(idx) # add 'A' for auto-paraphrasing
output_rows.append(copy_row)
new_query_count += 1
else: