From eda4f265020391d01edfef491c0a355b3e638d67 Mon Sep 17 00:00:00 2001 From: Sina Date: Thu, 23 Apr 2020 21:29:04 -0700 Subject: [PATCH] paraphrasing now defaults to stdin and stdout --- genienlp/run_generation.py | 142 ++++++++++++++++++++----------------- 1 file changed, 78 insertions(+), 64 deletions(-) diff --git a/genienlp/run_generation.py b/genienlp/run_generation.py index d5963ae0..565d0c04 100644 --- a/genienlp/run_generation.py +++ b/genienlp/run_generation.py @@ -28,6 +28,7 @@ import re import copy import numpy as np import os +import sys # multiprocessing with CUDA from torch.multiprocessing import Process, set_start_method @@ -84,7 +85,7 @@ def apply_repetition_penalty(logits, context, repetition_penalty, prompt_token_i m = torch.scatter(input=torch.zeros_like(logits), dim=1, index=context, value=1) m[:prompt_token_id] = 0 m[:pad_token_id] = 0 - # print('m = ', m.shape) + # logger.info('m = ', m.shape) need_change = m * logits need_divide = need_change > 0 need_multiply = need_change < 0 @@ -144,9 +145,9 @@ def sample_sequence(model, length, min_output_length, context, num_samples, segment_ids = torch.tensor(segment_ids, dtype=torch.long, device=device) segment_ids = segment_ids.repeat(num_samples, 1) - # print('context = ', context) - # print('position_ids = ', position_ids) - # print('segment_ids = ', segment_ids) + # logger.info('context = ', context) + # logger.info('position_ids = ', position_ids) + # logger.info('segment_ids = ', segment_ids) context = torch.tensor(padded_context, dtype=torch.long, device=device) context = context.repeat(num_samples, 1) @@ -159,7 +160,7 @@ def sample_sequence(model, length, min_output_length, context, num_samples, with torch.no_grad(): # rep_penalty = np.random.random(length) < 0.1 # original_rep_penalty = repetition_penalty - # print('rep_penalty = ', rep_penalty) + # logger.info('rep_penalty = ', rep_penalty) for _ in range(length): inputs = {'input_ids': generated, 'position_ids': position_ids[:, :next_index], 'token_type_ids': segment_ids[:, :next_index]} if is_xlnet: @@ -204,7 +205,7 @@ def sample_sequence(model, length, min_output_length, context, num_samples, # prevent stop_tokens if generated_length < min_output_length should_remove_stop_tokens = (generated_length < min_output_length) next_token_logits[:, stop_token_ids] = next_token_logits[:, stop_token_ids].masked_fill(should_remove_stop_tokens, -float('Inf')) - # print('after ', next_token_logits[:, stop_token_ids]) + # logger.info('after ', next_token_logits[:, stop_token_ids]) generated_length = generated_length + (1-m) filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) @@ -272,42 +273,53 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum all_golds = [] reverse_maps = [] - number_of_lines = get_number_of_lines(file_path) - with open(file_path) as input_file: - reader = csv.reader(input_file, delimiter='\t') - for row in tqdm(reader, desc='Reading Input File', total=number_of_lines): - input_sequence = row[input_column] - gold = row[gold_column] - # print('gold = ', gold) + if file_path is not None: + number_of_lines = get_number_of_lines(file_path) + disable_tqdm = False + input_file = open(file_path) + else: + number_of_lines = 1 + disable_tqdm = True + input_file = sys.stdin + + + for line in tqdm(input_file, desc='Reading Input File', total=number_of_lines, disable=disable_tqdm): + row = line.split('\t') + input_sequence = row[input_column] + gold = row[gold_column] + # logger.info('gold = %s', gold) + if not skip_heuristics: + gold, _ = input_heuristics(gold, None, is_cased, keep_special_tokens=True, keep_tokenized=True) + # logger.info('gold = %s', gold) + all_golds.append(gold) + # logger.info('before text = %s', input_sequence) + if skip_heuristics: + reverse_maps.append({}) + else: + thingtalk = row[thingtalk_column] if thingtalk_column is not None else None + # logger.info('input_sequence = %s', input_sequence) + input_sequence, reverse_map = input_heuristics(input_sequence, thingtalk, is_cased) + # logger.info('input_sequence = %s', input_sequence) + reverse_maps.append(reverse_map) + input_sequence += prompt_token + prompt = '' # includes the first few tokens of the output + if prompt_column is not None and len(row) > prompt_column: + prompt = row[prompt_column] if not skip_heuristics: - gold, _ = input_heuristics(gold, None, is_cased, keep_special_tokens=True, keep_tokenized=True) - # print('gold = ', gold) - all_golds.append(gold) - # print('before text = ', input_sequence) - if skip_heuristics: - reverse_maps.append({}) - else: - thingtalk = row[thingtalk_column] if thingtalk_column is not None else None - # print('input_sequence = ', input_sequence) - input_sequence, reverse_map = input_heuristics(input_sequence, thingtalk, is_cased) - # print('input_sequence = ', input_sequence) - reverse_maps.append(reverse_map) - input_sequence += prompt_token - prompt = '' # includes the first few tokens of the output - if prompt_column is not None and len(row) > prompt_column: - prompt = row[prompt_column] - if not skip_heuristics: - prompt, _ = input_heuristics(prompt, thingtalk, is_cased) - # print('prompt = ', prompt) - input_sequence_tokens = tokenizer.encode(input_sequence, add_special_tokens=False) - prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False) - context_tokens = input_sequence_tokens + prompt_tokens - if copy > 0: - context_tokens.extend(context_tokens[0 : min(copy, len(context_tokens)-1)]) # -1 since we should not copy prompt_token - all_input_sequences.append(input_sequence) - all_input_sequence_lengths.append(len(input_sequence_tokens)) - all_context_tokens.append(context_tokens) - all_context_lengths.append(len(context_tokens)) + prompt, _ = input_heuristics(prompt, thingtalk, is_cased) + # logger.info('prompt = %s', prompt) + input_sequence_tokens = tokenizer.encode(input_sequence, add_special_tokens=False) + prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False) + context_tokens = input_sequence_tokens + prompt_tokens + if copy > 0: + context_tokens.extend(context_tokens[0 : min(copy, len(context_tokens)-1)]) # -1 since we should not copy prompt_token + all_input_sequences.append(input_sequence) + all_input_sequence_lengths.append(len(input_sequence_tokens)) + all_context_tokens.append(context_tokens) + all_context_lengths.append(len(context_tokens)) + + if file_path is not None: + input_file.close() return all_input_sequences, all_input_sequence_lengths, all_context_tokens, all_context_lengths, all_golds, reverse_maps @@ -333,7 +345,7 @@ def input_heuristics(s: str, thingtalk=None, is_cased=False, keep_special_tokens # Put question mark at the end whenever necessary. sentences = [sentence.strip() for sentence in re.split('\s+([.?!:])\s*', s) if len(sentence) > 0] - # print('sentences = ', sentences) + # logger.info('sentences = %s', sentences) for idx in range(len(sentences)): if sentences[idx] in ['.', '?' , '!', ':']: continue @@ -352,7 +364,7 @@ def input_heuristics(s: str, thingtalk=None, is_cased=False, keep_special_tokens # capitalize the first word and parameters if thingtalk: _, parameters = remove_thingtalk_quotes(thingtalk) - # print('parameters = ', parameters) + # logger.info('parameters = ', parameters) for p in parameters: capitalized_p = ' '.join([t[0].upper()+t[1:] for t in p.split()]) sentences[idx] = sentences[idx].replace(p, capitalized_p) @@ -373,7 +385,7 @@ def input_heuristics(s: str, thingtalk=None, is_cased=False, keep_special_tokens s, r = spm.forwad(s) reverse_map.extend(r) - # print('s = ', s) + # logger.info('s = ', s) return s, reverse_map def output_heuristics(s: str, reverse_map: list): @@ -420,11 +432,11 @@ def compute_metrics(generations, golds, reduction='average'): # from matplotlib import pyplot as plt # import numpy as np # h, b = np.histogram(all_bleu, bins=list(range(0, 105, 5))) - # print('all_bleu = ', all_bleu) - # print('h = ', h) - # print('b = ', b) + # logger.info('all_bleu = ', all_bleu) + # logger.info('h = ', h) + # logger.info('b = ', b) # h = h / np.sum(h) - # print('h = ', h) + # logger.info('h = ', h) # plt.title('GPT2 (temp=0, penalty=1.0) Paraphrases for restaurants') # plt.xlabel('BLEU with original') # plt.ylim((0.0, 1.0)) @@ -439,7 +451,7 @@ def parse_argv(parser): help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) parser.add_argument("--model_name_or_path", default=None, type=str, required=True, help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) - parser.add_argument("--input_file", type=str, help="The file from which we read prompts.") + parser.add_argument("--input_file", type=str, help="The file from which we read prompts. Defaults to stdin.") parser.add_argument('--input_column', type=int, required=True, help='The column in the input file which contains the input sentences.') parser.add_argument('--prompt_column', type=int, default=None, @@ -448,7 +460,7 @@ def parse_argv(parser): help='The column in the input file which contains the gold sentences. Defaults to --input_column if no gold is available.') parser.add_argument('--thingtalk_column', type=int, default=None, help='The column in the input file which contains the ThingTalk program.') - parser.add_argument("--output_file", type=str, help="When specified, generated text will be written in this file.") + parser.add_argument("--output_file", type=str, help="When specified, generated text will be written in this file. Defaults to stdout.") parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.") parser.add_argument("--length", type=int, default=20, help='The generated sentences will have a maximum length of len(input) + arg.length') parser.add_argument("--min_output_length", type=int, default=1, help='Will prevent stop tokens from appearing in the first --min_length tokens of the generated sentences.') @@ -510,6 +522,8 @@ def main(args): args.model_type = args.model_type.lower() if args.n_gpu > 1: + if args.input_file is None: + raise ValueError('Cannot use multiple GPUs when reading from stdin. You should provide an --input_file') # Independent multi-GPU generation all_processes = [] all_input_files = split_file_on_disk(args.input_file, args.n_gpu) @@ -634,11 +648,11 @@ def run_generation(args): out_logits = out_logits[:, :].tolist() for i, o in enumerate(out): o_logits = out_logits[i] - # print('all output tokens: ', o) - # print('all output tokens detokenized: ', tokenizer.decode(o, clean_up_tokenization_spaces=True, skip_special_tokens=False)) + # logger.info('all output tokens: %s', o) + # logger.info('all output tokens detokenized: %s', str(tokenizer.decode(o, clean_up_tokenization_spaces=True, skip_special_tokens=False))) o = o[batch_input_sequence_lengths[i % batch_size]:] - # print('original context tokens: ', batch_context_tokens[i % batch_size]) - # print('original input sequence: ', batch_input_sequences[i % batch_size]) + # logger.info('original context tokens: %s', str(batch_context_tokens[i % batch_size])) + # logger.info('original input sequence: %s', str(batch_input_sequences[i % batch_size])) if args.stop_tokens is not None: min_index = len(o) @@ -672,22 +686,22 @@ def run_generation(args): else: criterion = np.mean(o_logits) batch_criterion[i % batch_size].append(criterion) - # print('generated tokens: ', o) - # print('o_logits = ', o_logits) - # print('generated cirterion: %.2f' % criterion) - # print('text = ', text) - # print('-'*10) + # logger.info('generated tokens: %s', str(o)) + # logger.info('o_logits = %s', str(o_logits)) + # logger.info('generated cirterion: %.2f', criterion) + # logger.info('text = %s', text) + # logger.info('-'*10) if args.selection_criterion == 'none': all_outputs.extend(batch_outputs) else: for idx, example in enumerate(batch_outputs): - print('input sequence: ', batch_input_sequences[idx % batch_size]) + logger.info('input sequence: %s', str(batch_input_sequences[idx % batch_size])) c, example = tuple(zip(*sorted(list(zip(batch_criterion[idx], example)), reverse=True))) - print(example) - print(c) - print('-'*10) + logger.info(example) + logger.info(c) + logger.info('-'*10) selection = example[0] all_outputs.append([selection]) @@ -703,7 +717,7 @@ def run_generation(args): for text in _: output_file.write(text + '\n') else: - logger.info(json.dumps(all_outputs, indent=2)) + print(json.dumps(all_outputs, indent=2)) logger.info('Average BLEU score = %.2f', metrics['bleu']) logger.info('Exact match score = %.2f', metrics['em'])