paraphrasing now defaults to stdin and stdout
This commit is contained in:
parent
d27323ebe6
commit
eda4f26502
|
@ -28,6 +28,7 @@ import re
|
|||
import copy
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
|
||||
# multiprocessing with CUDA
|
||||
from torch.multiprocessing import Process, set_start_method
|
||||
|
@ -84,7 +85,7 @@ def apply_repetition_penalty(logits, context, repetition_penalty, prompt_token_i
|
|||
m = torch.scatter(input=torch.zeros_like(logits), dim=1, index=context, value=1)
|
||||
m[:prompt_token_id] = 0
|
||||
m[:pad_token_id] = 0
|
||||
# print('m = ', m.shape)
|
||||
# logger.info('m = ', m.shape)
|
||||
need_change = m * logits
|
||||
need_divide = need_change > 0
|
||||
need_multiply = need_change < 0
|
||||
|
@ -144,9 +145,9 @@ def sample_sequence(model, length, min_output_length, context, num_samples,
|
|||
segment_ids = torch.tensor(segment_ids, dtype=torch.long, device=device)
|
||||
segment_ids = segment_ids.repeat(num_samples, 1)
|
||||
|
||||
# print('context = ', context)
|
||||
# print('position_ids = ', position_ids)
|
||||
# print('segment_ids = ', segment_ids)
|
||||
# logger.info('context = ', context)
|
||||
# logger.info('position_ids = ', position_ids)
|
||||
# logger.info('segment_ids = ', segment_ids)
|
||||
|
||||
context = torch.tensor(padded_context, dtype=torch.long, device=device)
|
||||
context = context.repeat(num_samples, 1)
|
||||
|
@ -159,7 +160,7 @@ def sample_sequence(model, length, min_output_length, context, num_samples,
|
|||
with torch.no_grad():
|
||||
# rep_penalty = np.random.random(length) < 0.1
|
||||
# original_rep_penalty = repetition_penalty
|
||||
# print('rep_penalty = ', rep_penalty)
|
||||
# logger.info('rep_penalty = ', rep_penalty)
|
||||
for _ in range(length):
|
||||
inputs = {'input_ids': generated, 'position_ids': position_ids[:, :next_index], 'token_type_ids': segment_ids[:, :next_index]}
|
||||
if is_xlnet:
|
||||
|
@ -204,7 +205,7 @@ def sample_sequence(model, length, min_output_length, context, num_samples,
|
|||
# prevent stop_tokens if generated_length < min_output_length
|
||||
should_remove_stop_tokens = (generated_length < min_output_length)
|
||||
next_token_logits[:, stop_token_ids] = next_token_logits[:, stop_token_ids].masked_fill(should_remove_stop_tokens, -float('Inf'))
|
||||
# print('after ', next_token_logits[:, stop_token_ids])
|
||||
# logger.info('after ', next_token_logits[:, stop_token_ids])
|
||||
generated_length = generated_length + (1-m)
|
||||
|
||||
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
||||
|
@ -272,42 +273,53 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum
|
|||
all_golds = []
|
||||
reverse_maps = []
|
||||
|
||||
number_of_lines = get_number_of_lines(file_path)
|
||||
with open(file_path) as input_file:
|
||||
reader = csv.reader(input_file, delimiter='\t')
|
||||
for row in tqdm(reader, desc='Reading Input File', total=number_of_lines):
|
||||
input_sequence = row[input_column]
|
||||
gold = row[gold_column]
|
||||
# print('gold = ', gold)
|
||||
if file_path is not None:
|
||||
number_of_lines = get_number_of_lines(file_path)
|
||||
disable_tqdm = False
|
||||
input_file = open(file_path)
|
||||
else:
|
||||
number_of_lines = 1
|
||||
disable_tqdm = True
|
||||
input_file = sys.stdin
|
||||
|
||||
|
||||
for line in tqdm(input_file, desc='Reading Input File', total=number_of_lines, disable=disable_tqdm):
|
||||
row = line.split('\t')
|
||||
input_sequence = row[input_column]
|
||||
gold = row[gold_column]
|
||||
# logger.info('gold = %s', gold)
|
||||
if not skip_heuristics:
|
||||
gold, _ = input_heuristics(gold, None, is_cased, keep_special_tokens=True, keep_tokenized=True)
|
||||
# logger.info('gold = %s', gold)
|
||||
all_golds.append(gold)
|
||||
# logger.info('before text = %s', input_sequence)
|
||||
if skip_heuristics:
|
||||
reverse_maps.append({})
|
||||
else:
|
||||
thingtalk = row[thingtalk_column] if thingtalk_column is not None else None
|
||||
# logger.info('input_sequence = %s', input_sequence)
|
||||
input_sequence, reverse_map = input_heuristics(input_sequence, thingtalk, is_cased)
|
||||
# logger.info('input_sequence = %s', input_sequence)
|
||||
reverse_maps.append(reverse_map)
|
||||
input_sequence += prompt_token
|
||||
prompt = '' # includes the first few tokens of the output
|
||||
if prompt_column is not None and len(row) > prompt_column:
|
||||
prompt = row[prompt_column]
|
||||
if not skip_heuristics:
|
||||
gold, _ = input_heuristics(gold, None, is_cased, keep_special_tokens=True, keep_tokenized=True)
|
||||
# print('gold = ', gold)
|
||||
all_golds.append(gold)
|
||||
# print('before text = ', input_sequence)
|
||||
if skip_heuristics:
|
||||
reverse_maps.append({})
|
||||
else:
|
||||
thingtalk = row[thingtalk_column] if thingtalk_column is not None else None
|
||||
# print('input_sequence = ', input_sequence)
|
||||
input_sequence, reverse_map = input_heuristics(input_sequence, thingtalk, is_cased)
|
||||
# print('input_sequence = ', input_sequence)
|
||||
reverse_maps.append(reverse_map)
|
||||
input_sequence += prompt_token
|
||||
prompt = '' # includes the first few tokens of the output
|
||||
if prompt_column is not None and len(row) > prompt_column:
|
||||
prompt = row[prompt_column]
|
||||
if not skip_heuristics:
|
||||
prompt, _ = input_heuristics(prompt, thingtalk, is_cased)
|
||||
# print('prompt = ', prompt)
|
||||
input_sequence_tokens = tokenizer.encode(input_sequence, add_special_tokens=False)
|
||||
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
context_tokens = input_sequence_tokens + prompt_tokens
|
||||
if copy > 0:
|
||||
context_tokens.extend(context_tokens[0 : min(copy, len(context_tokens)-1)]) # -1 since we should not copy prompt_token
|
||||
all_input_sequences.append(input_sequence)
|
||||
all_input_sequence_lengths.append(len(input_sequence_tokens))
|
||||
all_context_tokens.append(context_tokens)
|
||||
all_context_lengths.append(len(context_tokens))
|
||||
prompt, _ = input_heuristics(prompt, thingtalk, is_cased)
|
||||
# logger.info('prompt = %s', prompt)
|
||||
input_sequence_tokens = tokenizer.encode(input_sequence, add_special_tokens=False)
|
||||
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
context_tokens = input_sequence_tokens + prompt_tokens
|
||||
if copy > 0:
|
||||
context_tokens.extend(context_tokens[0 : min(copy, len(context_tokens)-1)]) # -1 since we should not copy prompt_token
|
||||
all_input_sequences.append(input_sequence)
|
||||
all_input_sequence_lengths.append(len(input_sequence_tokens))
|
||||
all_context_tokens.append(context_tokens)
|
||||
all_context_lengths.append(len(context_tokens))
|
||||
|
||||
if file_path is not None:
|
||||
input_file.close()
|
||||
|
||||
return all_input_sequences, all_input_sequence_lengths, all_context_tokens, all_context_lengths, all_golds, reverse_maps
|
||||
|
||||
|
@ -333,7 +345,7 @@ def input_heuristics(s: str, thingtalk=None, is_cased=False, keep_special_tokens
|
|||
|
||||
# Put question mark at the end whenever necessary.
|
||||
sentences = [sentence.strip() for sentence in re.split('\s+([.?!:])\s*', s) if len(sentence) > 0]
|
||||
# print('sentences = ', sentences)
|
||||
# logger.info('sentences = %s', sentences)
|
||||
for idx in range(len(sentences)):
|
||||
if sentences[idx] in ['.', '?' , '!', ':']:
|
||||
continue
|
||||
|
@ -352,7 +364,7 @@ def input_heuristics(s: str, thingtalk=None, is_cased=False, keep_special_tokens
|
|||
# capitalize the first word and parameters
|
||||
if thingtalk:
|
||||
_, parameters = remove_thingtalk_quotes(thingtalk)
|
||||
# print('parameters = ', parameters)
|
||||
# logger.info('parameters = ', parameters)
|
||||
for p in parameters:
|
||||
capitalized_p = ' '.join([t[0].upper()+t[1:] for t in p.split()])
|
||||
sentences[idx] = sentences[idx].replace(p, capitalized_p)
|
||||
|
@ -373,7 +385,7 @@ def input_heuristics(s: str, thingtalk=None, is_cased=False, keep_special_tokens
|
|||
s, r = spm.forwad(s)
|
||||
reverse_map.extend(r)
|
||||
|
||||
# print('s = ', s)
|
||||
# logger.info('s = ', s)
|
||||
return s, reverse_map
|
||||
|
||||
def output_heuristics(s: str, reverse_map: list):
|
||||
|
@ -420,11 +432,11 @@ def compute_metrics(generations, golds, reduction='average'):
|
|||
# from matplotlib import pyplot as plt
|
||||
# import numpy as np
|
||||
# h, b = np.histogram(all_bleu, bins=list(range(0, 105, 5)))
|
||||
# print('all_bleu = ', all_bleu)
|
||||
# print('h = ', h)
|
||||
# print('b = ', b)
|
||||
# logger.info('all_bleu = ', all_bleu)
|
||||
# logger.info('h = ', h)
|
||||
# logger.info('b = ', b)
|
||||
# h = h / np.sum(h)
|
||||
# print('h = ', h)
|
||||
# logger.info('h = ', h)
|
||||
# plt.title('GPT2 (temp=0, penalty=1.0) Paraphrases for restaurants')
|
||||
# plt.xlabel('BLEU with original')
|
||||
# plt.ylim((0.0, 1.0))
|
||||
|
@ -439,7 +451,7 @@ def parse_argv(parser):
|
|||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--input_file", type=str, help="The file from which we read prompts.")
|
||||
parser.add_argument("--input_file", type=str, help="The file from which we read prompts. Defaults to stdin.")
|
||||
parser.add_argument('--input_column', type=int, required=True,
|
||||
help='The column in the input file which contains the input sentences.')
|
||||
parser.add_argument('--prompt_column', type=int, default=None,
|
||||
|
@ -448,7 +460,7 @@ def parse_argv(parser):
|
|||
help='The column in the input file which contains the gold sentences. Defaults to --input_column if no gold is available.')
|
||||
parser.add_argument('--thingtalk_column', type=int, default=None,
|
||||
help='The column in the input file which contains the ThingTalk program.')
|
||||
parser.add_argument("--output_file", type=str, help="When specified, generated text will be written in this file.")
|
||||
parser.add_argument("--output_file", type=str, help="When specified, generated text will be written in this file. Defaults to stdout.")
|
||||
parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.")
|
||||
parser.add_argument("--length", type=int, default=20, help='The generated sentences will have a maximum length of len(input) + arg.length')
|
||||
parser.add_argument("--min_output_length", type=int, default=1, help='Will prevent stop tokens from appearing in the first --min_length tokens of the generated sentences.')
|
||||
|
@ -510,6 +522,8 @@ def main(args):
|
|||
args.model_type = args.model_type.lower()
|
||||
|
||||
if args.n_gpu > 1:
|
||||
if args.input_file is None:
|
||||
raise ValueError('Cannot use multiple GPUs when reading from stdin. You should provide an --input_file')
|
||||
# Independent multi-GPU generation
|
||||
all_processes = []
|
||||
all_input_files = split_file_on_disk(args.input_file, args.n_gpu)
|
||||
|
@ -634,11 +648,11 @@ def run_generation(args):
|
|||
out_logits = out_logits[:, :].tolist()
|
||||
for i, o in enumerate(out):
|
||||
o_logits = out_logits[i]
|
||||
# print('all output tokens: ', o)
|
||||
# print('all output tokens detokenized: ', tokenizer.decode(o, clean_up_tokenization_spaces=True, skip_special_tokens=False))
|
||||
# logger.info('all output tokens: %s', o)
|
||||
# logger.info('all output tokens detokenized: %s', str(tokenizer.decode(o, clean_up_tokenization_spaces=True, skip_special_tokens=False)))
|
||||
o = o[batch_input_sequence_lengths[i % batch_size]:]
|
||||
# print('original context tokens: ', batch_context_tokens[i % batch_size])
|
||||
# print('original input sequence: ', batch_input_sequences[i % batch_size])
|
||||
# logger.info('original context tokens: %s', str(batch_context_tokens[i % batch_size]))
|
||||
# logger.info('original input sequence: %s', str(batch_input_sequences[i % batch_size]))
|
||||
|
||||
if args.stop_tokens is not None:
|
||||
min_index = len(o)
|
||||
|
@ -672,22 +686,22 @@ def run_generation(args):
|
|||
else:
|
||||
criterion = np.mean(o_logits)
|
||||
batch_criterion[i % batch_size].append(criterion)
|
||||
# print('generated tokens: ', o)
|
||||
# print('o_logits = ', o_logits)
|
||||
# print('generated cirterion: %.2f' % criterion)
|
||||
# print('text = ', text)
|
||||
# print('-'*10)
|
||||
# logger.info('generated tokens: %s', str(o))
|
||||
# logger.info('o_logits = %s', str(o_logits))
|
||||
# logger.info('generated cirterion: %.2f', criterion)
|
||||
# logger.info('text = %s', text)
|
||||
# logger.info('-'*10)
|
||||
|
||||
|
||||
if args.selection_criterion == 'none':
|
||||
all_outputs.extend(batch_outputs)
|
||||
else:
|
||||
for idx, example in enumerate(batch_outputs):
|
||||
print('input sequence: ', batch_input_sequences[idx % batch_size])
|
||||
logger.info('input sequence: %s', str(batch_input_sequences[idx % batch_size]))
|
||||
c, example = tuple(zip(*sorted(list(zip(batch_criterion[idx], example)), reverse=True)))
|
||||
print(example)
|
||||
print(c)
|
||||
print('-'*10)
|
||||
logger.info(example)
|
||||
logger.info(c)
|
||||
logger.info('-'*10)
|
||||
selection = example[0]
|
||||
all_outputs.append([selection])
|
||||
|
||||
|
@ -703,7 +717,7 @@ def run_generation(args):
|
|||
for text in _:
|
||||
output_file.write(text + '\n')
|
||||
else:
|
||||
logger.info(json.dumps(all_outputs, indent=2))
|
||||
print(json.dumps(all_outputs, indent=2))
|
||||
logger.info('Average BLEU score = %.2f', metrics['bleu'])
|
||||
logger.info('Exact match score = %.2f', metrics['em'])
|
||||
|
||||
|
|
Loading…
Reference in New Issue