Add a "decanlp" script that calls out to the different subcommands
Usage: - decanlp train ... - decanlp predict ... - decanlp convert-to-logical-forms ...
This commit is contained in:
parent
a5a203b099
commit
5447d0c37c
|
@ -20,11 +20,11 @@ def save_args(args):
|
|||
json.dump(vars(args), f, indent=2)
|
||||
|
||||
|
||||
def parse():
|
||||
def parse(argv):
|
||||
"""
|
||||
Returns the arguments from the command line.
|
||||
"""
|
||||
parser = ArgumentParser()
|
||||
parser = ArgumentParser(prog=argv[0])
|
||||
parser.add_argument('--root', default='./decaNLP', type=str, help='root directory for data, results, embeddings, code, etc.')
|
||||
parser.add_argument('--data', default='.data/', type=str, help='where to load data from.')
|
||||
parser.add_argument('--save', default='results', type=str, help='where to save results.')
|
||||
|
@ -96,7 +96,7 @@ def parse():
|
|||
parser.add_argument('--loss_switch', default=0.666, type=float, help='switch to BLEU loss after certain iterations controlled by this ratio')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser.parse_args(argv)
|
||||
if args.model is None:
|
||||
args.model = 'mcqa'
|
||||
if args.val_tasks is None:
|
||||
|
|
|
@ -3,6 +3,7 @@ from decanlp.text.torchtext.datasets.generic import Query
|
|||
from argparse import ArgumentParser
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import ujson as json
|
||||
from decanlp.metrics import to_lf
|
||||
|
||||
|
@ -64,14 +65,14 @@ def write_logical_forms(greedy, args):
|
|||
except Exception as e:
|
||||
f.write(json.dumps(correct_format({})) + '\n')
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = ArgumentParser()
|
||||
def main(argv=sys.argv):
|
||||
parser = ArgumentParser(prog=argv[0])
|
||||
parser.add_argument('data', help='path to the directory containing data for WikiSQL')
|
||||
parser.add_argument('predictions', help='path to prediction file, containing one prediction per line')
|
||||
parser.add_argument('ids', help='path to file for indices, a list of integers indicating the index into the dev/test set of the predictions on the corresponding line in \'predicitons\'')
|
||||
parser.add_argument('output', help='path for logical forms output line by line')
|
||||
parser.add_argument('evaluate', help='running on the \'validation\' or \'test\' set')
|
||||
args = parser.parse_args()
|
||||
args = parser.parse_args(argv)
|
||||
with open(args.predictions) as f:
|
||||
greedy = [l for l in f]
|
||||
if args.ids is not None:
|
||||
|
@ -79,3 +80,6 @@ if __name__ == '__main__':
|
|||
ids = [int(l.strip()) for l in f]
|
||||
greedy = [x[1] for x in sorted([(i, g) for i, g in zip(ids, greedy)])]
|
||||
write_logical_forms(greedy, args)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -1,6 +1,6 @@
|
|||
from subprocess import Popen, PIPE, CalledProcessError
|
||||
import json
|
||||
from text.torchtext.datasets.generic import Query
|
||||
from .text.torchtext.datasets.generic import Query
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
|
|
@ -7,9 +7,10 @@ import ujson as json
|
|||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import sys
|
||||
from pprint import pformat
|
||||
|
||||
from .util import get_splits, set_seed, preprocess_examples, tokenizer
|
||||
from .util import get_splits, set_seed, preprocess_examples
|
||||
from .metrics import compute_metrics
|
||||
from . import models
|
||||
|
||||
|
@ -207,8 +208,8 @@ def run(args, field, val_sets, model):
|
|||
print(f'\nSummary: | {sum(decaScore)} | {" | ".join([str(x) for x in decaScore])} |\n')
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = ArgumentParser()
|
||||
def get_args(argv):
|
||||
parser = ArgumentParser(prog=argv[0])
|
||||
parser.add_argument('--path', required=True)
|
||||
parser.add_argument('--evaluate', type=str, required=True)
|
||||
parser.add_argument('--tasks', default=['almond', 'squad', 'iwslt.en.de', 'cnn_dailymail', 'multinli.in.out', 'sst', 'srl', 'zre', 'woz.en', 'wikisql', 'schema'], nargs='+')
|
||||
|
@ -227,7 +228,7 @@ def get_args():
|
|||
parser.add_argument('--eval_dir', type=str, default=None, help='use this directory to store eval results')
|
||||
parser.add_argument('--cached', default='', type=str, help='where to save cached files')
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
with open(os.path.join(args.path, 'config.json')) as config_file:
|
||||
config = json.load(config_file)
|
||||
|
@ -267,7 +268,7 @@ def get_args():
|
|||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def main(argv=sys.argv):
|
||||
args = get_args()
|
||||
print(f'Arguments:\n{pformat(vars(args))}')
|
||||
|
||||
|
@ -299,3 +300,6 @@ if __name__ == '__main__':
|
|||
model.set_embeddings(field.vocab.vectors)
|
||||
|
||||
run(args, field, splits, model)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
from text import torchtext
|
||||
from argparse import ArgumentParser
|
||||
import ujson as json
|
||||
import torch
|
||||
|
@ -11,11 +10,12 @@ import logging
|
|||
from copy import deepcopy
|
||||
from pprint import pformat
|
||||
|
||||
from util import set_seed
|
||||
import models
|
||||
from .util import set_seed
|
||||
from . import models
|
||||
|
||||
from text.torchtext.data import Example
|
||||
from text.torchtext.datasets.generic import CONTEXT_SPECIAL, QUESTION_SPECIAL, get_context_question, CQA
|
||||
from .text import torchtext
|
||||
from .text.torchtext.data import Example
|
||||
from .text.torchtext.datasets.generic import CONTEXT_SPECIAL, QUESTION_SPECIAL, get_context_question, CQA
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -37,14 +37,14 @@ class Server():
|
|||
def prepare_data(self):
|
||||
print(f'Vocabulary has {len(self.field.vocab)} tokens from training')
|
||||
|
||||
char_vectors = torchtext.vocab.CharNGram(cache=args.embeddings)
|
||||
glove_vectors = torchtext.vocab.GloVe(cache=args.embeddings)
|
||||
char_vectors = torchtext.vocab.CharNGram(cache=self.args.embeddings)
|
||||
glove_vectors = torchtext.vocab.GloVe(cache=self.args.embeddings)
|
||||
vectors = [char_vectors, glove_vectors]
|
||||
self.field.vocab.load_vectors(vectors, True)
|
||||
self.field.decoder_to_vocab = {idx: self.field.vocab.stoi[word] for idx, word in enumerate(self.field.decoder_itos)}
|
||||
self.field.vocab_to_decoder = {idx: self.field.decoder_stoi[word] for idx, word in enumerate(self.field.vocab.itos) if word in self.field.decoder_stoi}
|
||||
|
||||
self._limited_idx_to_full_idx = deepcopy(field.decoder_to_vocab) # should avoid this with a conditional in map to full
|
||||
self._limited_idx_to_full_idx = deepcopy(self.field.decoder_to_vocab) # should avoid this with a conditional in map to full
|
||||
self._oov_to_limited_idx = {}
|
||||
|
||||
assert self.field.include_lengths
|
||||
|
@ -55,8 +55,8 @@ class Server():
|
|||
for name in CQA.fields:
|
||||
# batch of size 1
|
||||
batch = [getattr(ex, name)]
|
||||
entry, lengths, limited_entry, raw = field.process(batch, device=self.device, train=True,
|
||||
limited=field.decoder_stoi, l2f=self._limited_idx_to_full_idx, oov2l=self._oov_to_limited_idx)
|
||||
entry, lengths, limited_entry, raw = self.field.process(batch, device=self.device, train=True,
|
||||
limited=self.field.decoder_stoi, l2f=self._limited_idx_to_full_idx, oov2l=self._oov_to_limited_idx)
|
||||
setattr(processed, name, entry)
|
||||
setattr(processed, f'{name}_lengths', lengths)
|
||||
setattr(processed, f'{name}_limited', limited_entry)
|
||||
|
@ -81,16 +81,16 @@ class Server():
|
|||
tokenize = None
|
||||
|
||||
context_question = get_context_question(context, question)
|
||||
fields = [(x, field) for x in CQA.fields]
|
||||
fields = [(x, self.field) for x in CQA.fields]
|
||||
ex = Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields, tokenize=tokenize)
|
||||
|
||||
batch = self.numericalize_example(ex)
|
||||
_, prediction_batch = model(batch, iteration=0)
|
||||
_, prediction_batch = self.model(batch, iteration=0)
|
||||
|
||||
if task == 'almond':
|
||||
predictions = field.reverse(prediction_batch, detokenize=lambda x: ' '.join(x))
|
||||
predictions = self.field.reverse(prediction_batch, detokenize=lambda x: ' '.join(x))
|
||||
else:
|
||||
predictions = field.reverse(prediction_batch)
|
||||
predictions = self.field.reverse(prediction_batch)
|
||||
|
||||
client_writer.write((json.dumps(dict(id=request['id'], answer=predictions[0])) + '\n').encode('utf-8'))
|
||||
|
||||
|
@ -110,15 +110,15 @@ class Server():
|
|||
this_r *= s
|
||||
r += this_r
|
||||
return r
|
||||
params = list(filter(lambda p: p.requires_grad, model.parameters()))
|
||||
params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
|
||||
num_param = mult(params)
|
||||
print(f'{args.model} has {num_param:,} parameters')
|
||||
model.to(self.device)
|
||||
print(f'{self.args.model} has {num_param:,} parameters')
|
||||
self.model.to(self.device)
|
||||
|
||||
model.eval()
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
loop = asyncio.get_event_loop()
|
||||
server = loop.run_until_complete(asyncio.start_server(self.handle_client, port=args.port))
|
||||
server = loop.run_until_complete(asyncio.start_server(self.handle_client, port=self.args.port))
|
||||
try:
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
|
@ -128,8 +128,8 @@ class Server():
|
|||
loop.close()
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = ArgumentParser()
|
||||
def get_args(argv):
|
||||
parser = ArgumentParser(prog=argv[0])
|
||||
parser.add_argument('--path', required=True)
|
||||
parser.add_argument('--devices', default=[0], nargs='+', type=int, help='a list of devices that can be used (multi-gpu currently WIP)')
|
||||
parser.add_argument('--seed', default=123, type=int, help='Random seed.')
|
||||
|
@ -138,7 +138,7 @@ def get_args():
|
|||
parser.add_argument('--checkpoint_name', default='best.pth', help='Checkpoint file to use (relative to --path, defaults to best.pth)')
|
||||
parser.add_argument('--port', default=8401, type=int, help='TCP port to listen on')
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
with open(os.path.join(args.path, 'config.json')) as config_file:
|
||||
config = json.load(config_file)
|
||||
|
@ -179,8 +179,8 @@ def get_args():
|
|||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
def main(argv=sys.argv):
|
||||
args = get_args(argv)
|
||||
print(f'Arguments:\n{pformat(vars(args))}')
|
||||
|
||||
np.random.seed(args.seed)
|
|
@ -0,0 +1,60 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright 2019 The Board of Trustees of the Leland Stanford Junior University
|
||||
#
|
||||
# Author: Giovanni Campagna <gcampagn@cs.stanford.edu>
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
#
|
||||
# * Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
#
|
||||
# * Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import sys
|
||||
|
||||
from decanlp import convert_to_logical_forms, train, predict, server
|
||||
|
||||
subcommands = {
|
||||
'convert-to-logical-froms': ('Convert to logical forms (for SQL tasks)', convert_to_logical_forms.main),
|
||||
'train': ('Train a model', train.main),
|
||||
'predict': ('Evaluate a model, or compute predictions on a test dataset', predict.main),
|
||||
'server': ('Export RPC interface to predict', server.main)
|
||||
}
|
||||
|
||||
def usage():
|
||||
print('Usage: %s SUBCOMMAND [OPTIONS]' % (sys.argv[0]))
|
||||
print()
|
||||
print('Available subcommands:')
|
||||
for subcommand,(help_text,_) in subcommands.items():
|
||||
print(' %s - %s' % (subcommand, help_text))
|
||||
sys.exit(1)
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2 or sys.argv[1] not in subcommands:
|
||||
usage()
|
||||
return
|
||||
|
||||
main_fn = subcommands[sys.argv[1]][1]
|
||||
canned_argv = ['decanlp-' + sys.argv[1]] + sys.argv[2:]
|
||||
main_fn(canned_argv)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -5,6 +5,7 @@ import math
|
|||
import time
|
||||
import random
|
||||
import collections
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
|
||||
import logging
|
||||
|
@ -365,8 +366,8 @@ def init_opt(args, model):
|
|||
return opt
|
||||
|
||||
|
||||
def main():
|
||||
args = arguments.parse()
|
||||
def main(argv=sys.argv):
|
||||
args = arguments.parse(argv)
|
||||
if args is None:
|
||||
return
|
||||
set_seed(args)
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from text import torchtext
|
||||
import time
|
||||
import os
|
||||
import sys
|
||||
|
@ -6,7 +5,8 @@ import torch
|
|||
import random
|
||||
import numpy as np
|
||||
|
||||
from text.torchtext.data.utils import get_tokenizer
|
||||
from .text import torchtext
|
||||
from .text.torchtext.data.utils import get_tokenizer
|
||||
|
||||
def tokenizer(s):
|
||||
return s.split()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from .util import pad, tokenizer
|
||||
from .metrics import compute_metrics
|
||||
from text.torchtext.data.utils import get_tokenizer
|
||||
from .text.torchtext.data.utils import get_tokenizer
|
||||
|
||||
|
||||
def compute_validation_outputs(model, val_iter, field, iteration, optional_names=[]):
|
||||
|
|
Loading…
Reference in New Issue