Fix bug in generation arguments
This commit is contained in:
parent
18561f9bac
commit
83552d4a68
|
@ -198,6 +198,24 @@ def parse_argv(parser):
|
|||
help='growth strategy for curriculum')
|
||||
|
||||
|
||||
def check_and_update_generation_args(args):
|
||||
"""
|
||||
checks all generation commandline arguments. Since these arguments are all lists and shorthand can be used, we expand them to match the expected length
|
||||
for instance, [1.0] becomes [1.0 1.0] if all other generation arguments are of length 2
|
||||
"""
|
||||
hyperparameters = ['num_outputs', 'temperature', 'top_k', 'top_p', 'repetition_penalty', 'num_beams', 'num_beam_groups', 'diversity_penalty', 'no_repeat_ngram_size']
|
||||
max_hyperparameter_len = max([len(getattr(args, h)) for h in hyperparameters])
|
||||
valid_len = [1, max_hyperparameter_len]
|
||||
for h in hyperparameters:
|
||||
if (len(getattr(args, h)) not in valid_len):
|
||||
logger.error('Hyperparameters should either have the same number of values as others or have exactly one value.')
|
||||
# If only one value is provided, use the same value for all samples
|
||||
setattr(args, h, getattr(args, h) * (max_hyperparameter_len // len(getattr(args, h))))
|
||||
|
||||
logger.info('Will output %d sequences for each input.', sum(args.num_outputs))
|
||||
return args
|
||||
|
||||
|
||||
def post_parse(args):
|
||||
if args.val_task_names is None:
|
||||
args.val_task_names = []
|
||||
|
@ -271,4 +289,6 @@ def post_parse(args):
|
|||
args.train_tasks = list(train_tasks_dict.values())
|
||||
val_task_dict = get_tasks(args.val_task_names, args, available_tasks=train_tasks_dict)
|
||||
args.val_tasks = list(val_task_dict.values())
|
||||
|
||||
args = check_and_update_generation_args(args)
|
||||
return args
|
||||
|
|
|
@ -52,6 +52,7 @@ from .util import set_seed, load_config_json, make_data_loader, log_model_size,
|
|||
have_multilingual, combine_folders_on_disk, split_folder_on_disk, get_part_path
|
||||
from .validate import generate_with_model, calculate_and_reduce_metrics
|
||||
from .calibrate import ConfidenceEstimator
|
||||
from .arguments import check_and_update_generation_args
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -307,23 +308,6 @@ def adjust_multilingual_eval(args):
|
|||
args.pred_languages[i] = None
|
||||
|
||||
|
||||
def check_and_update_generation_args(args):
|
||||
"""
|
||||
checks all generation commandline arguments. Since these arguments are all lists and shorthand can be used, we expand them to match the expected length
|
||||
for instance, [1.0] becomes [1.0 1.0] if all other generation arguments are of length 2
|
||||
"""
|
||||
hyperparameters = ['num_outputs', 'temperature', 'top_k', 'top_p', 'repetition_penalty', 'num_beams', 'no_repeat_ngram_size']
|
||||
max_hyperparameter_len = max([len(getattr(args, h)) for h in hyperparameters])
|
||||
valid_len = [1, max_hyperparameter_len]
|
||||
for h in hyperparameters:
|
||||
if (len(getattr(args, h)) not in valid_len):
|
||||
logger.error('Hyperparameters should either have the same number of values as others or have exactly one value.')
|
||||
# If only one value is provided, use the same value for all samples
|
||||
setattr(args, h, getattr(args, h) * (max_hyperparameter_len // len(getattr(args, h))))
|
||||
|
||||
logger.info('Will output %d sequences for each input.', sum(args.num_outputs))
|
||||
|
||||
|
||||
def set_default_values(args):
|
||||
"""
|
||||
sets default values that depend on other input arguments
|
||||
|
|
Loading…
Reference in New Issue