combined BART and GPT2 generation into one script

This commit is contained in:
Sina 2020-04-27 18:58:35 -07:00
parent e7e6e3a1c4
commit 78fb8ab2bc
3 changed files with 79 additions and 171 deletions

View File

@ -5,21 +5,23 @@ import torch
class GPT2Seq2Seq(GPT2LMHeadModel):
def __init__(self, config):
super().__init__(config)
self.end_token = 50259
self.sep_token = 50258
self.pad_token = 50257
def set_token_ids(self, end_token_id, sep_token_id, pad_token_id):
self.end_token_id = end_token_id
self.sep_token_id = sep_token_id
self.pad_token_id = pad_token_id
def pad_to_max_length(self, input_sequences: List[List[int]]):
"""
Adds pad tokens before the sep_token
"""
max_length = len(input_sequences[0]) # input is sorted by length
max_length = max([len(s) for s in input_sequences])
copy_input_sequences = []
for i in range(len(input_sequences)):
sep_token_index = input_sequences[i].index(self.sep_token)
sep_token_index = input_sequences[i].index(self.sep_token_id)
copy_input_sequences.append(input_sequences[i][:sep_token_index] + \
[self.pad_token]*(max_length-len(input_sequences[i])) +\
[self.pad_token_id]*(max_length-len(input_sequences[i])) +\
input_sequences[i][sep_token_index:])
return copy_input_sequences
@ -31,8 +33,8 @@ class GPT2Seq2Seq(GPT2LMHeadModel):
if repetition_penalty == 1.0:
return lprobs
m = torch.scatter(input=torch.zeros_like(lprobs), dim=1, index=prev_output_tokens, value=1)
m[:self.sep_token] = 0
m[:self.pad_token] = 0
m[:self.sep_token_id] = 0
m[:self.pad_token_id] = 0
# logger.info('m = ', m.shape)
need_change = m * lprobs
need_divide = need_change > 0
@ -48,14 +50,28 @@ class GPT2Seq2Seq(GPT2LMHeadModel):
# else:
# lprobs[i, previous_token] *= repetition_penalty
def generate(self, **kwargs):
outputs = super().generate(**kwargs)
outputs = outputs[:, :].tolist()
for i in range(len(outputs)):
outputs[i] = [x for x in outputs[i] if x != self.pad_token_id] # remove padding
outputs[i] = outputs[i][outputs[i].index(self.sep_token_id)+1:] # only return the output (i.e. after sep_token)
return outputs
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
sep_token_position = (input_ids==self.sep_token).to(torch.long)
assert (torch.sum(sep_token_position, dim=1)==1).all(), 'All input_ids must contain exactly one start_token. sep_token_position = %s' % str(sep_token_position)
sep_token_position = (input_ids==self.sep_token_id).to(torch.long)
# for i, s in enumerate(sep_token_position):
# if torch.sum(s) != 1:
# print(i, s)
# print(input_ids[i])
# exit()
assert (torch.sum(sep_token_position, dim=1)==1).all(), 'All input_ids must contain exactly one start_token. sep_token_position = %s\nsep_token_id = %d' % (str(sep_token_position), self.sep_token_id)
token_type_ids = torch.cumsum(sep_token_position, dim=1) - sep_token_position
attention_mask = (input_ids!=self.pad_token).to(torch.long) # 0 means mask, 1 means no mask
attention_mask = (input_ids!=self.pad_token_id).to(torch.long) # 0 means mask, 1 means no mask
position_ids = (torch.cumsum(attention_mask, dim=1)-1)*(1-token_type_ids)+(torch.cumsum(token_type_ids, dim=1)-1)*token_type_ids
token_type_ids = self.sep_token * (1-token_type_ids) + self.end_token * token_type_ids
token_type_ids = self.sep_token_id * (1-token_type_ids) + self.end_token_id * token_type_ids
# print('input_ids = ', input_ids)
# print('position_ids = ', position_ids)
# print('token_type_ids = ', token_type_ids)
@ -67,25 +83,4 @@ class GPT2Seq2Seq(GPT2LMHeadModel):
attention_mask = attention_mask[:, -1].unsqueeze(-1)
inputs = {"input_ids": input_ids, "position_ids": position_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask, "past": past}
return inputs
if __name__ == '__main__':
model = GPT2Seq2Seq.from_pretrained('workdir/models/gpt2-medium-5')
model.eval()
tokenizer = GPT2Tokenizer.from_pretrained('workdir/models/gpt2-medium-5')
# print(tokenizer.convert_tokens_to_ids('</paraphrase>'))
# print(tokenizer.convert_tokens_to_ids('<paraphrase>'))
dct = tokenizer.batch_encode_plus(['show me restaurants around here. <paraphrase>', 'where is it? <paraphrase>'], return_tensors="pt", pad_to_max_length=True)
outputs = model.generate(input_ids=dct['input_ids'],
max_length=40,
num_beams=16,
early_stopping=True,
num_return_sequences=4,
do_sample=False,
temperature=1.0,
eos_token_id=50259,
pad_token_id=tokenizer.convert_tokens_to_ids(tokenizer.pad_token)) # do greedy decoding
print('outputs = ', outputs)
for output in outputs:
print('Generated: {}'.format(tokenizer.decode(output, skip_special_tokens=True)))
return inputs

