Add tests for calibrator
This commit is contained in:
parent
a1b413b3cc
commit
4d481ad472
|
@ -196,6 +196,7 @@ def parse_argv(parser):
|
|||
parser.add_argument('--seed', type=int, default=123, help='Random seed to use for reproducibility')
|
||||
parser.add_argument('--save', type=str, help='The directory to save the calibrator model and plots after training')
|
||||
parser.add_argument('--plot', action='store_true', help='If True, will plot metrics and save them. Requires Matplotlib installation.')
|
||||
parser.add_argument('--testing', action='store_true', help='If True, will change labels so that not all of them are equal. This is only used for testing purposes.')
|
||||
|
||||
|
||||
class ConfidenceEstimator():
|
||||
|
@ -492,17 +493,20 @@ def main(args):
|
|||
]:
|
||||
if name.startswith('raw'):
|
||||
estimator_class = RawConfidenceEstimator
|
||||
mc_dropout = False
|
||||
mc_dropout_num = 0
|
||||
else:
|
||||
estimator_class = TreeConfidenceEstimator
|
||||
mc_dropout = train_confidences[0][0].mc_dropout
|
||||
mc_dropout_num = train_confidences[0][0].mc_dropout_num
|
||||
mc_dropout = train_confidences[0][0].mc_dropout
|
||||
mc_dropout_num = train_confidences[0][0].mc_dropout_num
|
||||
estimator = estimator_class(name=name, featurizers=f, eval_metric=args.eval_metric, mc_dropout=mc_dropout, mc_dropout_num=mc_dropout_num)
|
||||
logger.info('name = %s', name)
|
||||
|
||||
train_features, train_labels = estimator.convert_to_dataset(train_confidences, train=True)
|
||||
dev_features, dev_labels = estimator.convert_to_dataset(dev_confidences, train=False)
|
||||
if args.testing:
|
||||
if train_labels.all() or (~train_labels).all():
|
||||
train_labels[0] = ~train_labels[0]
|
||||
if dev_labels.all() or (~dev_labels).all():
|
||||
dev_labels[0] = ~dev_labels[0]
|
||||
estimator.train_and_validate(train_features, train_labels, dev_features, dev_labels)
|
||||
precision, recall, pass_rate, accuracies, thresholds = estimator.evaluate(dev_features, dev_labels)
|
||||
if args.plot:
|
||||
|
|
|
@ -118,7 +118,7 @@ class Server:
|
|||
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
|
||||
output_predictions_only=True,
|
||||
confidence_estimator=self.confidence_estimator)
|
||||
response = json.dumps(dict(id=request['id'], answer=output.predictions[0][0], score=output.confidence_scores[0]))
|
||||
response = json.dumps(dict(id=request['id'], answer=output.predictions[0][0], score=float(output.confidence_scores[0])))
|
||||
else:
|
||||
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
|
||||
output_predictions_only=True)
|
||||
|
|
|
@ -63,6 +63,40 @@ do
|
|||
i=$((i+1))
|
||||
done
|
||||
|
||||
# test calibration
|
||||
for hparams in \
|
||||
"--model TransformerSeq2Seq --pretrained_model sshleifer/bart-tiny-random" ;
|
||||
do
|
||||
|
||||
# train
|
||||
pipenv run python3 -m genienlp train --train_tasks almond --train_batch_tokens 100 --val_batch_size 100 --train_iterations 6 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --embeddings $embedding_dir --no_commit
|
||||
|
||||
# greedy prediction
|
||||
pipenv run python3 -m genienlp predict --tasks almond --evaluate test --path $workdir/model_$i --overwrite --eval_dir $workdir/model_$i/eval_results/ --data $SRCDIR/dataset/ --embeddings $embedding_dir --skip_cache --save_confidence_features --confidence_feature_path $workdir/model_$i/confidences.pkl --mc_dropout --mc_dropout_num 10
|
||||
|
||||
# check if confidence file exists
|
||||
if test ! -f $workdir/model_$i/confidences.pkl ; then
|
||||
echo "File not found!"
|
||||
exit
|
||||
fi
|
||||
|
||||
# calibrate
|
||||
pipenv run python3 -m genienlp calibrate --confidence_path $workdir/model_$i/confidences.pkl --save $workdir/model_$i --testing
|
||||
|
||||
# check if calibrator exists
|
||||
if test ! -f $workdir/model_$i/calibrator.pkl ; then
|
||||
echo "File not found!"
|
||||
exit
|
||||
fi
|
||||
|
||||
echo "Testing the server mode after calibration"
|
||||
echo '{"id": "dummy_example_1", "context": "show me .", "question": "translate to thingtalk", "answer": "now => () => notify"}' | pipenv run python3 -m genienlp server --path $workdir/model_$i --stdin
|
||||
|
||||
rm -rf $workdir/model_$i $workdir/model_$i_exported
|
||||
|
||||
i=$((i+1))
|
||||
done
|
||||
|
||||
# test almond_multilingual task
|
||||
for hparams in \
|
||||
"--model TransformerLSTM --pretrained_model bert-base-multilingual-cased --trainable_decoder_embeddings=50" \
|
||||
|
|
Loading…
Reference in New Issue