Add tests for calibrator
This commit is contained in:
parent
a1b413b3cc
commit
4d481ad472
|
@ -196,6 +196,7 @@ def parse_argv(parser):
|
||||||
parser.add_argument('--seed', type=int, default=123, help='Random seed to use for reproducibility')
|
parser.add_argument('--seed', type=int, default=123, help='Random seed to use for reproducibility')
|
||||||
parser.add_argument('--save', type=str, help='The directory to save the calibrator model and plots after training')
|
parser.add_argument('--save', type=str, help='The directory to save the calibrator model and plots after training')
|
||||||
parser.add_argument('--plot', action='store_true', help='If True, will plot metrics and save them. Requires Matplotlib installation.')
|
parser.add_argument('--plot', action='store_true', help='If True, will plot metrics and save them. Requires Matplotlib installation.')
|
||||||
|
parser.add_argument('--testing', action='store_true', help='If True, will change labels so that not all of them are equal. This is only used for testing purposes.')
|
||||||
|
|
||||||
|
|
||||||
class ConfidenceEstimator():
|
class ConfidenceEstimator():
|
||||||
|
@ -492,17 +493,20 @@ def main(args):
|
||||||
]:
|
]:
|
||||||
if name.startswith('raw'):
|
if name.startswith('raw'):
|
||||||
estimator_class = RawConfidenceEstimator
|
estimator_class = RawConfidenceEstimator
|
||||||
mc_dropout = False
|
|
||||||
mc_dropout_num = 0
|
|
||||||
else:
|
else:
|
||||||
estimator_class = TreeConfidenceEstimator
|
estimator_class = TreeConfidenceEstimator
|
||||||
mc_dropout = train_confidences[0][0].mc_dropout
|
mc_dropout = train_confidences[0][0].mc_dropout
|
||||||
mc_dropout_num = train_confidences[0][0].mc_dropout_num
|
mc_dropout_num = train_confidences[0][0].mc_dropout_num
|
||||||
estimator = estimator_class(name=name, featurizers=f, eval_metric=args.eval_metric, mc_dropout=mc_dropout, mc_dropout_num=mc_dropout_num)
|
estimator = estimator_class(name=name, featurizers=f, eval_metric=args.eval_metric, mc_dropout=mc_dropout, mc_dropout_num=mc_dropout_num)
|
||||||
logger.info('name = %s', name)
|
logger.info('name = %s', name)
|
||||||
|
|
||||||
train_features, train_labels = estimator.convert_to_dataset(train_confidences, train=True)
|
train_features, train_labels = estimator.convert_to_dataset(train_confidences, train=True)
|
||||||
dev_features, dev_labels = estimator.convert_to_dataset(dev_confidences, train=False)
|
dev_features, dev_labels = estimator.convert_to_dataset(dev_confidences, train=False)
|
||||||
|
if args.testing:
|
||||||
|
if train_labels.all() or (~train_labels).all():
|
||||||
|
train_labels[0] = ~train_labels[0]
|
||||||
|
if dev_labels.all() or (~dev_labels).all():
|
||||||
|
dev_labels[0] = ~dev_labels[0]
|
||||||
estimator.train_and_validate(train_features, train_labels, dev_features, dev_labels)
|
estimator.train_and_validate(train_features, train_labels, dev_features, dev_labels)
|
||||||
precision, recall, pass_rate, accuracies, thresholds = estimator.evaluate(dev_features, dev_labels)
|
precision, recall, pass_rate, accuracies, thresholds = estimator.evaluate(dev_features, dev_labels)
|
||||||
if args.plot:
|
if args.plot:
|
||||||
|
|
|
@ -118,7 +118,7 @@ class Server:
|
||||||
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
|
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
|
||||||
output_predictions_only=True,
|
output_predictions_only=True,
|
||||||
confidence_estimator=self.confidence_estimator)
|
confidence_estimator=self.confidence_estimator)
|
||||||
response = json.dumps(dict(id=request['id'], answer=output.predictions[0][0], score=output.confidence_scores[0]))
|
response = json.dumps(dict(id=request['id'], answer=output.predictions[0][0], score=float(output.confidence_scores[0])))
|
||||||
else:
|
else:
|
||||||
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
|
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
|
||||||
output_predictions_only=True)
|
output_predictions_only=True)
|
||||||
|
|
|
@ -63,6 +63,40 @@ do
|
||||||
i=$((i+1))
|
i=$((i+1))
|
||||||
done
|
done
|
||||||
|
|
||||||
|
# test calibration
|
||||||
|
for hparams in \
|
||||||
|
"--model TransformerSeq2Seq --pretrained_model sshleifer/bart-tiny-random" ;
|
||||||
|
do
|
||||||
|
|
||||||
|
# train
|
||||||
|
pipenv run python3 -m genienlp train --train_tasks almond --train_batch_tokens 100 --val_batch_size 100 --train_iterations 6 --preserve_case --save_every 2 --log_every 2 --val_every 2 --save $workdir/model_$i --data $SRCDIR/dataset/ $hparams --exist_ok --skip_cache --embeddings $embedding_dir --no_commit
|
||||||
|
|
||||||
|
# greedy prediction
|
||||||
|
pipenv run python3 -m genienlp predict --tasks almond --evaluate test --path $workdir/model_$i --overwrite --eval_dir $workdir/model_$i/eval_results/ --data $SRCDIR/dataset/ --embeddings $embedding_dir --skip_cache --save_confidence_features --confidence_feature_path $workdir/model_$i/confidences.pkl --mc_dropout --mc_dropout_num 10
|
||||||
|
|
||||||
|
# check if confidence file exists
|
||||||
|
if test ! -f $workdir/model_$i/confidences.pkl ; then
|
||||||
|
echo "File not found!"
|
||||||
|
exit
|
||||||
|
fi
|
||||||
|
|
||||||
|
# calibrate
|
||||||
|
pipenv run python3 -m genienlp calibrate --confidence_path $workdir/model_$i/confidences.pkl --save $workdir/model_$i --testing
|
||||||
|
|
||||||
|
# check if calibrator exists
|
||||||
|
if test ! -f $workdir/model_$i/calibrator.pkl ; then
|
||||||
|
echo "File not found!"
|
||||||
|
exit
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Testing the server mode after calibration"
|
||||||
|
echo '{"id": "dummy_example_1", "context": "show me .", "question": "translate to thingtalk", "answer": "now => () => notify"}' | pipenv run python3 -m genienlp server --path $workdir/model_$i --stdin
|
||||||
|
|
||||||
|
rm -rf $workdir/model_$i $workdir/model_$i_exported
|
||||||
|
|
||||||
|
i=$((i+1))
|
||||||
|
done
|
||||||
|
|
||||||
# test almond_multilingual task
|
# test almond_multilingual task
|
||||||
for hparams in \
|
for hparams in \
|
||||||
"--model TransformerLSTM --pretrained_model bert-base-multilingual-cased --trainable_decoder_embeddings=50" \
|
"--model TransformerLSTM --pretrained_model bert-base-multilingual-cased --trainable_decoder_embeddings=50" \
|
||||||
|
|
Loading…
Reference in New Issue