View File

@ -1,85 +0,0 @@
import argparse
from pathlib import Path
import torch
from tqdm import tqdm
from transformers import BartForConditionalGeneration, BartTokenizer
from genienlp.paraphrase.train_bart import BartSystem
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]
def generate_summaries(
examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE
):
# b = BartSystem.load_from_checkpoint('./workdir/models/bart-large-2to1/checkpointcheckpoint_ckpt_epoch_1.ckpt')
# b.model.save_pretrained('./workdir/models/bart-large-2to1/')
# b.tokenizer.save_pretrained('./workdir/models/bart-large-2to1/')
model = BartForConditionalGeneration.from_pretrained(model_name).to(device)
model.eval()
model = model.to(device)
tokenizer = BartTokenizer.from_pretrained(model_name)
max_length = 140
min_length = 1
fout = Path(out_file).open("w")
for batch in tqdm(list(chunks(examples, batch_size))):
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
# bad = ['which', 'Which', 'restaurant', 'restaurants']
# bad = [tokenizer.encode(b, add_prefix_space=True, add_special_tokens=False) for b in bad]
summaries = model.generate(
input_ids=dct["input_ids"].to(device),
attention_mask=dct["attention_mask"].to(device),
num_beams=16,
do_sample=False,
temperature=1,
length_penalty=1,
max_length=max_length + 2, # +2 from original because we start at step=1 and stop before max_length
min_length=min_length + 1, # +1 from original because we start at step=1
no_repeat_ngram_size=3,
early_stopping=True,
decoder_start_token_id=model.config.eos_token_id,
num_return_sequences=4
# bad_words_ids=bad
)
# print(bad)
# print(summaries)
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in summaries]
for hypothesis in dec:
fout.write(hypothesis + "\n")
fout.flush()
def run_generate():
parser = argparse.ArgumentParser()
parser.add_argument(
"source_path", type=str, help="like cnn_dm/test.source",
)
parser.add_argument(
"output_path", type=str, help="where to save summaries",
)
parser.add_argument(
"model_name", type=str, default="bart-large-cnn", help="like bart-large-cnn",
)
parser.add_argument(
"--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.",
)
parser.add_argument(
"--bs", type=int, default=8, required=False, help="batch size: how many to summarize at a time",
)
args = parser.parse_args()
examples = [" " + x.rstrip() for x in open(args.source_path).readlines()]
generate_summaries(examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device)
if __name__ == "__main__":
run_generate()

View File

