Fix bug in generation arguments

This commit is contained in:
Sina 2021-01-04 23:07:06 -08:00
parent 18561f9bac
commit 83552d4a68
2 changed files with 21 additions and 17 deletions

View File

@ -198,6 +198,24 @@ def parse_argv(parser):
help='growth strategy for curriculum')
def check_and_update_generation_args(args):
"""
checks all generation commandline arguments. Since these arguments are all lists and shorthand can be used, we expand them to match the expected length
for instance, [1.0] becomes [1.0 1.0] if all other generation arguments are of length 2
"""
hyperparameters = ['num_outputs', 'temperature', 'top_k', 'top_p', 'repetition_penalty', 'num_beams', 'num_beam_groups', 'diversity_penalty', 'no_repeat_ngram_size']
max_hyperparameter_len = max([len(getattr(args, h)) for h in hyperparameters])
valid_len = [1, max_hyperparameter_len]
for h in hyperparameters:
if (len(getattr(args, h)) not in valid_len):
logger.error('Hyperparameters should either have the same number of values as others or have exactly one value.')
# If only one value is provided, use the same value for all samples
setattr(args, h, getattr(args, h) * (max_hyperparameter_len // len(getattr(args, h))))
logger.info('Will output %d sequences for each input.', sum(args.num_outputs))
return args
def post_parse(args):
if args.val_task_names is None:
args.val_task_names = []
@ -271,4 +289,6 @@ def post_parse(args):
args.train_tasks = list(train_tasks_dict.values())
val_task_dict = get_tasks(args.val_task_names, args, available_tasks=train_tasks_dict)
args.val_tasks = list(val_task_dict.values())
args = check_and_update_generation_args(args)
return args

View File

@ -52,6 +52,7 @@ from .util import set_seed, load_config_json, make_data_loader, log_model_size,
have_multilingual, combine_folders_on_disk, split_folder_on_disk, get_part_path
from .validate import generate_with_model, calculate_and_reduce_metrics
from .calibrate import ConfidenceEstimator
from .arguments import check_and_update_generation_args
logger = logging.getLogger(__name__)
@ -307,23 +308,6 @@ def adjust_multilingual_eval(args):
args.pred_languages[i] = None
def check_and_update_generation_args(args):
"""
checks all generation commandline arguments. Since these arguments are all lists and shorthand can be used, we expand them to match the expected length
for instance, [1.0] becomes [1.0 1.0] if all other generation arguments are of length 2
"""
hyperparameters = ['num_outputs', 'temperature', 'top_k', 'top_p', 'repetition_penalty', 'num_beams', 'no_repeat_ngram_size']
max_hyperparameter_len = max([len(getattr(args, h)) for h in hyperparameters])
valid_len = [1, max_hyperparameter_len]
for h in hyperparameters:
if (len(getattr(args, h)) not in valid_len):
logger.error('Hyperparameters should either have the same number of values as others or have exactly one value.')
# If only one value is provided, use the same value for all samples
setattr(args, h, getattr(args, h) * (max_hyperparameter_len // len(getattr(args, h))))
logger.info('Will output %d sequences for each input.', sum(args.num_outputs))
def set_default_values(args):
"""
sets default values that depend on other input arguments