easy inference on a custom dataset
This commit is contained in:
parent
345836fbf4
commit
0e33e91dca
11
README.md
11
README.md
|
@ -118,6 +118,17 @@ docker run -it --rm -v `pwd`:/decaNLP/ decanlp bash -c "python /decaNLP/WikiSQL
|
|||
docker run -it --rm -v `pwd`:/decaNLP/ decanlp bash -c "python /decaNLP/WikiSQL/evaluate.py /decaNLP/.data/wikisql/data/dev.jsonl /decaNLP/.data/wikisql/data/dev.db /decaNLP/mqan_wikisql/model/test/wikisql_logical_forms.jsonl" # assumes that you have data stored in .data
|
||||
```
|
||||
|
||||
## Inference on a Custom Dataset
|
||||
|
||||
Using a pretrained model or a model you have trained yourself, you can run on new, custom datasets easily by following the instructions below. In this example, we use the checkpoint for the best MQAN trained on the entirety of decaNLP (see the section on Pretrained Models to see how to get this checkpoint) to run on my_custom_dataset.
|
||||
|
||||
```bash
|
||||
mkdir .data/my_custom_dataset/
|
||||
touch .datda/my_custom_dataset/val.jsonl
|
||||
#TODO add examples line by line to val.jsonl in the form of a JSON dict: {"context": "The answer is answer.", "question": "What is the answer?", "answer": "answer"}
|
||||
nvidia-docker run -it --rm -v `pwd`:/decaNLP/ decanlp bash -c "python /decaNLP/predict.py --evaluate valid --path /decaNLP/mqan_decanlp_qa_first --checkpoint_name model.pth --gpu 0 --tasks my_custom_dataset"
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this in your work, please cite [*The Natural Language Decathlon: Multitask Learning as Question Answering*](https://arxiv.org/abs/1806.08730).
|
||||
|
|
|
@ -1424,3 +1424,46 @@ class SNLI(CQA, data.Dataset):
|
|||
return tuple(d for d in (train_data, validation_data, test_data)
|
||||
if d is not None)
|
||||
|
||||
|
||||
class JSON(CQA, data.Dataset):
|
||||
|
||||
@staticmethod
|
||||
def sort_key(ex):
|
||||
return data.interleave_keys(len(ex.context), len(ex.answer))
|
||||
|
||||
def __init__(self, path, field, subsample=None, **kwargs):
|
||||
fields = [(x, field) for x in self.fields]
|
||||
cache_name = os.path.join(os.path.dirname(path), '.cache', os.path.basename(path), str(subsample))
|
||||
|
||||
examples = []
|
||||
if os.path.exists(cache_name):
|
||||
examples = torch.load(cache_name)
|
||||
else:
|
||||
with open(os.path.expanduser(path)) as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
ex = json.loads(line)
|
||||
context, question, answer = ex['context'], ex['question'], ex['answer']
|
||||
context_question = get_context_question(context, question)
|
||||
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
|
||||
examples.append(ex)
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
break
|
||||
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
|
||||
torch.save(examples, cache_name)
|
||||
|
||||
super(JSON, self).__init__(examples, fields, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def splits(cls, fields, name, root='.data',
|
||||
train='train', validation='val', test='test', **kwargs):
|
||||
path = os.path.join(root, name)
|
||||
|
||||
train_data = None if train is None else cls(
|
||||
os.path.join(path, 'train.jsonl'), fields, **kwargs)
|
||||
validation_data = None if validation is None else cls(
|
||||
os.path.join(path, 'val.jsonl'), fields, **kwargs)
|
||||
test_data = None if test is None else cls(
|
||||
os.path.join(path, 'test.jsonl'), fields, **kwargs)
|
||||
return tuple(d for d in (train_data, validation_data, test_data)
|
||||
if d is not None)
|
||||
|
|
3
util.py
3
util.py
|
@ -172,6 +172,9 @@ def get_splits(args, task, FIELD, **kwargs):
|
|||
if 'zre' in task:
|
||||
split = torchtext.datasets.generic.ZeroShotRE.splits(
|
||||
fields=FIELD, root=args.data, **kwargs)
|
||||
elif os.path.exists(os.path.join(args.data, task)):
|
||||
split = torchtext.datasets.generic.JSON.splits(
|
||||
fields=FIELD, root=args.data, name=task, **kwargs)
|
||||
return split
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue