Fix tests
This commit is contained in:
parent
25ebee3964
commit
c046ec0eb0
|
@ -135,7 +135,7 @@ def parse_argv(parser):
|
|||
parser.add_argument(
|
||||
'--val_batch_size',
|
||||
nargs='+',
|
||||
default=[4000],
|
||||
default=None,
|
||||
type=int,
|
||||
help='Batch size for validation corresponding to tasks in val tasks',
|
||||
)
|
||||
|
@ -171,10 +171,9 @@ def parse_argv(parser):
|
|||
default=[0],
|
||||
help='ngrams of this size cannot be repeated in the output. 0 disables it.',
|
||||
)
|
||||
parser.add_argument('--max_output_length', default=150, type=int, help='maximum output length for generation')
|
||||
parser.add_argument('--max_output_length', type=int, help='maximum output length for generation')
|
||||
parser.add_argument(
|
||||
'--min_output_length',
|
||||
default=3,
|
||||
type=int,
|
||||
help='maximum output length for generation; '
|
||||
'default is 3 for most multilingual models: BOS, language code, and one token. otherwise it is 2',
|
||||
|
@ -325,9 +324,9 @@ def check_args(args):
|
|||
if not args.pred_tgt_languages:
|
||||
setattr(args, 'pred_tgt_languages', [args.eval_tgt_languages])
|
||||
|
||||
if args.is_hf_model and (not args.pred_src_languages or not args.model):
|
||||
if args.is_hf_model and (not args.pred_src_languages or not args.model or not args.min_output_length or not args.max_output_length or not args.val_batch_size):
|
||||
# because in for HF models we are not getting these values from genienlp's training script
|
||||
raise ValueError('You need to specify --pred_languages and --model when directly loading a HuggingFace model.')
|
||||
raise ValueError('You need to specify --pred_languages, --model, --min_output_length, --max_output_length and --val_batch_size when directly loading a HuggingFace model.')
|
||||
|
||||
if len(args.task_names) != len(args.pred_src_languages):
|
||||
raise ValueError('You have to define prediction languages for each task in the same order you provided the tasks.')
|
||||
|
|
|
@ -585,7 +585,11 @@ def have_multilingual(task_names):
|
|||
return any(['multilingual' in name for name in task_names])
|
||||
|
||||
|
||||
def load_config_file_to_args(args) -> bool:
|
||||
def load_config_file_to_args(args):
|
||||
if not hasattr(args, 'is_hf_model'):
|
||||
# --is_hf_model might not exist if this function is called by anything other than predict.py
|
||||
setattr(args, 'is_hf_model', False)
|
||||
|
||||
if args.is_hf_model:
|
||||
# no config file found, treat `args.path` as a model name on HuggingFace model hub
|
||||
args.pretrained_model = args.path
|
||||
|
|
|
@ -26,7 +26,8 @@ do
|
|||
--train_iterations 4 \
|
||||
--min_output_length 2 \
|
||||
--save $workdir/model_$i \
|
||||
--data $SRCDIR/dataset/bitod
|
||||
--data $SRCDIR/dataset/bitod \
|
||||
${hparams[i]}
|
||||
|
||||
# greedy prediction
|
||||
genienlp predict \
|
||||
|
|
|
@ -42,6 +42,9 @@ for model in \
|
|||
--embeddings $EMBEDDING_DIR \
|
||||
--pred_languages en \
|
||||
--model TransformerSeq2Seq \
|
||||
--min_output_length 1 \
|
||||
--max_output_length 150 \
|
||||
--val_batch_size 100 \
|
||||
--is_hf_model
|
||||
|
||||
# check if result file exists
|
||||
|
|
|
@ -20,7 +20,16 @@ genienlp train \
|
|||
--num_print 0
|
||||
|
||||
# greedy prediction
|
||||
genienlp predict --tasks ood_task --evaluate valid --pred_set_name eval --path $workdir/model --overwrite --eval_dir $workdir/model/eval_results/ --data $SRCDIR/dataset/ood/ --embeddings $EMBEDDING_DIR --val_batch_size 200
|
||||
genienlp predict \
|
||||
--tasks ood_task \
|
||||
--evaluate valid \
|
||||
--pred_set_name eval \
|
||||
--path $workdir/model \
|
||||
--overwrite \
|
||||
--eval_dir $workdir/model/eval_results/ \
|
||||
--data $SRCDIR/dataset/ood/ \
|
||||
--embeddings $EMBEDDING_DIR \
|
||||
--val_batch_size 200
|
||||
|
||||
# check if result file exists
|
||||
if test ! -f $workdir/model/eval_results/valid/ood_task.tsv ; then
|
||||
|
|
|
@ -20,7 +20,8 @@ do
|
|||
--val_batch_size 200 \
|
||||
--train_iterations 4 \
|
||||
--save $workdir/model_$i \
|
||||
--data $SRCDIR/dataset/cross_ner/
|
||||
--data $SRCDIR/dataset/cross_ner/ \
|
||||
$hparams
|
||||
|
||||
# greedy prediction
|
||||
genienlp predict \
|
||||
|
@ -66,7 +67,8 @@ do
|
|||
--val_batch_size 100 \
|
||||
--train_iterations 4 \
|
||||
--save $workdir/model_$i \
|
||||
--data $SRCDIR/dataset/cross_ner/
|
||||
--data $SRCDIR/dataset/cross_ner/ \
|
||||
$hparams
|
||||
|
||||
# greedy prediction
|
||||
genienlp predict \
|
||||
|
|
Loading…
Reference in New Issue