Allow prediction on test sets without prepped bootleg features

Useful for calibration as new ood data is fed during runtime
This commit is contained in:
mehrad 2021-02-16 00:15:35 -08:00
parent 02e1681fb6
commit 1f1ceef39e
5 changed files with 105 additions and 86 deletions

View File

@ -554,7 +554,7 @@ def main(args):
if args.plot:
from matplotlib import pyplot # lazy import
confidences = torch.load(args.confidence_path, map_location=torch.device('cpu'))
confidences = torch.load(args.confidence_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
all_estimators = []
train_confidences, dev_confidences = train_test_split(confidences, test_size=args.dev_split, random_state=args.seed)

View File

@ -27,11 +27,13 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import functools
import os
import ujson
import numpy as np
import logging
import torch
from bootleg.annotator import Annotator
from .database_utils import is_banned
@ -39,6 +41,7 @@ from bootleg.extract_mentions import extract_mentions
from bootleg.utils.parser_utils import get_full_config
from bootleg import run
from .progbar import progress_bar
logger = logging.getLogger(__name__)
@ -60,6 +63,73 @@ def reverse_bisect_left(a, x, lo=None, hi=None):
return lo
def bootleg_process_examples(ex, bootleg_annotator, args, label, task):
line = {}
line['sentence'] = getattr(ex, task.utterance_field())
assert len(label) == 7
line['cands'] = label[3]
line['cand_probs'] = list(map(lambda item: list(item), label[4]))
line['spans'] = label[5]
line['aliases'] = label[6]
tokens_type_ids, tokens_type_probs = bootleg_annotator.bootleg.collect_features_per_line(line, args.bootleg_prob_threshold)
if task.utterance_field() == 'question':
for i in range(len(tokens_type_ids)):
ex.question_feature[i].type_id = tokens_type_ids[i]
ex.question_feature[i].type_prob = tokens_type_probs[i]
ex.context_plus_question_feature[i + len(ex.context.split(' '))].type_id = tokens_type_ids[i]
ex.context_plus_question_feature[i + len(ex.context.split(' '))].type_prob = tokens_type_probs[i]
else:
for i in range(len(tokens_type_ids)):
ex.context_feature[i].type_id = tokens_type_ids[i]
ex.context_feature[i].type_prob = tokens_type_probs[i]
ex.context_plus_question_feature[i].type_id = tokens_type_ids[i]
ex.context_plus_question_feature[i].type_prob = tokens_type_probs[i]
context_plus_question_with_types = task.create_sentence_plus_types_tokens(ex.context_plus_question,
ex.context_plus_question_feature,
args.add_types_to_text)
ex = ex._replace(context_plus_question_with_types=context_plus_question_with_types)
return ex
def extract_features_with_annotator(examples, bootleg_annotator, args, task):
with torch.no_grad():
bootleg_inputs = []
for ex in examples:
bootleg_inputs.append(getattr(ex, task.utterance_field()))
bootleg_labels = bootleg_annotator.label_mentions(bootleg_inputs)
bootleg_labels_unpacked = list(zip(*bootleg_labels))
for i in range(len(examples)):
ex = examples[i]
label = bootleg_labels_unpacked[i]
examples[i] = bootleg_process_examples(ex, bootleg_annotator, args, label, task)
def init_bootleg_annotator(args, device):
# instantiate a bootleg object to load config and relevant databases
bootleg = Bootleg(args)
bootleg_config = bootleg.create_config(bootleg.fixed_overrides)
# instantiate the annotator class. we use annotator only in server mode
# for training we use bootleg functions which preprocess and cache data using multiprocessing, and batching to speed up NED
bootleg_annotator = Annotator(config_args=bootleg_config,
device='cpu' if device.type == 'cpu' else 'cuda',
max_alias_len=args.max_entity_len,
cand_map=bootleg.cand_map,
threshold=args.bootleg_prob_threshold,
progbar_func=functools.partial(progress_bar, disable=True))
# collect all outputs now; we will filter later
bootleg_annotator.set_threshold(0.0)
setattr(bootleg_annotator, 'bootleg', bootleg)
return bootleg_annotator
def post_process_bootleg_types(qid, type, title, almond_domains):
# TODO if training on multiple domains (in one run) these mapping should be modified
# e.g. song is mapped to book which is not correct if training on music domain too

View File

@ -39,7 +39,7 @@ import shutil
# multiprocessing with CUDA
from torch.multiprocessing import Process, set_start_method
from .data_utils.bootleg import Bootleg
from .data_utils.bootleg import Bootleg, init_bootleg_annotator, extract_features_with_annotator
from .run_bootleg import bootleg_process_splits
try:
@ -48,7 +48,6 @@ except RuntimeError:
pass
import torch
import pickle
from . import models
from .tasks.registry import get_tasks
@ -61,7 +60,8 @@ from .arguments import check_and_update_generation_args
logger = logging.getLogger(__name__)
def prepare_data(args):
def prepare_data(args, device):
# initialize bootleg
bootleg = None
if args.do_ned and args.ned_retrieve_method == 'bootleg':
@ -110,7 +110,13 @@ def prepare_data(args):
data = split.test
path = path.test
if bootleg:
bootleg_process_splits(args, data.examples, path, task, bootleg)
if split.train or split.eval:
bootleg_process_splits(args, data.examples, path, task, bootleg)
else:
# no prepped bootleg features are available
# extract features on-the-fly using bootleg annotator
bootleg_annotator = init_bootleg_annotator(args, device)
extract_features_with_annotator(data.examples, bootleg_annotator, args, task)
task_data_processed.append(data)
task_path_processed.append(path)
datasets.append(task_data_processed)
@ -165,7 +171,7 @@ def run(args, device):
locale=locale
)
val_sets = prepare_data(args)
val_sets = prepare_data(args, device)
model.add_new_vocab_from_data(args.tasks)
iters = prepare_data_iterators(args, val_sets, model.numericalizer, device)

