easy inference on a custom dataset

This commit is contained in:
Bryan Marcus McCann 2018-08-16 19:42:37 +00:00
parent 345836fbf4
commit 0e33e91dca
3 changed files with 57 additions and 0 deletions

View File

@ -118,6 +118,17 @@ docker run -it --rm -v `pwd`:/decaNLP/ decanlp bash -c "python /decaNLP/WikiSQL
docker run -it --rm -v `pwd`:/decaNLP/ decanlp bash -c "python /decaNLP/WikiSQL/evaluate.py /decaNLP/.data/wikisql/data/dev.jsonl /decaNLP/.data/wikisql/data/dev.db /decaNLP/mqan_wikisql/model/test/wikisql_logical_forms.jsonl" # assumes that you have data stored in .data
```
## Inference on a Custom Dataset
Using a pretrained model or a model you have trained yourself, you can run on new, custom datasets easily by following the instructions below. In this example, we use the checkpoint for the best MQAN trained on the entirety of decaNLP (see the section on Pretrained Models to see how to get this checkpoint) to run on my_custom_dataset.
```bash
mkdir .data/my_custom_dataset/
touch .datda/my_custom_dataset/val.jsonl
#TODO add examples line by line to val.jsonl in the form of a JSON dict: {"context": "The answer is answer.", "question": "What is the answer?", "answer": "answer"}
nvidia-docker run -it --rm -v `pwd`:/decaNLP/ decanlp bash -c "python /decaNLP/predict.py --evaluate valid --path /decaNLP/mqan_decanlp_qa_first --checkpoint_name model.pth --gpu 0 --tasks my_custom_dataset"
```
## Citation
If you use this in your work, please cite [*The Natural Language Decathlon: Multitask Learning as Question Answering*](https://arxiv.org/abs/1806.08730).

View File

@ -1424,3 +1424,46 @@ class SNLI(CQA, data.Dataset):
return tuple(d for d in (train_data, validation_data, test_data)
if d is not None)
class JSON(CQA, data.Dataset):
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
def __init__(self, path, field, subsample=None, **kwargs):
fields = [(x, field) for x in self.fields]
cache_name = os.path.join(os.path.dirname(path), '.cache', os.path.basename(path), str(subsample))
examples = []
if os.path.exists(cache_name):
examples = torch.load(cache_name)
else:
with open(os.path.expanduser(path)) as f:
lines = f.readlines()
for line in lines:
ex = json.loads(line)
context, question, answer = ex['context'], ex['question'], ex['answer']
context_question = get_context_question(context, question)
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
examples.append(ex)
if subsample is not None and len(examples) >= subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
torch.save(examples, cache_name)
super(JSON, self).__init__(examples, fields, **kwargs)
@classmethod
def splits(cls, fields, name, root='.data',
train='train', validation='val', test='test', **kwargs):
path = os.path.join(root, name)
train_data = None if train is None else cls(
os.path.join(path, 'train.jsonl'), fields, **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, 'val.jsonl'), fields, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, 'test.jsonl'), fields, **kwargs)
return tuple(d for d in (train_data, validation_data, test_data)
if d is not None)

View File

@ -172,6 +172,9 @@ def get_splits(args, task, FIELD, **kwargs):
if 'zre' in task:
split = torchtext.datasets.generic.ZeroShotRE.splits(
fields=FIELD, root=args.data, **kwargs)
elif os.path.exists(os.path.join(args.data, task)):
split = torchtext.datasets.generic.JSON.splits(
fields=FIELD, root=args.data, name=task, **kwargs)
return split