@ -54,7 +54,6 @@ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(messa
level = logging.INFO)
logger = logging.getLogger(__name__)
MAX_LENGTH = int(1000) # Hardcoded max length to avoid infinite loop
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, BartConfig)), ())
@ -82,7 +81,7 @@ special_pattern_mapping = [
]
def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_column, prompt_column, copy, thingtalk_column, sep_token,
skip_heuristics, is_cased):
skip_heuristics, is_cased, model_type):
"""
Read a tsv file (this includes a text file with one example per line) and returns input features that the model needs
Outputs:
@ -91,7 +90,7 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum
all_input_sequences = []
all_input_sequence_lengths = []
all_context_tokens = []
all_context_lengths = []
estimated_output_lengths = []
all_golds = []
reverse_maps = []
@ -106,7 +105,7 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum
for line in tqdm(input_file, desc='Reading Input File', total=number_of_lines, disable=disable_tqdm):
row = line.split('\t')
row = [r.strip() for r in line.split('\t')]
input_sequence = row[input_column]
gold = row[gold_column]
# logger.info('gold = %s', gold)
@ -123,28 +122,28 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum
input_sequence, reverse_map = input_heuristics(input_sequence, thingtalk, is_cased)
# logger.info('input_sequence = %s', input_sequence)
reverse_maps.append(reverse_map)
input_sequence += sep_token
prompt = '' # includes the first few tokens of the output
input_sequence_tokens = tokenizer.encode(input_sequence,add_special_tokens=True) # add_special_tokens=True for gpt2 should have no effect, but as of transformers==2.8.0, a bug results in token_ids getting changed
prompt_tokens = [] # includes the first few tokens of the output
if prompt_column is not None and len(row) > prompt_column:
prompt = row[prompt_column]
if not skip_heuristics:
prompt, _ = input_heuristics(prompt, thingtalk, is_cased)
# logger.info('prompt = %s', prompt)
input_sequence_tokens = tokenizer.encode(input_sequence, add_special_tokens=False)
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
context_tokens = input_sequence_tokens + prompt_tokens
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
if copy > 0:
assert prompt == ''
context_tokens.extend(context_tokens[0 : min(copy, len(context_tokens)-1)]) # -1 since we should not copy prompt_token
assert len(prompt_tokens) == 0
prompt_tokens = context_tokens[0 : min(copy, len(context_tokens)-1)] # -1 since we should not copy sep_token
context_tokens = input_sequence_tokens + [tokenizer.convert_tokens_to_ids(sep_token)] + prompt_tokens
all_input_sequences.append(input_sequence)
all_input_sequence_lengths.append(len(input_sequence_tokens))
all_context_tokens.append(context_tokens)
all_context_lengths.append(len(context_tokens))
estimated_output_lengths.append(len(input_sequence_tokens)-len(prompt_tokens))
if file_path is not None:
input_file.close()
return all_input_sequences, all_input_sequence_lengths, all_context_tokens, all_context_lengths, all_golds, reverse_maps
return all_input_sequences, all_input_sequence_lengths, all_context_tokens, estimated_output_lengths, all_golds, reverse_maps
def is_question(sentence: str):
question_words = ['which', 'what', 'where', 'how', 'who', 'when', 'is', 'are', 'am', \
@ -372,52 +371,60 @@ def run_generation(args):
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path)
model.to(args.device)
model.eval()
if args.length < 0 and model.config.max_position_embeddings > 0:
args.length = model.config.max_position_embeddings
elif 0 < model.config.max_position_embeddings < args.length:
args.length = model.config.max_position_embeddings # No generation bigger than model size
elif args.length < 0:
args.length = MAX_LENGTH # avoid infinite loop
logger.info(args)
pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
sep_token_id = tokenizer.convert_tokens_to_ids(args.sep_token)
if pad_token_id is None:
logger.error('Your tokenizer does not have a padding token')
all_input_sequences, all_input_sequence_lengths, all_context_tokens, all_context_lengths, all_golds, reverse_maps = \
if args.model_type == 'gpt2':
model.set_token_ids(end_token_id=tokenizer.convert_tokens_to_ids(args.stop_tokens[0]),
sep_token_id=tokenizer.convert_tokens_to_ids(args.sep_token),
pad_token_id=pad_token_id)
logger.info(args)
all_input_sequences, all_input_sequence_lengths, all_context_tokens, estimated_output_lengths, all_golds, reverse_maps = \
create_features_from_tsv_file(file_path=args.input_file, tokenizer=tokenizer,
input_column=args.input_column, gold_column=args.gold_column, prompt_column=args.prompt_column,
copy=args.copy,
thingtalk_column=args.thingtalk_column,
sep_token=args.sep_token, skip_heuristics=args.skip_heuristics, is_cased=args.is_cased)
sep_token=args.sep_token, skip_heuristics=args.skip_heuristics, is_cased=args.is_cased,
model_type=args.model_type)
# sort contexts based on their context length so that less generated tokens are thrown away and generation can be done faster
all_context_lengths, all_input_sequence_lengths, all_input_sequences, all_context_tokens, original_order, reverse_maps = \
tuple(zip(*sorted(list(zip(all_context_lengths, all_input_sequence_lengths, all_input_sequences, all_context_tokens, range(len(all_context_tokens)), reverse_maps)), reverse=True)))
estimated_output_lengths, all_input_sequence_lengths, all_input_sequences, all_context_tokens, original_order, reverse_maps = \
tuple(zip(*sorted(list(zip(estimated_output_lengths, all_input_sequence_lengths, all_input_sequences, all_context_tokens, range(len(all_context_tokens)), reverse_maps)), reverse=True)))
all_outputs = []
stop_token_ids = [tokenizer.convert_tokens_to_ids(stop_token) for stop_token in args.stop_tokens]
for batch in trange(math.ceil(len(all_context_tokens) / args.batch_size), desc="Batch"):
for batch in tqdm(range(math.ceil(len(all_context_tokens) / args.batch_size)), desc="Batch"):
batch_slice = (batch*args.batch_size, min((batch+1)*args.batch_size, len(all_context_tokens)))
batch_size = batch_slice[1] - batch_slice[0]
batch_input_sequences = all_input_sequences[batch_slice[0]: batch_slice[1]]
batch_input_sequence_lengths = all_input_sequence_lengths[batch_slice[0]: batch_slice[1]]
batch_context_tokens = all_context_tokens[batch_slice[0]: batch_slice[1]]
batch_reverse_maps = reverse_maps[batch_slice[0]: batch_slice[1]]
# logger.info('batch_context_tokens = %s', str(batch_context_tokens))
batch_context_tensor = input_tensor = torch.tensor(model.pad_to_max_length(batch_context_tokens), dtype=torch.long, device=args.device)
if args.model_type == 'gpt2':
batch_context_tensor = torch.tensor(model.pad_to_max_length(batch_context_tokens), dtype=torch.long, device=args.device)
attention_mask = None
elif args.model_type == 'bart':
padded_batch_context_tokens = []
max_length = max([len(s) for s in batch_context_tokens])
for i in range(len(batch_context_tokens)):
padded_batch_context_tokens.append(batch_context_tokens[i]+[pad_token_id]*(max_length-len(batch_context_tokens[i])))
batch_context_tensor = torch.tensor(padded_batch_context_tokens, dtype=torch.long, device=args.device)
attention_mask = (batch_context_tensor!=pad_token_id).to(torch.long)
# logger.info('batch_context_tensor = %s', str(batch_context_tensor))
batch_outputs = [[] for _ in range(batch_size)]
for hyperparameter_idx in range(len(args.temperature)):
out = model.generate(input_ids=batch_context_tensor,
attention_mask=attention_mask,
min_length=args.min_output_length,
max_length=batch_context_tensor.shape[1]+args.length,
num_beams=args.num_beams[hyperparameter_idx],
@ -431,32 +438,24 @@ def run_generation(args):
eos_token_id=stop_token_ids[0],
pad_token_id=pad_token_id
)
out = out[:, :].tolist()
for i, o in enumerate(out):
# logger.info('all output tokens: %s', str(o))
# logger.info('all output tokens detokenized: %s', str(tokenizer.decode(o, clean_up_tokenization_spaces=True, skip_special_tokens=False)))
o = [x for x in o if x!=pad_token_id][batch_input_sequence_lengths[(i//args.num_samples) % batch_size]:]
# logger.info('original context tokens: %s', str(batch_context_tokens[(i//args.num_samples) % batch_size]))
# logger.info('original input sequence: %s', str(batch_input_sequences[(i//args.num_samples) % batch_size]))
if not isinstance(out, list):
out = out[:, :].tolist()
for i, o in enumerate(out):
if args.stop_tokens is not None:
min_index = len(o)
min_index = len(o)-1
for stop_token_id in stop_token_ids:
try:
index = o.index(stop_token_id)
min_index = min(index, min_index)
except ValueError:
pass
if min_index < len(o) and o[min_index] == tokenizer.convert_tokens_to_ids('?'):
# always include the question mark
min_index = min_index + 1
if o[min_index] != stop_token_ids[0]:
min_index = min_index + 1 # include stop_token if it is not end_token
o = o[:min_index]
text = tokenizer.decode(o, clean_up_tokenization_spaces=True, skip_special_tokens=False)
text = tokenizer.decode(o, clean_up_tokenization_spaces=True, skip_special_tokens=True)
# assert tokenizer.pad_token not in text
text = text.replace(tokenizer.pad_token, '')
text = re.sub('\s\s+', ' ', text) # remove duplicate white spaces
text = text.strip()
if not args.skip_heuristics:
@ -468,17 +467,16 @@ def run_generation(args):
# sort the results back to their original order
_, all_outputs = tuple(zip(*sorted(list(zip(original_order, all_outputs)))))
metrics = compute_metrics(all_outputs, all_golds, reduction=args.metric_reduction)
if args.output_file is not None:
with open(args.output_file, 'w') as output_file:
if args.output_file is not None:
for output in all_outputs:
for text in output:
output_file.write(text + '\n')
for output in all_outputs:
for text in output:
output_file.write(text + '\n')
else:
print(json.dumps(all_outputs, indent=2))
metrics = compute_metrics(all_outputs, all_golds, reduction=args.metric_reduction)
logger.info('Average BLEU score = %.2f', metrics['bleu'])
logger.info('Exact match score = %.2f', metrics['em'])