View File

@ -35,25 +35,22 @@ import logging
import sys
import os
from pprint import pformat
import functools
import torch
from . import models
from .data_utils.example import Example, NumericalizedExamples
from .data_utils.bootleg import init_bootleg_annotator, extract_features_with_annotator
from .tasks.registry import get_tasks
from .util import set_seed, init_devices, load_config_json, log_model_size
from .validate import generate_with_model
from .calibrate import ConfidenceEstimator
from bootleg.annotator import Annotator
from .data_utils.bootleg import Bootleg
from .data_utils.progbar import progress_bar
logger = logging.getLogger(__name__)
class Server:
class Server(object):
def __init__(self, args, numericalizer, model, device, confidence_estimators, estimator_filenames, bootleg_annotator=None):
self.args = args
self.device = device
@ -71,37 +68,6 @@ class Server:
# make a single batch with all examples
return NumericalizedExamples.collate_batches(all_features, self.numericalizer, device=self.device, db_unk_id=self.args.db_unk_id)
def bootleg_process_examples(self, ex, label, task):
line = {}
line['sentence'] = getattr(ex, task.utterance_field())
assert len(label) == 7
line['cands'] = label[3]
line['cand_probs'] = list(map(lambda item: list(item), label[4]))
line['spans'] = label[5]
line['aliases'] = label[6]
tokens_type_ids, tokens_type_probs = self.bootleg_annotator.bootleg.collect_features_per_line(line, self.args.bootleg_prob_threshold)
if task.utterance_field() == 'question':
for i in range(len(tokens_type_ids)):
ex.question_feature[i].type_id = tokens_type_ids[i]
ex.question_feature[i].type_prob = tokens_type_probs[i]
ex.context_plus_question_feature[i + len(ex.context.split(' '))].type_id = tokens_type_ids[i]
ex.context_plus_question_feature[i + len(ex.context.split(' '))].type_prob = tokens_type_probs[i]
else:
for i in range(len(tokens_type_ids)):
ex.context_feature[i].type_id = tokens_type_ids[i]
ex.context_feature[i].type_prob = tokens_type_probs[i]
ex.context_plus_question_feature[i].type_id = tokens_type_ids[i]
ex.context_plus_question_feature[i].type_prob = tokens_type_probs[i]
context_plus_question_with_types = task.create_sentence_plus_types_tokens(ex.context_plus_question,
ex.context_plus_question_feature,
self.args.add_types_to_text)
ex = ex._replace(context_plus_question_with_types=context_plus_question_with_types)
return ex
def handle_request(self, request):
task_name = request['task'] if 'task' in request else 'generic'
@ -125,36 +91,28 @@ class Server:
ex = Example.from_raw(str(example_id), context, question, answer, preprocess=task.preprocess_field, lower=self.args.lower)
examples.append(ex)
with torch.no_grad():
bootleg_inputs = []
if self.bootleg_annotator:
for ex in examples:
bootleg_inputs.append(getattr(ex, task.utterance_field()))
bootleg_labels = self.bootleg_annotator.label_mentions(bootleg_inputs)
bootleg_labels_unpacked = list(zip(*bootleg_labels))
for i in range(len(examples)):
ex = examples[i]
label = bootleg_labels_unpacked[i]
examples[i] = self.bootleg_process_examples(ex, label, task)
# process bootleg features
if self.bootleg_annotator:
extract_features_with_annotator(examples, self.bootleg_annotator, self.args, task)
self.model.add_new_vocab_from_data([task])
batch = self.numericalize_examples(examples)
if self.args.calibrator_paths is not None:
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
output_predictions_only=True,
confidence_estimators=self.confidence_estimators)
response = []
for idx, p in enumerate(output.predictions):
instance = {'answer': p[0], 'score': {}}
for e_idx, estimator_scores in enumerate(output.confidence_scores):
instance['score'][self.estimator_filenames[e_idx]] = float(estimator_scores[idx])
response.append(instance)
else:
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args, output_predictions_only=True)
response = [{'answer': p[0]} for p in output.predictions]
with torch.no_grad():
if self.args.calibrator_paths is not None:
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args,
output_predictions_only=True,
confidence_estimators=self.confidence_estimators)
response = []
for idx, p in enumerate(output.predictions):
instance = {'answer': p[0], 'score': {}}
for e_idx, estimator_scores in enumerate(output.confidence_scores):
instance['score'][self.estimator_filenames[e_idx]] = float(estimator_scores[idx])
response.append(instance)
else:
output = generate_with_model(self.model, [batch], self.numericalizer, task, self.args, output_predictions_only=True)
response = [{'answer': p[0]} for p in output.predictions]
return response
@ -244,22 +202,7 @@ def init(args):
bootleg_annotator = None
if args.do_ned and args.ned_retrieve_method == 'bootleg':
# instantiate a bootleg object to load config and relevant databases
bootleg = Bootleg(args)
bootleg_config = bootleg.create_config(bootleg.fixed_overrides)
# instantiate the annotator class. we use annotator only in server mode
# for training we use bootleg functions which preprocess and cache data using multiprocessing, and batching to speed up NED
bootleg_annotator = Annotator(config_args=bootleg_config,
device='cpu' if device.type=='cpu' else 'cuda',
max_alias_len=args.max_entity_len,
cand_map=bootleg.cand_map,
threshold=args.bootleg_prob_threshold,
progbar_func=functools.partial(progress_bar, disable=True))
# collect all outputs now; we will filter later
bootleg_annotator.set_threshold(0.0)
setattr(bootleg_annotator, 'bootleg', bootleg)
bootleg_annotator = init_bootleg_annotator(args, device)
logger.info(f'Arguments:\n{pformat(vars(args))}')
logger.info(f'Loading from {args.best_checkpoint}')

View File

@ -128,7 +128,7 @@ do
# test server for bootleg
# due to travis memory limitations, uncomment and run this test locally
# echo '{"id": "dummy_example_1", "context": "show me .", "question": "translate to thingtalk", "answer": "now => () => notify"}' | pipenv run python3 -m genienlp server --database_dir $SRCDIR/database/ --path $workdir/model_$i --stdin
# echo '{"task": "almond", "id": "dummy_example_1", "context": "show me .", "question": "translate to thingtalk", "answer": "now => () => notify"}' | pipenv run python3 -m genienlp server --database_dir $SRCDIR/database/ --path $workdir/model_$i --stdin
rm -rf $workdir/model_$i
i=$((i+1))