diff --git a/genienlp/GPT2Seq2Seq.py b/genienlp/GPT2Seq2Seq.py index b1cdc393..b7a39592 100644 --- a/genienlp/GPT2Seq2Seq.py +++ b/genienlp/GPT2Seq2Seq.py @@ -5,21 +5,23 @@ import torch class GPT2Seq2Seq(GPT2LMHeadModel): def __init__(self, config): super().__init__(config) - self.end_token = 50259 - self.sep_token = 50258 - self.pad_token = 50257 + + def set_token_ids(self, end_token_id, sep_token_id, pad_token_id): + self.end_token_id = end_token_id + self.sep_token_id = sep_token_id + self.pad_token_id = pad_token_id def pad_to_max_length(self, input_sequences: List[List[int]]): """ Adds pad tokens before the sep_token """ - max_length = len(input_sequences[0]) # input is sorted by length + max_length = max([len(s) for s in input_sequences]) copy_input_sequences = [] for i in range(len(input_sequences)): - sep_token_index = input_sequences[i].index(self.sep_token) + sep_token_index = input_sequences[i].index(self.sep_token_id) copy_input_sequences.append(input_sequences[i][:sep_token_index] + \ - [self.pad_token]*(max_length-len(input_sequences[i])) +\ + [self.pad_token_id]*(max_length-len(input_sequences[i])) +\ input_sequences[i][sep_token_index:]) return copy_input_sequences @@ -31,8 +33,8 @@ class GPT2Seq2Seq(GPT2LMHeadModel): if repetition_penalty == 1.0: return lprobs m = torch.scatter(input=torch.zeros_like(lprobs), dim=1, index=prev_output_tokens, value=1) - m[:self.sep_token] = 0 - m[:self.pad_token] = 0 + m[:self.sep_token_id] = 0 + m[:self.pad_token_id] = 0 # logger.info('m = ', m.shape) need_change = m * lprobs need_divide = need_change > 0 @@ -48,14 +50,28 @@ class GPT2Seq2Seq(GPT2LMHeadModel): # else: # lprobs[i, previous_token] *= repetition_penalty + def generate(self, **kwargs): + outputs = super().generate(**kwargs) + outputs = outputs[:, :].tolist() + for i in range(len(outputs)): + outputs[i] = [x for x in outputs[i] if x != self.pad_token_id] # remove padding + outputs[i] = outputs[i][outputs[i].index(self.sep_token_id)+1:] # only return the output (i.e. after sep_token) + + return outputs + def prepare_inputs_for_generation(self, input_ids, past, **kwargs): - sep_token_position = (input_ids==self.sep_token).to(torch.long) - assert (torch.sum(sep_token_position, dim=1)==1).all(), 'All input_ids must contain exactly one start_token. sep_token_position = %s' % str(sep_token_position) + sep_token_position = (input_ids==self.sep_token_id).to(torch.long) + # for i, s in enumerate(sep_token_position): + # if torch.sum(s) != 1: + # print(i, s) + # print(input_ids[i]) + # exit() + assert (torch.sum(sep_token_position, dim=1)==1).all(), 'All input_ids must contain exactly one start_token. sep_token_position = %s\nsep_token_id = %d' % (str(sep_token_position), self.sep_token_id) token_type_ids = torch.cumsum(sep_token_position, dim=1) - sep_token_position - attention_mask = (input_ids!=self.pad_token).to(torch.long) # 0 means mask, 1 means no mask + attention_mask = (input_ids!=self.pad_token_id).to(torch.long) # 0 means mask, 1 means no mask position_ids = (torch.cumsum(attention_mask, dim=1)-1)*(1-token_type_ids)+(torch.cumsum(token_type_ids, dim=1)-1)*token_type_ids - token_type_ids = self.sep_token * (1-token_type_ids) + self.end_token * token_type_ids + token_type_ids = self.sep_token_id * (1-token_type_ids) + self.end_token_id * token_type_ids # print('input_ids = ', input_ids) # print('position_ids = ', position_ids) # print('token_type_ids = ', token_type_ids) @@ -67,25 +83,4 @@ class GPT2Seq2Seq(GPT2LMHeadModel): attention_mask = attention_mask[:, -1].unsqueeze(-1) inputs = {"input_ids": input_ids, "position_ids": position_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask, "past": past} - return inputs - -if __name__ == '__main__': - model = GPT2Seq2Seq.from_pretrained('workdir/models/gpt2-medium-5') - model.eval() - tokenizer = GPT2Tokenizer.from_pretrained('workdir/models/gpt2-medium-5') - # print(tokenizer.convert_tokens_to_ids('')) - # print(tokenizer.convert_tokens_to_ids('')) - dct = tokenizer.batch_encode_plus(['show me restaurants around here. ', 'where is it? '], return_tensors="pt", pad_to_max_length=True) - outputs = model.generate(input_ids=dct['input_ids'], - max_length=40, - num_beams=16, - early_stopping=True, - num_return_sequences=4, - do_sample=False, - temperature=1.0, - eos_token_id=50259, - pad_token_id=tokenizer.convert_tokens_to_ids(tokenizer.pad_token)) # do greedy decoding - print('outputs = ', outputs) - for output in outputs: - print('Generated: {}'.format(tokenizer.decode(output, skip_special_tokens=True))) - \ No newline at end of file + return inputs \ No newline at end of file diff --git a/genienlp/paraphrase/evaluate_bart.py b/genienlp/paraphrase/evaluate_bart.py deleted file mode 100644 index 474c7284..00000000 --- a/genienlp/paraphrase/evaluate_bart.py +++ /dev/null @@ -1,85 +0,0 @@ -import argparse -from pathlib import Path - -import torch -from tqdm import tqdm - -from transformers import BartForConditionalGeneration, BartTokenizer -from genienlp.paraphrase.train_bart import BartSystem - - -DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" - - -def chunks(lst, n): - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i : i + n] - - -def generate_summaries( - examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = DEFAULT_DEVICE -): - # b = BartSystem.load_from_checkpoint('./workdir/models/bart-large-2to1/checkpointcheckpoint_ckpt_epoch_1.ckpt') - # b.model.save_pretrained('./workdir/models/bart-large-2to1/') - # b.tokenizer.save_pretrained('./workdir/models/bart-large-2to1/') - model = BartForConditionalGeneration.from_pretrained(model_name).to(device) - model.eval() - model = model.to(device) - tokenizer = BartTokenizer.from_pretrained(model_name) - - max_length = 140 - min_length = 1 - - fout = Path(out_file).open("w") - for batch in tqdm(list(chunks(examples, batch_size))): - dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True) - # bad = ['which', 'Which', 'restaurant', 'restaurants'] - # bad = [tokenizer.encode(b, add_prefix_space=True, add_special_tokens=False) for b in bad] - summaries = model.generate( - input_ids=dct["input_ids"].to(device), - attention_mask=dct["attention_mask"].to(device), - num_beams=16, - do_sample=False, - temperature=1, - length_penalty=1, - max_length=max_length + 2, # +2 from original because we start at step=1 and stop before max_length - min_length=min_length + 1, # +1 from original because we start at step=1 - no_repeat_ngram_size=3, - early_stopping=True, - decoder_start_token_id=model.config.eos_token_id, - num_return_sequences=4 - # bad_words_ids=bad - ) - # print(bad) - # print(summaries) - dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in summaries] - for hypothesis in dec: - fout.write(hypothesis + "\n") - fout.flush() - - -def run_generate(): - parser = argparse.ArgumentParser() - parser.add_argument( - "source_path", type=str, help="like cnn_dm/test.source", - ) - parser.add_argument( - "output_path", type=str, help="where to save summaries", - ) - parser.add_argument( - "model_name", type=str, default="bart-large-cnn", help="like bart-large-cnn", - ) - parser.add_argument( - "--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.", - ) - parser.add_argument( - "--bs", type=int, default=8, required=False, help="batch size: how many to summarize at a time", - ) - args = parser.parse_args() - examples = [" " + x.rstrip() for x in open(args.source_path).readlines()] - generate_summaries(examples, args.output_path, args.model_name, batch_size=args.bs, device=args.device) - - -if __name__ == "__main__": - run_generate() diff --git a/genienlp/run_generation.py b/genienlp/run_generation.py index e3a783d1..12c44d60 100644 --- a/genienlp/run_generation.py +++ b/genienlp/run_generation.py @@ -54,7 +54,6 @@ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(messa level = logging.INFO) logger = logging.getLogger(__name__) -MAX_LENGTH = int(1000) # Hardcoded max length to avoid infinite loop ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, BartConfig)), ()) @@ -82,7 +81,7 @@ special_pattern_mapping = [ ] def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_column, prompt_column, copy, thingtalk_column, sep_token, - skip_heuristics, is_cased): + skip_heuristics, is_cased, model_type): """ Read a tsv file (this includes a text file with one example per line) and returns input features that the model needs Outputs: @@ -91,7 +90,7 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum all_input_sequences = [] all_input_sequence_lengths = [] all_context_tokens = [] - all_context_lengths = [] + estimated_output_lengths = [] all_golds = [] reverse_maps = [] @@ -106,7 +105,7 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum for line in tqdm(input_file, desc='Reading Input File', total=number_of_lines, disable=disable_tqdm): - row = line.split('\t') + row = [r.strip() for r in line.split('\t')] input_sequence = row[input_column] gold = row[gold_column] # logger.info('gold = %s', gold) @@ -123,28 +122,28 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum input_sequence, reverse_map = input_heuristics(input_sequence, thingtalk, is_cased) # logger.info('input_sequence = %s', input_sequence) reverse_maps.append(reverse_map) - input_sequence += sep_token - prompt = '' # includes the first few tokens of the output + input_sequence_tokens = tokenizer.encode(input_sequence,add_special_tokens=True) # add_special_tokens=True for gpt2 should have no effect, but as of transformers==2.8.0, a bug results in token_ids getting changed + + prompt_tokens = [] # includes the first few tokens of the output if prompt_column is not None and len(row) > prompt_column: prompt = row[prompt_column] if not skip_heuristics: prompt, _ = input_heuristics(prompt, thingtalk, is_cased) # logger.info('prompt = %s', prompt) - input_sequence_tokens = tokenizer.encode(input_sequence, add_special_tokens=False) - prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False) - context_tokens = input_sequence_tokens + prompt_tokens + prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False) if copy > 0: - assert prompt == '' - context_tokens.extend(context_tokens[0 : min(copy, len(context_tokens)-1)]) # -1 since we should not copy prompt_token + assert len(prompt_tokens) == 0 + prompt_tokens = context_tokens[0 : min(copy, len(context_tokens)-1)] # -1 since we should not copy sep_token + context_tokens = input_sequence_tokens + [tokenizer.convert_tokens_to_ids(sep_token)] + prompt_tokens all_input_sequences.append(input_sequence) all_input_sequence_lengths.append(len(input_sequence_tokens)) all_context_tokens.append(context_tokens) - all_context_lengths.append(len(context_tokens)) + estimated_output_lengths.append(len(input_sequence_tokens)-len(prompt_tokens)) if file_path is not None: input_file.close() - return all_input_sequences, all_input_sequence_lengths, all_context_tokens, all_context_lengths, all_golds, reverse_maps + return all_input_sequences, all_input_sequence_lengths, all_context_tokens, estimated_output_lengths, all_golds, reverse_maps def is_question(sentence: str): question_words = ['which', 'what', 'where', 'how', 'who', 'when', 'is', 'are', 'am', \ @@ -372,52 +371,60 @@ def run_generation(args): tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) model = model_class.from_pretrained(args.model_name_or_path) model.to(args.device) - model.eval() - if args.length < 0 and model.config.max_position_embeddings > 0: - args.length = model.config.max_position_embeddings - elif 0 < model.config.max_position_embeddings < args.length: - args.length = model.config.max_position_embeddings # No generation bigger than model size - elif args.length < 0: - args.length = MAX_LENGTH # avoid infinite loop - - logger.info(args) - - pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) - sep_token_id = tokenizer.convert_tokens_to_ids(args.sep_token) if pad_token_id is None: logger.error('Your tokenizer does not have a padding token') - all_input_sequences, all_input_sequence_lengths, all_context_tokens, all_context_lengths, all_golds, reverse_maps = \ + if args.model_type == 'gpt2': + model.set_token_ids(end_token_id=tokenizer.convert_tokens_to_ids(args.stop_tokens[0]), + sep_token_id=tokenizer.convert_tokens_to_ids(args.sep_token), + pad_token_id=pad_token_id) + + logger.info(args) + + all_input_sequences, all_input_sequence_lengths, all_context_tokens, estimated_output_lengths, all_golds, reverse_maps = \ create_features_from_tsv_file(file_path=args.input_file, tokenizer=tokenizer, input_column=args.input_column, gold_column=args.gold_column, prompt_column=args.prompt_column, copy=args.copy, thingtalk_column=args.thingtalk_column, - sep_token=args.sep_token, skip_heuristics=args.skip_heuristics, is_cased=args.is_cased) + sep_token=args.sep_token, skip_heuristics=args.skip_heuristics, is_cased=args.is_cased, + model_type=args.model_type) # sort contexts based on their context length so that less generated tokens are thrown away and generation can be done faster - all_context_lengths, all_input_sequence_lengths, all_input_sequences, all_context_tokens, original_order, reverse_maps = \ - tuple(zip(*sorted(list(zip(all_context_lengths, all_input_sequence_lengths, all_input_sequences, all_context_tokens, range(len(all_context_tokens)), reverse_maps)), reverse=True))) + estimated_output_lengths, all_input_sequence_lengths, all_input_sequences, all_context_tokens, original_order, reverse_maps = \ + tuple(zip(*sorted(list(zip(estimated_output_lengths, all_input_sequence_lengths, all_input_sequences, all_context_tokens, range(len(all_context_tokens)), reverse_maps)), reverse=True))) all_outputs = [] stop_token_ids = [tokenizer.convert_tokens_to_ids(stop_token) for stop_token in args.stop_tokens] - for batch in trange(math.ceil(len(all_context_tokens) / args.batch_size), desc="Batch"): + for batch in tqdm(range(math.ceil(len(all_context_tokens) / args.batch_size)), desc="Batch"): batch_slice = (batch*args.batch_size, min((batch+1)*args.batch_size, len(all_context_tokens))) batch_size = batch_slice[1] - batch_slice[0] batch_input_sequences = all_input_sequences[batch_slice[0]: batch_slice[1]] batch_input_sequence_lengths = all_input_sequence_lengths[batch_slice[0]: batch_slice[1]] batch_context_tokens = all_context_tokens[batch_slice[0]: batch_slice[1]] batch_reverse_maps = reverse_maps[batch_slice[0]: batch_slice[1]] + # logger.info('batch_context_tokens = %s', str(batch_context_tokens)) - batch_context_tensor = input_tensor = torch.tensor(model.pad_to_max_length(batch_context_tokens), dtype=torch.long, device=args.device) + if args.model_type == 'gpt2': + batch_context_tensor = torch.tensor(model.pad_to_max_length(batch_context_tokens), dtype=torch.long, device=args.device) + attention_mask = None + elif args.model_type == 'bart': + padded_batch_context_tokens = [] + max_length = max([len(s) for s in batch_context_tokens]) + for i in range(len(batch_context_tokens)): + padded_batch_context_tokens.append(batch_context_tokens[i]+[pad_token_id]*(max_length-len(batch_context_tokens[i]))) + batch_context_tensor = torch.tensor(padded_batch_context_tokens, dtype=torch.long, device=args.device) + attention_mask = (batch_context_tensor!=pad_token_id).to(torch.long) + # logger.info('batch_context_tensor = %s', str(batch_context_tensor)) batch_outputs = [[] for _ in range(batch_size)] for hyperparameter_idx in range(len(args.temperature)): out = model.generate(input_ids=batch_context_tensor, + attention_mask=attention_mask, min_length=args.min_output_length, max_length=batch_context_tensor.shape[1]+args.length, num_beams=args.num_beams[hyperparameter_idx], @@ -431,32 +438,24 @@ def run_generation(args): eos_token_id=stop_token_ids[0], pad_token_id=pad_token_id ) - - out = out[:, :].tolist() - for i, o in enumerate(out): - # logger.info('all output tokens: %s', str(o)) - # logger.info('all output tokens detokenized: %s', str(tokenizer.decode(o, clean_up_tokenization_spaces=True, skip_special_tokens=False))) - o = [x for x in o if x!=pad_token_id][batch_input_sequence_lengths[(i//args.num_samples) % batch_size]:] - # logger.info('original context tokens: %s', str(batch_context_tokens[(i//args.num_samples) % batch_size])) - # logger.info('original input sequence: %s', str(batch_input_sequences[(i//args.num_samples) % batch_size])) + if not isinstance(out, list): + out = out[:, :].tolist() + for i, o in enumerate(out): if args.stop_tokens is not None: - min_index = len(o) + min_index = len(o)-1 for stop_token_id in stop_token_ids: try: index = o.index(stop_token_id) min_index = min(index, min_index) except ValueError: pass - if min_index < len(o) and o[min_index] == tokenizer.convert_tokens_to_ids('?'): - # always include the question mark - min_index = min_index + 1 + if o[min_index] != stop_token_ids[0]: + min_index = min_index + 1 # include stop_token if it is not end_token o = o[:min_index] - text = tokenizer.decode(o, clean_up_tokenization_spaces=True, skip_special_tokens=False) + text = tokenizer.decode(o, clean_up_tokenization_spaces=True, skip_special_tokens=True) - # assert tokenizer.pad_token not in text - text = text.replace(tokenizer.pad_token, '') text = re.sub('\s\s+', ' ', text) # remove duplicate white spaces text = text.strip() if not args.skip_heuristics: @@ -468,17 +467,16 @@ def run_generation(args): # sort the results back to their original order _, all_outputs = tuple(zip(*sorted(list(zip(original_order, all_outputs))))) - - metrics = compute_metrics(all_outputs, all_golds, reduction=args.metric_reduction) if args.output_file is not None: with open(args.output_file, 'w') as output_file: - if args.output_file is not None: - for output in all_outputs: - for text in output: - output_file.write(text + '\n') + for output in all_outputs: + for text in output: + output_file.write(text + '\n') else: print(json.dumps(all_outputs, indent=2)) + + metrics = compute_metrics(all_outputs, all_golds, reduction=args.metric_reduction) logger.info('Average BLEU score = %.2f', metrics['bleu']) logger.info('Exact match score = %.2f', metrics['em'])