arguments of the generation script are now model-agnostic
This commit is contained in:
parent
78fb8ab2bc
commit
fa0ef5b687
|
@ -43,6 +43,7 @@ from transformers import GPT2Config, BartConfig
|
|||
|
||||
from transformers import GPT2Tokenizer
|
||||
from transformers import BartForConditionalGeneration, BartTokenizer
|
||||
from transformers import PretrainedConfig
|
||||
from .util import set_seed, get_number_of_lines, combine_files_on_disk, split_file_on_disk, get_part_path, detokenize, tokenize, lower_case, \
|
||||
SpecialTokenMap, remove_thingtalk_quotes
|
||||
from .metrics import computeBLEU
|
||||
|
@ -58,8 +59,8 @@ logger = logging.getLogger(__name__)
|
|||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, BartConfig)), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'gpt2': (GPT2Seq2Seq, GPT2Tokenizer),
|
||||
'bart': (BartForConditionalGeneration, BartTokenizer)
|
||||
'gpt2': (GPT2Seq2Seq, GPT2Tokenizer, {'sep_token': '<paraphrase>', 'end_token': '</paraphrase>'}),
|
||||
'bart': (BartForConditionalGeneration, BartTokenizer, {'sep_token': '<s>', 'end_token': '</s>'}) # sep_token will not be used for BART
|
||||
}
|
||||
|
||||
|
||||
|
@ -122,7 +123,7 @@ def create_features_from_tsv_file(file_path, tokenizer, input_column, gold_colum
|
|||
input_sequence, reverse_map = input_heuristics(input_sequence, thingtalk, is_cased)
|
||||
# logger.info('input_sequence = %s', input_sequence)
|
||||
reverse_maps.append(reverse_map)
|
||||
input_sequence_tokens = tokenizer.encode(input_sequence,add_special_tokens=True) # add_special_tokens=True for gpt2 should have no effect, but as of transformers==2.8.0, a bug results in token_ids getting changed
|
||||
input_sequence_tokens = tokenizer.encode(input_sequence, add_special_tokens=True)
|
||||
|
||||
prompt_tokens = [] # includes the first few tokens of the output
|
||||
if prompt_column is not None and len(row) > prompt_column:
|
||||
|
@ -268,8 +269,6 @@ def compute_metrics(generations, golds, reduction='average'):
|
|||
return {'bleu': total_bleu/count, 'em': total_exact_match/count*100}
|
||||
|
||||
def parse_argv(parser):
|
||||
parser.add_argument("--model_type", default=None, type=str, required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
parser.add_argument("--input_file", type=str, help="The file from which we read prompts. Defaults to stdin.")
|
||||
|
@ -307,14 +306,20 @@ def parse_argv(parser):
|
|||
help="Avoid using CUDA when available")
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument('--sep_token', type=str, default='<paraphrase>',
|
||||
help="Token after which text generation starts. We add this to the end of all inputs.")
|
||||
parser.add_argument('--stop_tokens', type=str, nargs='+', default=['</paraphrase>'],
|
||||
help="Token at which text generation is stopped. The first element of the list is used as segment id as well.")
|
||||
parser.add_argument('--stop_tokens', type=str, nargs='+', default=[],
|
||||
help="Token at which text generation is stopped.")
|
||||
parser.add_argument('--batch_size', type=int, default=4,
|
||||
help="Batch size for text generation for each GPU.")
|
||||
|
||||
def main(args):
|
||||
config = PretrainedConfig.from_pretrained(args.model_name_or_path)
|
||||
if config.architectures[0] == 'BartForConditionalGeneration':
|
||||
args.model_type = 'bart'
|
||||
elif config.architectures[0] == 'GPT2LMHeadModel':
|
||||
args.model_type = 'gpt2'
|
||||
else:
|
||||
raise ValueError('Model should be either GPT2 or BART')
|
||||
|
||||
if args.prompt_column is not None and args.copy is not None and args.copy != 0:
|
||||
raise ValueError('Cannot copy from the input and use prompt at the same time. Disable either --copy or --prompt_column.')
|
||||
hyperparameters = ['temperature', 'top_k', 'top_p', 'repetition_penalty', 'num_beams']
|
||||
|
@ -367,7 +372,7 @@ def main(args):
|
|||
|
||||
|
||||
def run_generation(args):
|
||||
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
model_class, tokenizer_class, special_tokens = MODEL_CLASSES[args.model_type]
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||||
model = model_class.from_pretrained(args.model_name_or_path)
|
||||
model.to(args.device)
|
||||
|
@ -378,8 +383,8 @@ def run_generation(args):
|
|||
logger.error('Your tokenizer does not have a padding token')
|
||||
|
||||
if args.model_type == 'gpt2':
|
||||
model.set_token_ids(end_token_id=tokenizer.convert_tokens_to_ids(args.stop_tokens[0]),
|
||||
sep_token_id=tokenizer.convert_tokens_to_ids(args.sep_token),
|
||||
model.set_token_ids(end_token_id=tokenizer.convert_tokens_to_ids(special_tokens['end_token']),
|
||||
sep_token_id=tokenizer.convert_tokens_to_ids(special_tokens['sep_token']),
|
||||
pad_token_id=pad_token_id)
|
||||
|
||||
logger.info(args)
|
||||
|
@ -389,7 +394,7 @@ def run_generation(args):
|
|||
input_column=args.input_column, gold_column=args.gold_column, prompt_column=args.prompt_column,
|
||||
copy=args.copy,
|
||||
thingtalk_column=args.thingtalk_column,
|
||||
sep_token=args.sep_token, skip_heuristics=args.skip_heuristics, is_cased=args.is_cased,
|
||||
sep_token=special_tokens['sep_token'], skip_heuristics=args.skip_heuristics, is_cased=args.is_cased,
|
||||
model_type=args.model_type)
|
||||
|
||||
|
||||
|
@ -399,6 +404,7 @@ def run_generation(args):
|
|||
all_outputs = []
|
||||
|
||||
stop_token_ids = [tokenizer.convert_tokens_to_ids(stop_token) for stop_token in args.stop_tokens]
|
||||
end_token_id = tokenizer.convert_tokens_to_ids(special_tokens['end_token'])
|
||||
|
||||
for batch in tqdm(range(math.ceil(len(all_context_tokens) / args.batch_size)), desc="Batch"):
|
||||
batch_slice = (batch*args.batch_size, min((batch+1)*args.batch_size, len(all_context_tokens)))
|
||||
|
@ -419,11 +425,13 @@ def run_generation(args):
|
|||
padded_batch_context_tokens.append(batch_context_tokens[i]+[pad_token_id]*(max_length-len(batch_context_tokens[i])))
|
||||
batch_context_tensor = torch.tensor(padded_batch_context_tokens, dtype=torch.long, device=args.device)
|
||||
attention_mask = (batch_context_tensor!=pad_token_id).to(torch.long)
|
||||
# logger.info('context text = %s', [tokenizer.decode(b, clean_up_tokenization_spaces=False, skip_special_tokens=False) for b in batch_context_tensor])
|
||||
# logger.info('batch_context_tensor = %s', str(batch_context_tensor))
|
||||
|
||||
batch_outputs = [[] for _ in range(batch_size)]
|
||||
for hyperparameter_idx in range(len(args.temperature)):
|
||||
out = model.generate(input_ids=batch_context_tensor,
|
||||
bad_words_ids=[[tokenizer.convert_tokens_to_ids(special_tokens['sep_token'])]] if args.model_type=='gpt2' else None,
|
||||
attention_mask=attention_mask,
|
||||
min_length=args.min_output_length,
|
||||
max_length=batch_context_tensor.shape[1]+args.length,
|
||||
|
@ -435,24 +443,26 @@ def run_generation(args):
|
|||
repetition_penalty=args.repetition_penalty[hyperparameter_idx],
|
||||
do_sample=args.temperature[hyperparameter_idx]!=0,
|
||||
temperature=args.temperature[hyperparameter_idx] if args.temperature[hyperparameter_idx] > 0 else 1.0, # if temperature==0, we do not sample
|
||||
eos_token_id=stop_token_ids[0],
|
||||
eos_token_id=end_token_id,
|
||||
pad_token_id=pad_token_id
|
||||
)
|
||||
|
||||
# logger.info('out = %s', str(out))
|
||||
# logger.info('out text = %s', [tokenizer.decode(o, clean_up_tokenization_spaces=False, skip_special_tokens=False) for o in out])
|
||||
if not isinstance(out, list):
|
||||
out = out[:, :].tolist()
|
||||
for i, o in enumerate(out):
|
||||
if args.stop_tokens is not None:
|
||||
min_index = len(o)-1
|
||||
for stop_token_id in stop_token_ids:
|
||||
try:
|
||||
index = o.index(stop_token_id)
|
||||
min_index = min(index, min_index)
|
||||
except ValueError:
|
||||
pass
|
||||
if o[min_index] != stop_token_ids[0]:
|
||||
min_index = min_index + 1 # include stop_token if it is not end_token
|
||||
o = o[:min_index]
|
||||
if args.model_type=='bart':
|
||||
o = o[1:]
|
||||
min_index = len(o)-1
|
||||
for stop_token_id in stop_token_ids+[end_token_id]:
|
||||
try:
|
||||
index = o.index(stop_token_id)
|
||||
min_index = min(index, min_index)
|
||||
except ValueError:
|
||||
pass
|
||||
if o[min_index] != end_token_id:
|
||||
min_index = min_index + 1 # include the last token if it is not end_token
|
||||
o = o[:min_index]
|
||||
|
||||
text = tokenizer.decode(o, clean_up_tokenization_spaces=True, skip_special_tokens=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue