diff --git a/genienlp/GPT2seq2seq.py b/genienlp/GPT2Seq2Seq.py similarity index 57% rename from genienlp/GPT2seq2seq.py rename to genienlp/GPT2Seq2Seq.py index d87c85b2..b1cdc393 100644 --- a/genienlp/GPT2seq2seq.py +++ b/genienlp/GPT2Seq2Seq.py @@ -1,16 +1,57 @@ +from typing import List from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer import torch class GPT2Seq2Seq(GPT2LMHeadModel): def __init__(self, config): super().__init__(config) - self.sep_token = 50258 self.end_token = 50259 + self.sep_token = 50258 self.pad_token = 50257 + + def pad_to_max_length(self, input_sequences: List[List[int]]): + """ + Adds pad tokens before the sep_token + """ + max_length = len(input_sequences[0]) # input is sorted by length + copy_input_sequences = [] + for i in range(len(input_sequences)): + sep_token_index = input_sequences[i].index(self.sep_token) + copy_input_sequences.append(input_sequences[i][:sep_token_index] + \ + [self.pad_token]*(max_length-len(input_sequences[i])) +\ + input_sequences[i][sep_token_index:]) + + return copy_input_sequences + + + def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty): + """ repetition penalty from CTRL (https://arxiv.org/abs/1909.05858), but much faster on GPU + """ + if repetition_penalty == 1.0: + return lprobs + m = torch.scatter(input=torch.zeros_like(lprobs), dim=1, index=prev_output_tokens, value=1) + m[:self.sep_token] = 0 + m[:self.pad_token] = 0 + # logger.info('m = ', m.shape) + need_change = m * lprobs + need_divide = need_change > 0 + need_multiply = need_change < 0 + lprobs = need_divide * lprobs / repetition_penalty + need_multiply * lprobs * repetition_penalty + (1-m) * lprobs + + # old, slow implementation + # if repetition_penalty != 1.0: + # for i in range(context.shape[0]): + # for previous_token in set(generated[i].tolist()): + # if lprobs[i, previous_token] > 0: + # lprobs[i, previous_token] /= repetition_penalty + # else: + # lprobs[i, previous_token] *= repetition_penalty + + def prepare_inputs_for_generation(self, input_ids, past, **kwargs): sep_token_position = (input_ids==self.sep_token).to(torch.long) - assert (torch.sum(sep_token_position, dim=1)==1).all(), 'All input_ids must contain exactly one start_token' + assert (torch.sum(sep_token_position, dim=1)==1).all(), 'All input_ids must contain exactly one start_token. sep_token_position = %s' % str(sep_token_position) token_type_ids = torch.cumsum(sep_token_position, dim=1) - sep_token_position attention_mask = (input_ids!=self.pad_token).to(torch.long) # 0 means mask, 1 means no mask position_ids = (torch.cumsum(attention_mask, dim=1)-1)*(1-token_type_ids)+(torch.cumsum(token_type_ids, dim=1)-1)*token_type_ids diff --git a/genienlp/paraphrase/evaluate_bart.py b/genienlp/paraphrase/evaluate_bart.py index a30fb4e2..474c7284 100644 --- a/genienlp/paraphrase/evaluate_bart.py +++ b/genienlp/paraphrase/evaluate_bart.py @@ -20,10 +20,9 @@ def chunks(lst, n): def generate_summaries( examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE ): - fout = Path(out_file).open("w") - # b = BartSystem.load_from_checkpoint('./workdir/models/bart-large-mw6/checkpointepoch=1.ckpt') - # b.model.save_pretrained('./workdir/models/bart-large-mw6/') - # b.tokenizer.save_pretrained('./workdir/models/bart-large-mw6/') + # b = BartSystem.load_from_checkpoint('./workdir/models/bart-large-2to1/checkpointcheckpoint_ckpt_epoch_1.ckpt') + # b.model.save_pretrained('./workdir/models/bart-large-2to1/') + # b.tokenizer.save_pretrained('./workdir/models/bart-large-2to1/') model = BartForConditionalGeneration.from_pretrained(model_name).to(device) model.eval() model = model.to(device) @@ -32,6 +31,7 @@ def generate_summaries( max_length = 140 min_length = 1 + fout = Path(out_file).open("w") for batch in tqdm(list(chunks(examples, batch_size))): dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True) # bad = ['which', 'Which', 'restaurant', 'restaurants'] @@ -39,7 +39,7 @@ def generate_summaries( summaries = model.generate( input_ids=dct["input_ids"].to(device), attention_mask=dct["attention_mask"].to(device), - num_beams=1, + num_beams=16, do_sample=False, temperature=1, length_penalty=1, @@ -48,7 +48,7 @@ def generate_summaries( no_repeat_ngram_size=3, early_stopping=True, decoder_start_token_id=model.config.eos_token_id, - num_return_sequences=1 + num_return_sequences=4 # bad_words_ids=bad ) # print(bad) diff --git a/genienlp/run_generation.py b/genienlp/run_generation.py index cae37d03..e3a783d1 100644 --- a/genienlp/run_generation.py +++ b/genienlp/run_generation.py @@ -38,16 +38,15 @@ except RuntimeError: pass import torch -import torch.nn.functional as F from transformers import GPT2Config, BartConfig -from transformers import GPT2LMHeadModel, GPT2Tokenizer +from transformers import GPT2Tokenizer from transformers import BartForConditionalGeneration, BartTokenizer from .util import set_seed, get_number_of_lines, combine_files_on_disk, split_file_on_disk, get_part_path, detokenize, tokenize, lower_case, \ - top_k_top_p_filtering, SpecialTokenMap, remove_thingtalk_quotes + SpecialTokenMap, remove_thingtalk_quotes from .metrics import computeBLEU -# from .models.common import BeamHypotheses +from .GPT2Seq2Seq import GPT2Seq2Seq logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', @@ -55,159 +54,16 @@ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(messa level = logging.INFO) logger = logging.getLogger(__name__) -MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop +MAX_LENGTH = int(1000) # Hardcoded max length to avoid infinite loop ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, BartConfig)), ()) MODEL_CLASSES = { - 'gpt2': (GPT2LMHeadModel, GPT2Tokenizer), + 'gpt2': (GPT2Seq2Seq, GPT2Tokenizer), 'bart': (BartForConditionalGeneration, BartTokenizer) } -def apply_repetition_penalty(logits, context, repetition_penalty, prompt_token_id, pad_token_id): - """ repetition penalty from CTRL (https://arxiv.org/abs/1909.05858), but much faster on GPU - we penalize only the tokens that appear in the context, not in the generated text - """ - if repetition_penalty == 1.0: - return logits - m = torch.scatter(input=torch.zeros_like(logits), dim=1, index=context, value=1) - m[:prompt_token_id] = 0 - m[:pad_token_id] = 0 - # logger.info('m = ', m.shape) - need_change = m * logits - need_divide = need_change > 0 - need_multiply = need_change < 0 - logits = need_divide * logits / repetition_penalty + need_multiply * logits * repetition_penalty + (1-m) * logits - - # Old, slow implementation - # if repetition_penalty != 1.0: - # for i in range(context.shape[0]): - # for _ in set(generated[i].tolist()): - # if logits[i, _] > 0: - # logits[i, _] /= repetition_penalty - # else: - # logits[i, _] *= repetition_penalty - return logits - - -def sample_sequence(model, length, min_output_length, context, num_samples, - temperature=1.0, top_k=0, top_p=1.0, repetition_penalty=1.0, device='cpu', - stop_token_ids=None, pad_token_id=None, supports_past=False, prompt_token_id=None, segment_token_ids=None, - start_reverse_position_ids=None, output_form=None): - """ - Generates sequence of tokens for the batch of input contexts. - Inputs: - context: a list of token_ids, sorted by length from longest to shortest - num_samples: the number of sequences to output for each input context - length: The maximum length of generation in addition to the original sentence's length - stop_token_ids: generation of each sequence will stop if we generate any of these tokens - supports_past: set to True if the model accepts the 'past' input for more efficient generation. For example, GPT-2/Transfo-XL/XLNet/CTRL do - segment_token_ids: a list of two integers that indicate the tokens we should use for each of the two segments - """ - max_length = len(context[0]) # context is sorted by length from longest to shortest - min_length = len(context[-1]) - - # should not change the elements of context since it will change them outside this function as well. - padded_context = [] - for i in range(len(context)): - padded_context.append(context[i] + [pad_token_id] * (max_length-len(context[i]))) # pad to max_length - - next_index = min_length - length = max_length + (max_length - min_length) + length # generate till max_length, then generate another max_length+length tokens - max_index = length + next_index - - segment_ids = [] - position_ids = [] - for i in range(len(context)): - prompt_token_position = context[i].index(prompt_token_id) - p = list(range(prompt_token_position+1)) - segment_ids.append([segment_token_ids[0]]*len(p) + [segment_token_ids[1]]*(max_index - len(p))) - if start_reverse_position_ids is None: - position_ids.append(p + list(range(max_index - len(p)))) - else: - position_ids.append(p + list(reversed(range(start_reverse_position_ids+len(p)))) + [0]*(max_index-start_reverse_position_ids-2*len(p))) - - position_ids = torch.tensor(position_ids, dtype=torch.long, device=device) - position_ids = position_ids.repeat(num_samples, 1) - segment_ids = torch.tensor(segment_ids, dtype=torch.long, device=device) - segment_ids = segment_ids.repeat(num_samples, 1) - - # logger.info('context = ', context) - # logger.info('position_ids = ', position_ids) - # logger.info('segment_ids = ', segment_ids) - - context = torch.tensor(padded_context, dtype=torch.long, device=device) - context = context.repeat(num_samples, 1) - generated = context[:, :next_index] - generated_length = torch.zeros((context.shape[0], 1), dtype=torch.long, device=device) - should_finish = None - generated_logits = None - past = None - next_token = None - with torch.no_grad(): - for _ in range(length): - inputs = {'input_ids': generated, 'position_ids': position_ids[:, :next_index], 'token_type_ids': segment_ids[:, :next_index]} - if supports_past: - inputs['past'] = past - if past is not None: - inputs['input_ids'] = next_token - inputs['position_ids'] = position_ids[:, next_index-1] - inputs['token_type_ids'] = segment_ids[:, next_index-1] - - outputs = model(**inputs) - original_next_token_logits = outputs[0][:, -1, :] - next_token_logits = original_next_token_logits / (temperature if temperature > 0 else 1.) - past = outputs[1] - - next_token_logits = apply_repetition_penalty(next_token_logits, context, repetition_penalty, - prompt_token_id=prompt_token_id, pad_token_id=pad_token_id) - - if next_index < context.shape[1]: - m = (context[:, next_index:next_index+1] != pad_token_id).long() # m==0 is where next_token should be kept - else: - m = torch.zeros(1, device=device) - - # prevent stop_tokens if generated_length < min_output_length - should_remove_stop_tokens = (generated_length < min_output_length) - next_token_logits[:, stop_token_ids] = next_token_logits[:, stop_token_ids].masked_fill(should_remove_stop_tokens, -float('Inf')) - # logger.info('after ', next_token_logits[:, stop_token_ids]) - generated_length = generated_length + (1-m) - - filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) - - if temperature == 0: # greedy sampling: - next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1) - else: - next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) - - if output_form == 'logprob': - generated_token_logit = F.log_softmax(original_next_token_logits, dim=-1).gather(1, next_token) - else: - assert output_form == 'logit' - generated_token_logit = original_next_token_logits.gather(1, next_token) - - # throw away the tokens that we already have from the context - if next_index < context.shape[1]: - next_token = m*context[:, next_index:next_index+1] + (1-m)*next_token - generated_token_logit = (1-m)*generated_token_logit - - for stop_token_id in stop_token_ids: - if should_finish is None: - should_finish = ((next_token == stop_token_id) & (1-m).bool()) - else: - should_finish = should_finish | ((next_token == stop_token_id) & (1-m).bool()) - next_index += 1 - generated = torch.cat((generated, next_token), dim=1) - if generated_logits is None: - generated_logits = generated_token_logit - else: - generated_logits = torch.cat((generated_logits, generated_token_logit), dim=1) - if should_finish.all(): - break - return generated, generated_logits - - special_pattern_mapping = [ SpecialTokenMap('PHONE_NUMBER_([0-9]+)', ['888-8888', '777-8888']), SpecialTokenMap('NUMBER_([0-9]+)', ['2', '3'], [['2', 'two'], ['3', 'three']]), @@ -225,7 +81,7 @@ special_pattern_mapping = [ SpecialTokenMap('GENERIC_ENTITY_uk.ac.cam.multiwoz.Restaurant:Restaurant_([0-9]+)', ["restaurant1", "restaurant2", "restaurant3"]) # TODO the only reason we can get away with this unnatural replacement is that actual backward is not going to be called for this ] -def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_column, prompt_column, copy, thingtalk_column, prompt_token, +def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_column, prompt_column, copy, thingtalk_column, sep_token, skip_heuristics, is_cased): """ Read a tsv file (this includes a text file with one example per line) and returns input features that the model needs @@ -267,7 +123,7 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum input_sequence, reverse_map = input_heuristics(input_sequence, thingtalk, is_cased) # logger.info('input_sequence = %s', input_sequence) reverse_maps.append(reverse_map) - input_sequence += prompt_token + input_sequence += sep_token prompt = '' # includes the first few tokens of the output if prompt_column is not None and len(row) > prompt_column: prompt = row[prompt_column] @@ -278,6 +134,7 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False) context_tokens = input_sequence_tokens + prompt_tokens if copy > 0: + assert prompt == '' context_tokens.extend(context_tokens[0 : min(copy, len(context_tokens)-1)]) # -1 since we should not copy prompt_token all_input_sequences.append(input_sequence) all_input_sequence_lengths.append(len(input_sequence_tokens)) @@ -289,7 +146,6 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum return all_input_sequences, all_input_sequence_lengths, all_context_tokens, all_context_lengths, all_golds, reverse_maps - def is_question(sentence: str): question_words = ['which', 'what', 'where', 'how', 'who', 'when', 'is', 'are', 'am', \ 'can', 'could', 'would', 'will', 'have', 'did', 'do', 'does', 'no is', 'yes is'] @@ -435,20 +291,16 @@ def parse_argv(parser): parser.add_argument("--metric_reduction", type=str, choices=['average', 'max'], default='average', help="How we should calculate metrics where there are multiple generations per example.") - # These can be used for improving the quality of the output parser.add_argument("--num_samples", type=int, default=1) - parser.add_argument("--selection_criterion", type=str, choices=['none', 'average_logit', 'average_logprob', 'bleu'], default='none', - help='Select one of --num_sample outputs that maximizes this criterion') # These are generation hyperparameters. Each one can be a list of values in which case, we generate num_samples outputs for each set of hyperparameters. - parser.add_argument("--start_reverse_position_ids", type=int, nargs='+', default=[None], - help='If provided, position ids will be the number of tokens left in generation and will start from len(input) + args.start_reverse_position_ids') parser.add_argument("--temperature", type=float, nargs='+', default=[1.0], help="temperature of 0 implies greedy sampling") parser.add_argument("--repetition_penalty", type=float, nargs='+', default=[1.0], help="primarily useful for CTRL model; in that case, use 1.2") parser.add_argument("--top_k", type=int, nargs='+', default=[0], help='0 disables top-k filtering') parser.add_argument("--top_p", type=float, nargs='+', default=[0.9], help='1.0 disables top-p filtering') + parser.add_argument("--num_beams", type=int, nargs='+', default=[1], help='1 disables beam seach') parser.add_argument("--copy", type=int, default=0, help='Number of tokens that will be copied at the beginning of generation. Helps preserve the original meaning of the input sequence.') @@ -456,7 +308,7 @@ def parse_argv(parser): help="Avoid using CUDA when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") - parser.add_argument('--prompt_token', type=str, default='', + parser.add_argument('--sep_token', type=str, default='', help="Token after which text generation starts. We add this to the end of all inputs.") parser.add_argument('--stop_tokens', type=str, nargs='+', default=[''], help="Token at which text generation is stopped. The first element of the list is used as segment id as well.") @@ -466,7 +318,7 @@ def parse_argv(parser): def main(args): if args.prompt_column is not None and args.copy is not None and args.copy != 0: raise ValueError('Cannot copy from the input and use prompt at the same time. Disable either --copy or --prompt_column.') - hyperparameters = ['temperature', 'top_k', 'top_p', 'repetition_penalty', 'start_reverse_position_ids'] + hyperparameters = ['temperature', 'top_k', 'top_p', 'repetition_penalty', 'num_beams'] max_hyperparameter_len = max([len(getattr(args, h)) for h in hyperparameters]) valid_len = [1, max_hyperparameter_len] for h in hyperparameters: @@ -534,7 +386,7 @@ def run_generation(args): pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) - prompt_token_id = tokenizer.convert_tokens_to_ids(args.prompt_token) + sep_token_id = tokenizer.convert_tokens_to_ids(args.sep_token) if pad_token_id is None: logger.error('Your tokenizer does not have a padding token') @@ -543,7 +395,7 @@ def run_generation(args): input_column=args.input_column, gold_column=args.gold_column, prompt_column=args.prompt_column, copy=args.copy, thingtalk_column=args.thingtalk_column, - prompt_token=args.prompt_token, skip_heuristics=args.skip_heuristics, is_cased=args.is_cased) + sep_token=args.sep_token, skip_heuristics=args.skip_heuristics, is_cased=args.is_cased) # sort contexts based on their context length so that less generated tokens are thrown away and generation can be done faster @@ -561,38 +413,32 @@ def run_generation(args): batch_context_tokens = all_context_tokens[batch_slice[0]: batch_slice[1]] batch_reverse_maps = reverse_maps[batch_slice[0]: batch_slice[1]] + batch_context_tensor = input_tensor = torch.tensor(model.pad_to_max_length(batch_context_tokens), dtype=torch.long, device=args.device) + batch_outputs = [[] for _ in range(batch_size)] - batch_criterion = [[] for _ in range(batch_size)] for hyperparameter_idx in range(len(args.temperature)): - out, out_logits = sample_sequence( - model=model, - context=batch_context_tokens, - num_samples=args.num_samples, - length=args.length, - min_output_length=args.min_output_length, - temperature=args.temperature[hyperparameter_idx], - top_k=args.top_k[hyperparameter_idx], - top_p=args.top_p[hyperparameter_idx], - repetition_penalty=args.repetition_penalty[hyperparameter_idx], - device=args.device, - stop_token_ids=stop_token_ids, - pad_token_id=pad_token_id, - supports_past=args.model_type in ['gpt2'], - prompt_token_id=prompt_token_id, - segment_token_ids=[tokenizer.convert_tokens_to_ids(args.prompt_token), tokenizer.convert_tokens_to_ids(args.stop_tokens[0])] if args.model_type=='gpt2' else [0, 1], - start_reverse_position_ids=args.start_reverse_position_ids[hyperparameter_idx], - output_form='logit' if args.selection_criterion=='average_logit' else 'logprob' - ) + out = model.generate(input_ids=batch_context_tensor, + min_length=args.min_output_length, + max_length=batch_context_tensor.shape[1]+args.length, + num_beams=args.num_beams[hyperparameter_idx], + top_k=args.top_k[hyperparameter_idx], + top_p=args.top_p[hyperparameter_idx], + early_stopping=True, + num_return_sequences=args.num_samples, + repetition_penalty=args.repetition_penalty[hyperparameter_idx], + do_sample=args.temperature[hyperparameter_idx]!=0, + temperature=args.temperature[hyperparameter_idx] if args.temperature[hyperparameter_idx] > 0 else 1.0, # if temperature==0, we do not sample + eos_token_id=stop_token_ids[0], + pad_token_id=pad_token_id + ) out = out[:, :].tolist() - out_logits = out_logits[:, :].tolist() for i, o in enumerate(out): - o_logits = out_logits[i] - # logger.info('all output tokens: %s', o) + # logger.info('all output tokens: %s', str(o)) # logger.info('all output tokens detokenized: %s', str(tokenizer.decode(o, clean_up_tokenization_spaces=True, skip_special_tokens=False))) - o = o[batch_input_sequence_lengths[i % batch_size]:] - # logger.info('original context tokens: %s', str(batch_context_tokens[i % batch_size])) - # logger.info('original input sequence: %s', str(batch_input_sequences[i % batch_size])) + o = [x for x in o if x!=pad_token_id][batch_input_sequence_lengths[(i//args.num_samples) % batch_size]:] + # logger.info('original context tokens: %s', str(batch_context_tokens[(i//args.num_samples) % batch_size])) + # logger.info('original input sequence: %s', str(batch_input_sequences[(i//args.num_samples) % batch_size])) if args.stop_tokens is not None: min_index = len(o) @@ -605,9 +451,6 @@ def run_generation(args): if min_index < len(o) and o[min_index] == tokenizer.convert_tokens_to_ids('?'): # always include the question mark min_index = min_index + 1 - if min_index < len(o) and o[min_index] == tokenizer.convert_tokens_to_ids(args.stop_tokens[0]): - # include in logit calculation - o_logits = o_logits[:len(o_logits)-(len(o)-min_index-1)] o = o[:min_index] text = tokenizer.decode(o, clean_up_tokenization_spaces=True, skip_special_tokens=False) @@ -617,34 +460,12 @@ def run_generation(args): text = re.sub('\s\s+', ' ', text) # remove duplicate white spaces text = text.strip() if not args.skip_heuristics: - text = output_heuristics(text, batch_reverse_maps[i % batch_size]) - batch_outputs[i % batch_size].append(text) + text = output_heuristics(text, batch_reverse_maps[(i//args.num_samples) % batch_size]) + batch_outputs[(i//args.num_samples) % batch_size].append(text) - if args.selection_criterion == 'bleu': - # computeBLEU always converts to lower case first, so do not worry about lower/upper case here - criterion = computeBLEU([text], [[batch_input_sequences[i % batch_size]]]) - else: - criterion = np.mean(o_logits) - batch_criterion[i % batch_size].append(criterion) - # logger.info('generated tokens: %s', str(o)) - # logger.info('o_logits = %s', str(o_logits)) - # logger.info('generated cirterion: %.2f', criterion) - # logger.info('text = %s', text) - # logger.info('-'*10) + all_outputs.extend(batch_outputs) - if args.selection_criterion == 'none': - all_outputs.extend(batch_outputs) - else: - for idx, example in enumerate(batch_outputs): - logger.info('input sequence: %s', str(batch_input_sequences[idx % batch_size])) - c, example = tuple(zip(*sorted(list(zip(batch_criterion[idx], example)), reverse=True))) - logger.info(example) - logger.info(c) - logger.info('-'*10) - selection = example[0] - all_outputs.append([selection]) - # sort the results back to their original order _, all_outputs = tuple(zip(*sorted(list(zip(original_order, all_outputs))))) @@ -653,8 +474,8 @@ def run_generation(args): if args.output_file is not None: with open(args.output_file, 'w') as output_file: if args.output_file is not None: - for _ in all_outputs: - for text in _: + for output in all_outputs: + for text in output: output_file.write(text + '\n') else: print(json.dumps(all_outputs, indent=2))