diff --git a/genienlp/run_generation.py b/genienlp/run_generation.py index 12c44d60..310ca3b5 100644 --- a/genienlp/run_generation.py +++ b/genienlp/run_generation.py @@ -43,6 +43,7 @@ from transformers import GPT2Config, BartConfig from transformers import GPT2Tokenizer from transformers import BartForConditionalGeneration, BartTokenizer +from transformers import PretrainedConfig from .util import set_seed, get_number_of_lines, combine_files_on_disk, split_file_on_disk, get_part_path, detokenize, tokenize, lower_case, \ SpecialTokenMap, remove_thingtalk_quotes from .metrics import computeBLEU @@ -58,8 +59,8 @@ logger = logging.getLogger(__name__) ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, BartConfig)), ()) MODEL_CLASSES = { - 'gpt2': (GPT2Seq2Seq, GPT2Tokenizer), - 'bart': (BartForConditionalGeneration, BartTokenizer) + 'gpt2': (GPT2Seq2Seq, GPT2Tokenizer, {'sep_token': '', 'end_token': ''}), + 'bart': (BartForConditionalGeneration, BartTokenizer, {'sep_token': '', 'end_token': ''}) # sep_token will not be used for BART } @@ -122,7 +123,7 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum input_sequence, reverse_map = input_heuristics(input_sequence, thingtalk, is_cased) # logger.info('input_sequence = %s', input_sequence) reverse_maps.append(reverse_map) - input_sequence_tokens = tokenizer.encode(input_sequence,add_special_tokens=True) # add_special_tokens=True for gpt2 should have no effect, but as of transformers==2.8.0, a bug results in token_ids getting changed + input_sequence_tokens = tokenizer.encode(input_sequence, add_special_tokens=True) prompt_tokens = [] # includes the first few tokens of the output if prompt_column is not None and len(row) > prompt_column: @@ -268,8 +269,6 @@ def compute_metrics(generations, golds, reduction='average'): return {'bleu': total_bleu/count, 'em': total_exact_match/count*100} def parse_argv(parser): - parser.add_argument("--model_type", default=None, type=str, required=True, - help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) parser.add_argument("--model_name_or_path", default=None, type=str, required=True, help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) parser.add_argument("--input_file", type=str, help="The file from which we read prompts. Defaults to stdin.") @@ -307,14 +306,20 @@ def parse_argv(parser): help="Avoid using CUDA when available") parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") - parser.add_argument('--sep_token', type=str, default='', - help="Token after which text generation starts. We add this to the end of all inputs.") - parser.add_argument('--stop_tokens', type=str, nargs='+', default=[''], - help="Token at which text generation is stopped. The first element of the list is used as segment id as well.") + parser.add_argument('--stop_tokens', type=str, nargs='+', default=[], + help="Token at which text generation is stopped.") parser.add_argument('--batch_size', type=int, default=4, help="Batch size for text generation for each GPU.") def main(args): + config = PretrainedConfig.from_pretrained(args.model_name_or_path) + if config.architectures[0] == 'BartForConditionalGeneration': + args.model_type = 'bart' + elif config.architectures[0] == 'GPT2LMHeadModel': + args.model_type = 'gpt2' + else: + raise ValueError('Model should be either GPT2 or BART') + if args.prompt_column is not None and args.copy is not None and args.copy != 0: raise ValueError('Cannot copy from the input and use prompt at the same time. Disable either --copy or --prompt_column.') hyperparameters = ['temperature', 'top_k', 'top_p', 'repetition_penalty', 'num_beams'] @@ -367,7 +372,7 @@ def main(args): def run_generation(args): - model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + model_class, tokenizer_class, special_tokens = MODEL_CLASSES[args.model_type] tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) model = model_class.from_pretrained(args.model_name_or_path) model.to(args.device) @@ -378,8 +383,8 @@ def run_generation(args): logger.error('Your tokenizer does not have a padding token') if args.model_type == 'gpt2': - model.set_token_ids(end_token_id=tokenizer.convert_tokens_to_ids(args.stop_tokens[0]), - sep_token_id=tokenizer.convert_tokens_to_ids(args.sep_token), + model.set_token_ids(end_token_id=tokenizer.convert_tokens_to_ids(special_tokens['end_token']), + sep_token_id=tokenizer.convert_tokens_to_ids(special_tokens['sep_token']), pad_token_id=pad_token_id) logger.info(args) @@ -389,7 +394,7 @@ def run_generation(args): input_column=args.input_column, gold_column=args.gold_column, prompt_column=args.prompt_column, copy=args.copy, thingtalk_column=args.thingtalk_column, - sep_token=args.sep_token, skip_heuristics=args.skip_heuristics, is_cased=args.is_cased, + sep_token=special_tokens['sep_token'], skip_heuristics=args.skip_heuristics, is_cased=args.is_cased, model_type=args.model_type) @@ -399,6 +404,7 @@ def run_generation(args): all_outputs = [] stop_token_ids = [tokenizer.convert_tokens_to_ids(stop_token) for stop_token in args.stop_tokens] + end_token_id = tokenizer.convert_tokens_to_ids(special_tokens['end_token']) for batch in tqdm(range(math.ceil(len(all_context_tokens) / args.batch_size)), desc="Batch"): batch_slice = (batch*args.batch_size, min((batch+1)*args.batch_size, len(all_context_tokens))) @@ -419,11 +425,13 @@ def run_generation(args): padded_batch_context_tokens.append(batch_context_tokens[i]+[pad_token_id]*(max_length-len(batch_context_tokens[i]))) batch_context_tensor = torch.tensor(padded_batch_context_tokens, dtype=torch.long, device=args.device) attention_mask = (batch_context_tensor!=pad_token_id).to(torch.long) + # logger.info('context text = %s', [tokenizer.decode(b, clean_up_tokenization_spaces=False, skip_special_tokens=False) for b in batch_context_tensor]) # logger.info('batch_context_tensor = %s', str(batch_context_tensor)) batch_outputs = [[] for _ in range(batch_size)] for hyperparameter_idx in range(len(args.temperature)): out = model.generate(input_ids=batch_context_tensor, + bad_words_ids=[[tokenizer.convert_tokens_to_ids(special_tokens['sep_token'])]] if args.model_type=='gpt2' else None, attention_mask=attention_mask, min_length=args.min_output_length, max_length=batch_context_tensor.shape[1]+args.length, @@ -435,24 +443,26 @@ def run_generation(args): repetition_penalty=args.repetition_penalty[hyperparameter_idx], do_sample=args.temperature[hyperparameter_idx]!=0, temperature=args.temperature[hyperparameter_idx] if args.temperature[hyperparameter_idx] > 0 else 1.0, # if temperature==0, we do not sample - eos_token_id=stop_token_ids[0], + eos_token_id=end_token_id, pad_token_id=pad_token_id ) - + # logger.info('out = %s', str(out)) + # logger.info('out text = %s', [tokenizer.decode(o, clean_up_tokenization_spaces=False, skip_special_tokens=False) for o in out]) if not isinstance(out, list): out = out[:, :].tolist() for i, o in enumerate(out): - if args.stop_tokens is not None: - min_index = len(o)-1 - for stop_token_id in stop_token_ids: - try: - index = o.index(stop_token_id) - min_index = min(index, min_index) - except ValueError: - pass - if o[min_index] != stop_token_ids[0]: - min_index = min_index + 1 # include stop_token if it is not end_token - o = o[:min_index] + if args.model_type=='bart': + o = o[1:] + min_index = len(o)-1 + for stop_token_id in stop_token_ids+[end_token_id]: + try: + index = o.index(stop_token_id) + min_index = min(index, min_index) + except ValueError: + pass + if o[min_index] != end_token_id: + min_index = min_index + 1 # include the last token if it is not end_token + o = o[:min_index] text = tokenizer.decode(o, clean_up_tokenization_spaces=True, skip_special_tokens=True)