Add CEV as a confidence feature
This commit is contained in:
parent
83552d4a68
commit
927d052a6e
|
@ -14,6 +14,7 @@ dataset/*
|
||||||
generated/*
|
generated/*
|
||||||
*events.out.tfevents*
|
*events.out.tfevents*
|
||||||
*.pkl
|
*.pkl
|
||||||
|
*.png
|
||||||
|
|
||||||
.DS_Store
|
.DS_Store
|
||||||
.idea/
|
.idea/
|
||||||
|
|
|
@ -23,6 +23,9 @@ def max_of(f: Callable) -> Callable:
|
||||||
def min_of(f: Callable) -> Callable:
|
def min_of(f: Callable) -> Callable:
|
||||||
return lambda x: f(x).min().view(-1)
|
return lambda x: f(x).min().view(-1)
|
||||||
|
|
||||||
|
def neg_of(f: Callable) -> Callable:
|
||||||
|
return lambda x: -f(x)
|
||||||
|
|
||||||
def cv_drop_logit(i: int) -> Callable:
|
def cv_drop_logit(i: int) -> Callable:
|
||||||
def f(x):
|
def f(x):
|
||||||
a = torch.sqrt(var_drop_logit(i)(x)) / mean_drop_logit(i)(x)
|
a = torch.sqrt(var_drop_logit(i)(x)) / mean_drop_logit(i)(x)
|
||||||
|
@ -51,10 +54,14 @@ def input_length(i: int) -> Callable:
|
||||||
def nodrop_avg_logprob(i: int):
|
def nodrop_avg_logprob(i: int):
|
||||||
return lambda x: torch.mean(x[i].nodrop_logits).view(-1)
|
return lambda x: torch.mean(x[i].nodrop_logits).view(-1)
|
||||||
|
|
||||||
def variance_of_beams(x):
|
def variance_of_beam_logits(x):
|
||||||
a = torch.var(torch.tensor([torch.mean(x[i].nodrop_logits).item() for i in range(1, 5)])).view(-1)
|
a = torch.var(torch.tensor([torch.mean(x[i].nodrop_logits).item() for i in range(1, 5)])).view(-1)
|
||||||
return a
|
return a
|
||||||
|
|
||||||
|
def variance_of_beam_probs(x):
|
||||||
|
a = torch.var(torch.tensor([torch.mean(x[i].nodrop_probs).item() for i in range(1, 5)])).view(-1)
|
||||||
|
return a
|
||||||
|
|
||||||
def mean_drop_avg_logprob(i):
|
def mean_drop_avg_logprob(i):
|
||||||
return lambda x: torch.mean(x[i].drop_logits).view(-1)
|
return lambda x: torch.mean(x[i].drop_logits).view(-1)
|
||||||
|
|
||||||
|
@ -74,9 +81,15 @@ def nodrop_seq_prob(i):
|
||||||
def mean_drop_seq_prob(i):
|
def mean_drop_seq_prob(i):
|
||||||
return lambda x: torch.mean(torch.prod(x[i].drop_probs, dim=1)).view(-1)
|
return lambda x: torch.mean(torch.prod(x[i].drop_probs, dim=1)).view(-1)
|
||||||
|
|
||||||
|
def mean_drop_prob(i):
|
||||||
|
return lambda x: torch.mean(x[i].drop_probs, dim=0).view(-1)
|
||||||
|
|
||||||
def var_drop_seq_prob(i):
|
def var_drop_seq_prob(i):
|
||||||
return lambda x: torch.var(torch.prod(x[i].drop_probs, dim=1)).view(-1)
|
return lambda x: torch.var(torch.prod(x[i].drop_probs, dim=1)).view(-1)
|
||||||
|
|
||||||
|
def var_drop_prob(i):
|
||||||
|
return lambda x: torch.var(x[i].drop_probs, dim=0).view(-1)
|
||||||
|
|
||||||
def cv_drop_seq_prob(i):
|
def cv_drop_seq_prob(i):
|
||||||
def f(x):
|
def f(x):
|
||||||
a = torch.sqrt(var_drop_seq_prob(i)(x)) / mean_drop_seq_prob(i)(x)
|
a = torch.sqrt(var_drop_seq_prob(i)(x)) / mean_drop_seq_prob(i)(x)
|
||||||
|
@ -84,6 +97,26 @@ def cv_drop_seq_prob(i):
|
||||||
return a
|
return a
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
def cev_drop_seq_prob(i):
|
||||||
|
"""
|
||||||
|
Introduced in https://arxiv.org/pdf/1909.00157.pdf
|
||||||
|
"""
|
||||||
|
def f(x):
|
||||||
|
a = torch.square(1 - (var_drop_seq_prob(i)(x) / mean_drop_seq_prob(i)(x)))
|
||||||
|
a[a!=a] = 0 # set NaNs to zero
|
||||||
|
return a
|
||||||
|
return f
|
||||||
|
|
||||||
|
def cev_drop_prob(i):
|
||||||
|
"""
|
||||||
|
Introduced in https://arxiv.org/pdf/1909.00157.pdf
|
||||||
|
"""
|
||||||
|
def f(x):
|
||||||
|
a = torch.square(1 - (var_drop_prob(i)(x) / mean_drop_prob(i)(x)))
|
||||||
|
a[a!=a] = 0 # set NaNs to zero
|
||||||
|
return a
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
def accuracy_at_pass_rate(labels, confidence_scores):
|
def accuracy_at_pass_rate(labels, confidence_scores):
|
||||||
sorted_confidence_scores, sorted_labels = zip(*sorted(zip(confidence_scores, labels)))
|
sorted_confidence_scores, sorted_labels = zip(*sorted(zip(confidence_scores, labels)))
|
||||||
|
@ -330,11 +363,15 @@ def main(args):
|
||||||
train_confidences, dev_confidences = train_test_split(confidences, test_size=args.dev_split, random_state=args.seed)
|
train_confidences, dev_confidences = train_test_split(confidences, test_size=args.dev_split, random_state=args.seed)
|
||||||
for f, name in [
|
for f, name in [
|
||||||
# (oracle_score, 'raw_oracle'),
|
# (oracle_score, 'raw_oracle'),
|
||||||
(nodrop_avg_logprob(0), 'raw_avg_logprob'),
|
# (nodrop_avg_logprob(0), 'raw_avg_logprob'),
|
||||||
(nodrop_seq_prob(0), 'raw_seq_prob'),
|
# (nodrop_seq_prob(0), 'raw_seq_prob'),
|
||||||
(mean_drop_seq_prob(0), 'raw_mean_seq_prob'),
|
(mean_drop_seq_prob(0), 'raw_mean_seq_prob'),
|
||||||
(var_drop_seq_prob(0), 'raw_var_seq_prob'),
|
(neg_of(var_drop_seq_prob(0)), 'raw_var_seq_prob'),
|
||||||
(cv_drop_seq_prob(0), 'raw_cv_seq_prob')
|
# (neg_of(cv_drop_seq_prob(0)), 'raw_cv_seq_prob'),
|
||||||
|
# (cev_drop_seq_prob(0), 'raw_cev_drop_seq_prob'),
|
||||||
|
# ([mean_drop_prob(0)], 'mean_prob'),
|
||||||
|
# ([var_drop_prob(0)], 'var_prob'),
|
||||||
|
# ([cev_drop_prob(0)], 'cev_drop_prob'),
|
||||||
# ([mean_drop_logit(0)], 'mean'),
|
# ([mean_drop_logit(0)], 'mean'),
|
||||||
# ([nodrop_entropies(0)], 'entropy'),
|
# ([nodrop_entropies(0)], 'entropy'),
|
||||||
# ([(mean_drop_logit(0), nodrop_entropies(0))], 'mean + entropy'),
|
# ([(mean_drop_logit(0), nodrop_entropies(0))], 'mean + entropy'),
|
||||||
|
@ -347,10 +384,13 @@ def main(args):
|
||||||
# ([prediction_length(0), max_of(nodrop_logit(0)), max_of(nodrop_entropies(0)), max_of(cv_drop_logit(0)), min_of(nodrop_logit(0)), min_of(nodrop_entropies(0)), min_of(cv_drop_logit(0)), input_length(0)], 'prediction_length + max_logit + max_entropy + max_cv + min_logit + min_entropy + min_cv + input_length'),
|
# ([prediction_length(0), max_of(nodrop_logit(0)), max_of(nodrop_entropies(0)), max_of(cv_drop_logit(0)), min_of(nodrop_logit(0)), min_of(nodrop_entropies(0)), min_of(cv_drop_logit(0)), input_length(0)], 'prediction_length + max_logit + max_entropy + max_cv + min_logit + min_entropy + min_cv + input_length'),
|
||||||
# ([nodrop_avg_logprob(0), prediction_length(0), max_of(nodrop_logit(0)), max_of(nodrop_entropies(0)), max_of(cv_drop_logit(0)), min_of(nodrop_logit(0)), min_of(nodrop_entropies(0)), min_of(cv_drop_logit(0)), input_length(0)], 'logprob + prediction_length + max_logit + max_entropy + max_cv + min_logit + min_entropy + min_cv + input_length'),
|
# ([nodrop_avg_logprob(0), prediction_length(0), max_of(nodrop_logit(0)), max_of(nodrop_entropies(0)), max_of(cv_drop_logit(0)), min_of(nodrop_logit(0)), min_of(nodrop_entropies(0)), min_of(cv_drop_logit(0)), input_length(0)], 'logprob + prediction_length + max_logit + max_entropy + max_cv + min_logit + min_entropy + min_cv + input_length'),
|
||||||
# ([nodrop_avg_logprob(0), prediction_length(0), max_of(nodrop_logit(0)), max_of(nodrop_entropies(0)), max_of(cv_drop_logit(0)), min_of(nodrop_logit(0)), input_length(0)], 'logprob + prediction_length + max_logit + max_entropy + max_cv + min_logit + input_length'),
|
# ([nodrop_avg_logprob(0), prediction_length(0), max_of(nodrop_logit(0)), max_of(nodrop_entropies(0)), max_of(cv_drop_logit(0)), min_of(nodrop_logit(0)), input_length(0)], 'logprob + prediction_length + max_logit + max_entropy + max_cv + min_logit + input_length'),
|
||||||
# ([variance_of_beams], 'var_beams'),
|
([nodrop_avg_logprob(0), prediction_length(0), max_of(nodrop_logit(0)), max_of(nodrop_entropies(0)), max_of(cv_drop_logit(0)), min_of(nodrop_logit(0)), input_length(0), cev_drop_seq_prob(0), mean_drop_seq_prob(0)], 'logprob + prediction_length + max_logit + max_entropy + max_cv + min_logit + input_length + cev_seq_prob + mean_seq_prob'),
|
||||||
|
# ([variance_of_beam_logits], 'var_beam_logits'),
|
||||||
|
# ([variance_of_beam_probs], 'var_beam_probs'),
|
||||||
# ([mean_drop_avg_logprob(0)], 'mean_drop_avg_logprob'),
|
# ([mean_drop_avg_logprob(0)], 'mean_drop_avg_logprob'),
|
||||||
# ([var_drop_avg_logprob(0)], 'var_drop_avg_logprob'),
|
# ([var_drop_avg_logprob(0)], 'var_drop_avg_logprob'),
|
||||||
# ([cv_drop_avg_logprob(0)], 'cv_drop_avg_logprob'),
|
# ([cv_drop_avg_logprob(0)], 'cv_drop_avg_logprob'),
|
||||||
|
|
||||||
]:
|
]:
|
||||||
estimator = ConfidenceEstimator(name=name, featurizers=f, eval_metric=args.eval_metric)
|
estimator = ConfidenceEstimator(name=name, featurizers=f, eval_metric=args.eval_metric)
|
||||||
logger.info('name = %s', name)
|
logger.info('name = %s', name)
|
||||||
|
|
Loading…
Reference in New Issue