paraphrasing now defaults to stdin and stdout

This commit is contained in:
Sina 2020-04-23 21:29:04 -07:00
parent d27323ebe6
commit eda4f26502
1 changed files with 78 additions and 64 deletions

View File

@ -28,6 +28,7 @@ import re
import copy
import numpy as np
import os
import sys
# multiprocessing with CUDA
from torch.multiprocessing import Process, set_start_method
@ -84,7 +85,7 @@ def apply_repetition_penalty(logits, context, repetition_penalty, prompt_token_i
m = torch.scatter(input=torch.zeros_like(logits), dim=1, index=context, value=1)
m[:prompt_token_id] = 0
m[:pad_token_id] = 0
# print('m = ', m.shape)
# logger.info('m = ', m.shape)
need_change = m * logits
need_divide = need_change > 0
need_multiply = need_change < 0
@ -144,9 +145,9 @@ def sample_sequence(model, length, min_output_length, context, num_samples,
segment_ids = torch.tensor(segment_ids, dtype=torch.long, device=device)
segment_ids = segment_ids.repeat(num_samples, 1)
# print('context = ', context)
# print('position_ids = ', position_ids)
# print('segment_ids = ', segment_ids)
# logger.info('context = ', context)
# logger.info('position_ids = ', position_ids)
# logger.info('segment_ids = ', segment_ids)
context = torch.tensor(padded_context, dtype=torch.long, device=device)
context = context.repeat(num_samples, 1)
@ -159,7 +160,7 @@ def sample_sequence(model, length, min_output_length, context, num_samples,
with torch.no_grad():
# rep_penalty = np.random.random(length) < 0.1
# original_rep_penalty = repetition_penalty
# print('rep_penalty = ', rep_penalty)
# logger.info('rep_penalty = ', rep_penalty)
for _ in range(length):
inputs = {'input_ids': generated, 'position_ids': position_ids[:, :next_index], 'token_type_ids': segment_ids[:, :next_index]}
if is_xlnet:
@ -204,7 +205,7 @@ def sample_sequence(model, length, min_output_length, context, num_samples,
# prevent stop_tokens if generated_length < min_output_length
should_remove_stop_tokens = (generated_length < min_output_length)
next_token_logits[:, stop_token_ids] = next_token_logits[:, stop_token_ids].masked_fill(should_remove_stop_tokens, -float('Inf'))
# print('after ', next_token_logits[:, stop_token_ids])
# logger.info('after ', next_token_logits[:, stop_token_ids])
generated_length = generated_length + (1-m)
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
@ -272,42 +273,53 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum
all_golds = []
reverse_maps = []
number_of_lines = get_number_of_lines(file_path)
with open(file_path) as input_file:
reader = csv.reader(input_file, delimiter='\t')
for row in tqdm(reader, desc='Reading Input File', total=number_of_lines):
input_sequence = row[input_column]
gold = row[gold_column]
# print('gold = ', gold)
if file_path is not None:
number_of_lines = get_number_of_lines(file_path)
disable_tqdm = False
input_file = open(file_path)
else:
number_of_lines = 1
disable_tqdm = True
input_file = sys.stdin
for line in tqdm(input_file, desc='Reading Input File', total=number_of_lines, disable=disable_tqdm):
row = line.split('\t')
input_sequence = row[input_column]
gold = row[gold_column]
# logger.info('gold = %s', gold)
if not skip_heuristics:
gold, _ = input_heuristics(gold, None, is_cased, keep_special_tokens=True, keep_tokenized=True)
# logger.info('gold = %s', gold)
all_golds.append(gold)
# logger.info('before text = %s', input_sequence)
if skip_heuristics:
reverse_maps.append({})
else:
thingtalk = row[thingtalk_column] if thingtalk_column is not None else None
# logger.info('input_sequence = %s', input_sequence)
input_sequence, reverse_map = input_heuristics(input_sequence, thingtalk, is_cased)
# logger.info('input_sequence = %s', input_sequence)
reverse_maps.append(reverse_map)
input_sequence += prompt_token
prompt = '' # includes the first few tokens of the output
if prompt_column is not None and len(row) > prompt_column:
prompt = row[prompt_column]
if not skip_heuristics:
gold, _ = input_heuristics(gold, None, is_cased, keep_special_tokens=True, keep_tokenized=True)
# print('gold = ', gold)
all_golds.append(gold)
# print('before text = ', input_sequence)
if skip_heuristics:
reverse_maps.append({})
else:
thingtalk = row[thingtalk_column] if thingtalk_column is not None else None
# print('input_sequence = ', input_sequence)
input_sequence, reverse_map = input_heuristics(input_sequence, thingtalk, is_cased)
# print('input_sequence = ', input_sequence)
reverse_maps.append(reverse_map)
input_sequence += prompt_token
prompt = '' # includes the first few tokens of the output
if prompt_column is not None and len(row) > prompt_column:
prompt = row[prompt_column]
if not skip_heuristics:
prompt, _ = input_heuristics(prompt, thingtalk, is_cased)
# print('prompt = ', prompt)
input_sequence_tokens = tokenizer.encode(input_sequence, add_special_tokens=False)
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
context_tokens = input_sequence_tokens + prompt_tokens
if copy > 0:
context_tokens.extend(context_tokens[0 : min(copy, len(context_tokens)-1)]) # -1 since we should not copy prompt_token
all_input_sequences.append(input_sequence)
all_input_sequence_lengths.append(len(input_sequence_tokens))
all_context_tokens.append(context_tokens)
all_context_lengths.append(len(context_tokens))
prompt, _ = input_heuristics(prompt, thingtalk, is_cased)
# logger.info('prompt = %s', prompt)
input_sequence_tokens = tokenizer.encode(input_sequence, add_special_tokens=False)
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
context_tokens = input_sequence_tokens + prompt_tokens
if copy > 0:
context_tokens.extend(context_tokens[0 : min(copy, len(context_tokens)-1)]) # -1 since we should not copy prompt_token
all_input_sequences.append(input_sequence)
all_input_sequence_lengths.append(len(input_sequence_tokens))
all_context_tokens.append(context_tokens)
all_context_lengths.append(len(context_tokens))
if file_path is not None:
input_file.close()
return all_input_sequences, all_input_sequence_lengths, all_context_tokens, all_context_lengths, all_golds, reverse_maps
@ -333,7 +345,7 @@ def input_heuristics(s: str, thingtalk=None, is_cased=False, keep_special_tokens
# Put question mark at the end whenever necessary.
sentences = [sentence.strip() for sentence in re.split('\s+([.?!:])\s*', s) if len(sentence) > 0]
# print('sentences = ', sentences)
# logger.info('sentences = %s', sentences)
for idx in range(len(sentences)):
if sentences[idx] in ['.', '?' , '!', ':']:
continue
@ -352,7 +364,7 @@ def input_heuristics(s: str, thingtalk=None, is_cased=False, keep_special_tokens
# capitalize the first word and parameters
if thingtalk:
_, parameters = remove_thingtalk_quotes(thingtalk)
# print('parameters = ', parameters)
# logger.info('parameters = ', parameters)
for p in parameters:
capitalized_p = ' '.join([t[0].upper()+t[1:] for t in p.split()])
sentences[idx] = sentences[idx].replace(p, capitalized_p)
@ -373,7 +385,7 @@ def input_heuristics(s: str, thingtalk=None, is_cased=False, keep_special_tokens
s, r = spm.forwad(s)
reverse_map.extend(r)
# print('s = ', s)
# logger.info('s = ', s)
return s, reverse_map
def output_heuristics(s: str, reverse_map: list):
@ -420,11 +432,11 @@ def compute_metrics(generations, golds, reduction='average'):
# from matplotlib import pyplot as plt
# import numpy as np
# h, b = np.histogram(all_bleu, bins=list(range(0, 105, 5)))
# print('all_bleu = ', all_bleu)
# print('h = ', h)
# print('b = ', b)
# logger.info('all_bleu = ', all_bleu)
# logger.info('h = ', h)
# logger.info('b = ', b)
# h = h / np.sum(h)
# print('h = ', h)
# logger.info('h = ', h)
# plt.title('GPT2 (temp=0, penalty=1.0) Paraphrases for restaurants')
# plt.xlabel('BLEU with original')
# plt.ylim((0.0, 1.0))
@ -439,7 +451,7 @@ def parse_argv(parser):
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--input_file", type=str, help="The file from which we read prompts.")
parser.add_argument("--input_file", type=str, help="The file from which we read prompts. Defaults to stdin.")
parser.add_argument('--input_column', type=int, required=True,
help='The column in the input file which contains the input sentences.')
parser.add_argument('--prompt_column', type=int, default=None,
@ -448,7 +460,7 @@ def parse_argv(parser):
help='The column in the input file which contains the gold sentences. Defaults to --input_column if no gold is available.')
parser.add_argument('--thingtalk_column', type=int, default=None,
help='The column in the input file which contains the ThingTalk program.')
parser.add_argument("--output_file", type=str, help="When specified, generated text will be written in this file.")
parser.add_argument("--output_file", type=str, help="When specified, generated text will be written in this file. Defaults to stdout.")
parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.")
parser.add_argument("--length", type=int, default=20, help='The generated sentences will have a maximum length of len(input) + arg.length')
parser.add_argument("--min_output_length", type=int, default=1, help='Will prevent stop tokens from appearing in the first --min_length tokens of the generated sentences.')
@ -510,6 +522,8 @@ def main(args):
args.model_type = args.model_type.lower()
if args.n_gpu > 1:
if args.input_file is None:
raise ValueError('Cannot use multiple GPUs when reading from stdin. You should provide an --input_file')
# Independent multi-GPU generation
all_processes = []
all_input_files = split_file_on_disk(args.input_file, args.n_gpu)
@ -634,11 +648,11 @@ def run_generation(args):
out_logits = out_logits[:, :].tolist()
for i, o in enumerate(out):
o_logits = out_logits[i]
# print('all output tokens: ', o)
# print('all output tokens detokenized: ', tokenizer.decode(o, clean_up_tokenization_spaces=True, skip_special_tokens=False))
# logger.info('all output tokens: %s', o)
# logger.info('all output tokens detokenized: %s', str(tokenizer.decode(o, clean_up_tokenization_spaces=True, skip_special_tokens=False)))
o = o[batch_input_sequence_lengths[i % batch_size]:]
# print('original context tokens: ', batch_context_tokens[i % batch_size])
# print('original input sequence: ', batch_input_sequences[i % batch_size])
# logger.info('original context tokens: %s', str(batch_context_tokens[i % batch_size]))
# logger.info('original input sequence: %s', str(batch_input_sequences[i % batch_size]))
if args.stop_tokens is not None:
min_index = len(o)
@ -672,22 +686,22 @@ def run_generation(args):
else:
criterion = np.mean(o_logits)
batch_criterion[i % batch_size].append(criterion)
# print('generated tokens: ', o)
# print('o_logits = ', o_logits)
# print('generated cirterion: %.2f' % criterion)
# print('text = ', text)
# print('-'*10)
# logger.info('generated tokens: %s', str(o))
# logger.info('o_logits = %s', str(o_logits))
# logger.info('generated cirterion: %.2f', criterion)
# logger.info('text = %s', text)
# logger.info('-'*10)
if args.selection_criterion == 'none':
all_outputs.extend(batch_outputs)
else:
for idx, example in enumerate(batch_outputs):
print('input sequence: ', batch_input_sequences[idx % batch_size])
logger.info('input sequence: %s', str(batch_input_sequences[idx % batch_size]))
c, example = tuple(zip(*sorted(list(zip(batch_criterion[idx], example)), reverse=True)))
print(example)
print(c)
print('-'*10)
logger.info(example)
logger.info(c)
logger.info('-'*10)
selection = example[0]
all_outputs.append([selection])
@ -703,7 +717,7 @@ def run_generation(args):
for text in _:
output_file.write(text + '\n')
else:
logger.info(json.dumps(all_outputs, indent=2))
print(json.dumps(all_outputs, indent=2))
logger.info('Average BLEU score = %.2f', metrics['bleu'])
logger.info('Exact match score = %.2f', metrics['em'])