Add test for natural_seq2seq and paraphrase tasks

This commit is contained in:
Sina 2021-01-28 00:50:04 -08:00
parent faa1f30e36
commit 680dff582a
1 changed files with 20 additions and 0 deletions

View File

@ -122,6 +122,26 @@ for hparams in \
i=$((i+1))
done
# test natural_seq2seq and paraphrase tasks
for hparams in \
"--model TransformerSeq2Seq --pretrained_model sshleifer/bart-tiny-random"; do
# train
pipenv run python3 -m genienlp train --train_tasks natural_seq2seq --train_batch_tokens 100 --val_batch_size 100 --train_iterations 6 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --embeddings $embedding_dir --no_commit
# greedy prediction
pipenv run python3 -m genienlp predict --tasks paraphrase --evaluate test --path $workdir/model_$i --overwrite --eval_dir $workdir/model_$i/eval_results/ --data $SRCDIR/dataset/ --embeddings $embedding_dir --skip_cache
# check if result file exists
if test ! -f $workdir/model_$i/eval_results/test/paraphrase.tsv || test ! -f $workdir/model_$i/eval_results/test/paraphrase.results.json; then
echo "File not found!"
exit
fi
rm -rf $workdir/model_$i
i=$((i+1))
done
# paraphrasing tests
cp -r $SRCDIR/dataset/paraphrasing/ $workdir/paraphrasing/