arguments of the generation script are now model-agnostic

This commit is contained in:
Sina 2020-04-27 23:57:23 -07:00
parent 78fb8ab2bc
commit fa0ef5b687
1 changed files with 36 additions and 26 deletions

View File

@ -43,6 +43,7 @@ from transformers import GPT2Config, BartConfig
from transformers import GPT2Tokenizer
from transformers import BartForConditionalGeneration, BartTokenizer
from transformers import PretrainedConfig
from .util import set_seed, get_number_of_lines, combine_files_on_disk, split_file_on_disk, get_part_path, detokenize, tokenize, lower_case, \
SpecialTokenMap, remove_thingtalk_quotes
from .metrics import computeBLEU
@ -58,8 +59,8 @@ logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, BartConfig)), ())
MODEL_CLASSES = {
'gpt2': (GPT2Seq2Seq, GPT2Tokenizer),
'bart': (BartForConditionalGeneration, BartTokenizer)
'gpt2': (GPT2Seq2Seq, GPT2Tokenizer, {'sep_token': '<paraphrase>', 'end_token': '</paraphrase>'}),
'bart': (BartForConditionalGeneration, BartTokenizer, {'sep_token': '<s>', 'end_token': '</s>'}) # sep_token will not be used for BART
}
@ -122,7 +123,7 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum
input_sequence, reverse_map = input_heuristics(input_sequence, thingtalk, is_cased)
# logger.info('input_sequence = %s', input_sequence)
reverse_maps.append(reverse_map)
input_sequence_tokens = tokenizer.encode(input_sequence,add_special_tokens=True) # add_special_tokens=True for gpt2 should have no effect, but as of transformers==2.8.0, a bug results in token_ids getting changed
input_sequence_tokens = tokenizer.encode(input_sequence, add_special_tokens=True)
prompt_tokens = [] # includes the first few tokens of the output
if prompt_column is not None and len(row) > prompt_column:
@ -268,8 +269,6 @@ def compute_metrics(generations, golds, reduction='average'):
return {'bleu': total_bleu/count, 'em': total_exact_match/count*100}
def parse_argv(parser):
parser.add_argument("--model_type", default=None, type=str, required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--input_file", type=str, help="The file from which we read prompts. Defaults to stdin.")
@ -307,14 +306,20 @@ def parse_argv(parser):
help="Avoid using CUDA when available")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--sep_token', type=str, default='<paraphrase>',
help="Token after which text generation starts. We add this to the end of all inputs.")
parser.add_argument('--stop_tokens', type=str, nargs='+', default=['</paraphrase>'],
help="Token at which text generation is stopped. The first element of the list is used as segment id as well.")
parser.add_argument('--stop_tokens', type=str, nargs='+', default=[],
help="Token at which text generation is stopped.")
parser.add_argument('--batch_size', type=int, default=4,
help="Batch size for text generation for each GPU.")
def main(args):
config = PretrainedConfig.from_pretrained(args.model_name_or_path)
if config.architectures[0] == 'BartForConditionalGeneration':
args.model_type = 'bart'
elif config.architectures[0] == 'GPT2LMHeadModel':
args.model_type = 'gpt2'
else:
raise ValueError('Model should be either GPT2 or BART')
if args.prompt_column is not None and args.copy is not None and args.copy != 0:
raise ValueError('Cannot copy from the input and use prompt at the same time. Disable either --copy or --prompt_column.')
hyperparameters = ['temperature', 'top_k', 'top_p', 'repetition_penalty', 'num_beams']
@ -367,7 +372,7 @@ def main(args):
def run_generation(args):
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
model_class, tokenizer_class, special_tokens = MODEL_CLASSES[args.model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path)
model.to(args.device)
@ -378,8 +383,8 @@ def run_generation(args):
logger.error('Your tokenizer does not have a padding token')
if args.model_type == 'gpt2':
model.set_token_ids(end_token_id=tokenizer.convert_tokens_to_ids(args.stop_tokens[0]),
sep_token_id=tokenizer.convert_tokens_to_ids(args.sep_token),
model.set_token_ids(end_token_id=tokenizer.convert_tokens_to_ids(special_tokens['end_token']),
sep_token_id=tokenizer.convert_tokens_to_ids(special_tokens['sep_token']),
pad_token_id=pad_token_id)
logger.info(args)
@ -389,7 +394,7 @@ def run_generation(args):
input_column=args.input_column, gold_column=args.gold_column, prompt_column=args.prompt_column,
copy=args.copy,
thingtalk_column=args.thingtalk_column,
sep_token=args.sep_token, skip_heuristics=args.skip_heuristics, is_cased=args.is_cased,
sep_token=special_tokens['sep_token'], skip_heuristics=args.skip_heuristics, is_cased=args.is_cased,
model_type=args.model_type)
@ -399,6 +404,7 @@ def run_generation(args):
all_outputs = []
stop_token_ids = [tokenizer.convert_tokens_to_ids(stop_token) for stop_token in args.stop_tokens]
end_token_id = tokenizer.convert_tokens_to_ids(special_tokens['end_token'])
for batch in tqdm(range(math.ceil(len(all_context_tokens) / args.batch_size)), desc="Batch"):
batch_slice = (batch*args.batch_size, min((batch+1)*args.batch_size, len(all_context_tokens)))
@ -419,11 +425,13 @@ def run_generation(args):
padded_batch_context_tokens.append(batch_context_tokens[i]+[pad_token_id]*(max_length-len(batch_context_tokens[i])))
batch_context_tensor = torch.tensor(padded_batch_context_tokens, dtype=torch.long, device=args.device)
attention_mask = (batch_context_tensor!=pad_token_id).to(torch.long)
# logger.info('context text = %s', [tokenizer.decode(b, clean_up_tokenization_spaces=False, skip_special_tokens=False) for b in batch_context_tensor])
# logger.info('batch_context_tensor = %s', str(batch_context_tensor))
batch_outputs = [[] for _ in range(batch_size)]
for hyperparameter_idx in range(len(args.temperature)):
out = model.generate(input_ids=batch_context_tensor,
bad_words_ids=[[tokenizer.convert_tokens_to_ids(special_tokens['sep_token'])]] if args.model_type=='gpt2' else None,
attention_mask=attention_mask,
min_length=args.min_output_length,
max_length=batch_context_tensor.shape[1]+args.length,
@ -435,24 +443,26 @@ def run_generation(args):
repetition_penalty=args.repetition_penalty[hyperparameter_idx],
do_sample=args.temperature[hyperparameter_idx]!=0,
temperature=args.temperature[hyperparameter_idx] if args.temperature[hyperparameter_idx] > 0 else 1.0, # if temperature==0, we do not sample
eos_token_id=stop_token_ids[0],
eos_token_id=end_token_id,
pad_token_id=pad_token_id
)
# logger.info('out = %s', str(out))
# logger.info('out text = %s', [tokenizer.decode(o, clean_up_tokenization_spaces=False, skip_special_tokens=False) for o in out])
if not isinstance(out, list):
out = out[:, :].tolist()
for i, o in enumerate(out):
if args.stop_tokens is not None:
min_index = len(o)-1
for stop_token_id in stop_token_ids:
try:
index = o.index(stop_token_id)
min_index = min(index, min_index)
except ValueError:
pass
if o[min_index] != stop_token_ids[0]:
min_index = min_index + 1 # include stop_token if it is not end_token
o = o[:min_index]
if args.model_type=='bart':
o = o[1:]
min_index = len(o)-1
for stop_token_id in stop_token_ids+[end_token_id]:
try:
index = o.index(stop_token_id)
min_index = min(index, min_index)
except ValueError:
pass
if o[min_index] != end_token_id:
min_index = min_index + 1 # include the last token if it is not end_token
o = o[:min_index]
text = tokenizer.decode(o, clean_up_tokenization_spaces=True, skip_special_tokens=True)