bugs in overwrite; new best mqan model
This commit is contained in:
parent
1dc5f7d28e
commit
c096a1f5ba
|
@ -13,6 +13,7 @@ While the research direction associated with this repository focused on multitas
|
|||
|
||||
| Model | decaNLP | [SQuAD](https://rajpurkar.github.io/SQuAD-explorer/) | [IWSLT](https://wit3.fbk.eu/mt.php?release=2016-01) | [CNN/DM](https://cs.nyu.edu/~kcho/DMQA/) | [MNLI](https://www.nyu.edu/projects/bowman/multinli/) | [SST](https://nlp.stanford.edu/sentiment/treebank.html) | [QA‑SRL](https://dada.cs.washington.edu/qasrl/) | [QA‑ZRE](http://nlp.cs.washington.edu/zeroshot/) | [WOZ](https://github.com/nmrksic/neural-belief-tracker/tree/master/data/woz) | [WikiSQL](https://github.com/salesforce/WikiSQL) | [MWSC](https://s3.amazonaws.com/research.metamind.io/decaNLP/data/schema.txt) |
|
||||
| --- | --- | --- | --- | --- | --- | --- | ---- | ---- | --- | --- |--- |
|
||||
| [MQAN](https://arxiv.org/abs/1806.08730)(Better Sampling+[CoVe](http://papers.nips.cc/paper/7209-learned-in-translation-contextualized-word-vectors)) | 608.9687963636 | 77.0417861486 | 21.3824477725 | 24.3900766008 | 74.02 | 86.4678899083 | 80.935100999 | 40.9381663113 | 84.8192771084 | 70.1935637098 | 48.7804878050 |
|
||||
| [MQAN](https://arxiv.org/abs/1806.08730)(QA‑first+[CoVe](http://papers.nips.cc/paper/7209-learned-in-translation-contextualized-word-vectors)) | 599.9 | 75.5 | 18.9 | 24.4 | 73.6 | 86.4 | 80.8 | 37.4 | 85.8 | 68.5 | 48.8 |
|
||||
| [MQAN](https://arxiv.org/abs/1806.08730)(QA‑first) | 590.5 | 74.4 | 18.6 | 24.3 | 71.5 | 87.4 | 78.4 | 37.6 | 84.8 | 64.8 | 48.7 |
|
||||
| [S2S](https://arxiv.org/abs/1806.08730) | 513.6 | 47.5 | 14.2 | 25.7 | 60.9 | 85.9 | 68.7 | 28.5 | 84.0 | 45.8 | 52.4 |
|
||||
|
@ -116,9 +117,9 @@ For test performance, please use the original [SQuAD](https://rajpurkar.github.i
|
|||
This model is the best MQAN trained on decaNLP so far. It was trained first on SQuAD and then on all of decaNLP. It uses [CoVe](http://papers.nips.cc/paper/7209-learned-in-translation-contextualized-word-vectors.pdf) as well. You can obtain this model and run it on the validation sets with the following.
|
||||
|
||||
```bash
|
||||
wget https://s3.amazonaws.com/research.metamind.io/decaNLP/pretrained/mqan_decanlp_qa_first_cove_cpu.tar.gz
|
||||
tar -xvzf mqan_decanlp_qa_first_cove_cpu.tar.gz
|
||||
nvidia-docker run -it --rm -v `pwd`:/decaNLP/ -u $(id -u):$(id -g) bmccann/decanlp:cuda9_torch041 bash -c "python /decaNLP/predict.py --evaluate validation --path /decaNLP/mqan_decanlp_qa_first_cove_cpu/ --device 0 --silent"
|
||||
wget https://s3.amazonaws.com/research.metamind.io/decaNLP/pretrained/mqan_decanlp_better_sampling_cove_cpu.tgz
|
||||
tar -xvzf mqan_decanlp_better_sampling_cove_cpu.tgz
|
||||
nvidia-docker run -it --rm -v `pwd`:/decaNLP/ -u $(id -u):$(id -g) bmccann/decanlp:cuda9_torch041 bash -c "python /decaNLP/predict.py --evaluate validation --path /decaNLP/mqan_decanlp_better_sampling_cove_cpu/ --checkpoint_name iteration_560000.pth --device 0 --silent"
|
||||
```
|
||||
|
||||
This model is the best MQAN trained on WikiSQL alone, which established [a new state-of-the-art performance by several points on that task](https://github.com/salesforce/WikiSQL): 73.2 / 75.4 / 81.4 (ordered test logical form accuracy, unordered test logical form accuracy, test execution accuracy).
|
||||
|
|
|
@ -9,7 +9,7 @@ from torch.nn import functional as F
|
|||
from util import get_trainable_params
|
||||
|
||||
from cove import MTLSTM
|
||||
from allennlp.modules.elmo import Elmo, batch_to_ids
|
||||
#from allennlp.modules.elmo import Elmo, batch_to_ids
|
||||
|
||||
from .common import positional_encodings_like, INF, EPSILON, TransformerEncoder, TransformerDecoder, PackedLSTM, LSTMDecoderAttention, LSTMDecoder, Embedding, Feedforward, mask, CoattentiveLayer
|
||||
|
||||
|
|
37
predict.py
37
predict.py
|
@ -93,23 +93,29 @@ def run(args, field, val_sets, model):
|
|||
ids_file_name = answer_file_name.replace('gold', 'ids')
|
||||
if os.path.exists(prediction_file_name):
|
||||
print('** ', prediction_file_name, ' already exists -- this is where predictions are stored **')
|
||||
if args.overwrite:
|
||||
print('**** overwriting ', prediction_file_name, ' ****')
|
||||
if os.path.exists(answer_file_name):
|
||||
print('** ', answer_file_name, ' already exists -- this is where ground truth answers are stored **')
|
||||
if args.overwrite:
|
||||
print('**** overwriting ', answer_file_name, ' ****')
|
||||
if os.path.exists(results_file_name):
|
||||
print('** ', results_file_name, ' already exists -- this is where metrics are stored **')
|
||||
with open(results_file_name) as results_file:
|
||||
for l in results_file:
|
||||
print(l)
|
||||
if not args.overwrite_predictions and args.silent:
|
||||
if args.overwrite:
|
||||
print('**** overwriting ', results_file_name, ' ****')
|
||||
else:
|
||||
with open(results_file_name) as results_file:
|
||||
metrics = json.loads(results_file.readlines()[0])
|
||||
decaScore.append(metrics[args.task_to_metric[task]])
|
||||
if not args.silent:
|
||||
for l in results_file:
|
||||
print(l)
|
||||
metrics = json.loads(results_file.readlines()[0])
|
||||
decaScore.append(metrics[args.task_to_metric[task]])
|
||||
continue
|
||||
|
||||
for x in [prediction_file_name, answer_file_name, results_file_name]:
|
||||
os.makedirs(os.path.dirname(x), exist_ok=True)
|
||||
|
||||
if not os.path.exists(prediction_file_name) or args.overwrite_predictions:
|
||||
if not os.path.exists(prediction_file_name) or args.overwrite:
|
||||
with open(prediction_file_name, 'w') as prediction_file:
|
||||
predictions = []
|
||||
ids = []
|
||||
|
@ -141,7 +147,7 @@ def run(args, field, val_sets, model):
|
|||
def from_all_answers(an):
|
||||
return [it.dataset.all_answers[sid] for sid in an.tolist()]
|
||||
|
||||
if not os.path.exists(answer_file_name):
|
||||
if not os.path.exists(answer_file_name) or args.overwrite:
|
||||
with open(answer_file_name, 'w') as answer_file:
|
||||
answers = []
|
||||
for batch_idx, batch in enumerate(it):
|
||||
|
@ -161,7 +167,7 @@ def run(args, field, val_sets, model):
|
|||
answers = [json.loads(x.strip()) for x in answer_file.readlines()]
|
||||
|
||||
if len(answers) > 0:
|
||||
if not os.path.exists(results_file_name):
|
||||
if not os.path.exists(results_file_name) or args.overwrite:
|
||||
metrics, answers = compute_metrics(predictions, answers, bleu='iwslt' in task or 'multi30k' in task or args.bleu, dialogue='woz' in task,
|
||||
rouge='cnn' in task or 'dailymail' in task or args.rouge, logical_form='sql' in task, corpus_f1='zre' in task, args=args)
|
||||
with open(results_file_name, 'w') as results_file:
|
||||
|
@ -173,14 +179,14 @@ def run(args, field, val_sets, model):
|
|||
if not args.silent:
|
||||
for i, (p, a) in enumerate(zip(predictions, answers)):
|
||||
print(f'Prediction {i+1}: {p}\nAnswer {i+1}: {a}\n')
|
||||
print(metrics)
|
||||
print(metrics)
|
||||
decaScore.append(metrics[args.task_to_metric[task]])
|
||||
|
||||
print(f'Evaluated Tasks:\n')
|
||||
for i, (task, _) in enumerate(iters):
|
||||
print(f'{task}: {decaScore[i]}')
|
||||
print(f'-------------------')
|
||||
print(f'DecaScore: {sum(decaScore)}\n')
|
||||
|
||||
print(f'\nSummary: | {sum(decaScore)} | {" | ".join([str(x) for x in decaScore])} |\n')
|
||||
|
||||
|
||||
|
@ -196,7 +202,7 @@ def get_args():
|
|||
parser.add_argument('--checkpoint_name')
|
||||
parser.add_argument('--bleu', action='store_true', help='whether to use the bleu metric (always on for iwslt)')
|
||||
parser.add_argument('--rouge', action='store_true', help='whether to use the bleu metric (always on for cnn, dailymail, and cnn_dailymail)')
|
||||
parser.add_argument('--overwrite_predictions', action='store_true', help='whether to overwrite previously written predictions')
|
||||
parser.add_argument('--overwrite', action='store_true', help='whether to overwrite previously written predictions')
|
||||
parser.add_argument('--silent', action='store_true', help='whether to print predictions to stdout')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
@ -232,10 +238,11 @@ def get_args():
|
|||
'zre': 'corpus_f1',
|
||||
'schema': 'em'}
|
||||
|
||||
if os.path.exists(os.path.join(args.path, 'process_0.log')):
|
||||
args.best_checkpoint = get_best(args)
|
||||
else:
|
||||
if not args.checkpoint_name is None:
|
||||
args.best_checkpoint = os.path.join(args.path, args.checkpoint_name)
|
||||
else:
|
||||
assert os.path.exists(os.path.join(args.path, 'process_0.log'))
|
||||
args.best_checkpoint = get_best(args)
|
||||
|
||||
return args
|
||||
|
||||
|
|
Loading…
Reference in New Issue