Add a "decanlp" script that calls out to the different subcommands

Usage:
- decanlp train ...
- decanlp predict ...
- decanlp convert-to-logical-forms ...
This commit is contained in:
Giovanni Campagna 2019-01-23 12:08:41 -08:00
parent a5a203b099
commit 5447d0c37c
10 changed files with 113 additions and 43 deletions

View File

@ -20,11 +20,11 @@ def save_args(args):
json.dump(vars(args), f, indent=2)
def parse():
def parse(argv):
"""
Returns the arguments from the command line.
"""
parser = ArgumentParser()
parser = ArgumentParser(prog=argv[0])
parser.add_argument('--root', default='./decaNLP', type=str, help='root directory for data, results, embeddings, code, etc.')
parser.add_argument('--data', default='.data/', type=str, help='where to load data from.')
parser.add_argument('--save', default='results', type=str, help='where to save results.')
@ -96,7 +96,7 @@ def parse():
parser.add_argument('--loss_switch', default=0.666, type=float, help='switch to BLEU loss after certain iterations controlled by this ratio')
args = parser.parse_args()
args = parser.parse_args(argv)
if args.model is None:
args.model = 'mcqa'
if args.val_tasks is None:

View File

@ -3,6 +3,7 @@ from decanlp.text.torchtext.datasets.generic import Query
from argparse import ArgumentParser
import os
import re
import sys
import ujson as json
from decanlp.metrics import to_lf
@ -64,14 +65,14 @@ def write_logical_forms(greedy, args):
except Exception as e:
f.write(json.dumps(correct_format({})) + '\n')
if __name__ == '__main__':
parser = ArgumentParser()
def main(argv=sys.argv):
parser = ArgumentParser(prog=argv[0])
parser.add_argument('data', help='path to the directory containing data for WikiSQL')
parser.add_argument('predictions', help='path to prediction file, containing one prediction per line')
parser.add_argument('ids', help='path to file for indices, a list of integers indicating the index into the dev/test set of the predictions on the corresponding line in \'predicitons\'')
parser.add_argument('output', help='path for logical forms output line by line')
parser.add_argument('evaluate', help='running on the \'validation\' or \'test\' set')
args = parser.parse_args()
args = parser.parse_args(argv)
with open(args.predictions) as f:
greedy = [l for l in f]
if args.ids is not None:
@ -79,3 +80,6 @@ if __name__ == '__main__':
ids = [int(l.strip()) for l in f]
greedy = [x[1] for x in sorted([(i, g) for i, g in zip(ids, greedy)])]
write_logical_forms(greedy, args)
if __name__ == '__main__':
main()

View File

@ -1,6 +1,6 @@
from subprocess import Popen, PIPE, CalledProcessError
import json
from text.torchtext.datasets.generic import Query
from .text.torchtext.datasets.generic import Query
import logging
import os
import re

View File

@ -7,9 +7,10 @@ import ujson as json
import torch
import numpy as np
import random
import sys
from pprint import pformat
from .util import get_splits, set_seed, preprocess_examples, tokenizer
from .util import get_splits, set_seed, preprocess_examples
from .metrics import compute_metrics
from . import models
@ -207,8 +208,8 @@ def run(args, field, val_sets, model):
print(f'\nSummary: | {sum(decaScore)} | {" | ".join([str(x) for x in decaScore])} |\n')
def get_args():
parser = ArgumentParser()
def get_args(argv):
parser = ArgumentParser(prog=argv[0])
parser.add_argument('--path', required=True)
parser.add_argument('--evaluate', type=str, required=True)
parser.add_argument('--tasks', default=['almond', 'squad', 'iwslt.en.de', 'cnn_dailymail', 'multinli.in.out', 'sst', 'srl', 'zre', 'woz.en', 'wikisql', 'schema'], nargs='+')
@ -227,7 +228,7 @@ def get_args():
parser.add_argument('--eval_dir', type=str, default=None, help='use this directory to store eval results')
parser.add_argument('--cached', default='', type=str, help='where to save cached files')
args = parser.parse_args()
args = parser.parse_args(argv)
with open(os.path.join(args.path, 'config.json')) as config_file:
config = json.load(config_file)
@ -267,7 +268,7 @@ def get_args():
return args
if __name__ == '__main__':
def main(argv=sys.argv):
args = get_args()
print(f'Arguments:\n{pformat(vars(args))}')
@ -299,3 +300,6 @@ if __name__ == '__main__':
model.set_embeddings(field.vocab.vectors)
run(args, field, splits, model)
if __name__ == '__main__':
main()

View File

@ -1,6 +1,5 @@
#!/usr/bin/env python3
import os
from text import torchtext
from argparse import ArgumentParser
import ujson as json
import torch
@ -11,11 +10,12 @@ import logging
from copy import deepcopy
from pprint import pformat
from util import set_seed
import models
from .util import set_seed
from . import models
from text.torchtext.data import Example
from text.torchtext.datasets.generic import CONTEXT_SPECIAL, QUESTION_SPECIAL, get_context_question, CQA
from .text import torchtext
from .text.torchtext.data import Example
from .text.torchtext.datasets.generic import CONTEXT_SPECIAL, QUESTION_SPECIAL, get_context_question, CQA
logger = logging.getLogger(__name__)
@ -37,14 +37,14 @@ class Server():
def prepare_data(self):
print(f'Vocabulary has {len(self.field.vocab)} tokens from training')
char_vectors = torchtext.vocab.CharNGram(cache=args.embeddings)
glove_vectors = torchtext.vocab.GloVe(cache=args.embeddings)
char_vectors = torchtext.vocab.CharNGram(cache=self.args.embeddings)
glove_vectors = torchtext.vocab.GloVe(cache=self.args.embeddings)
vectors = [char_vectors, glove_vectors]
self.field.vocab.load_vectors(vectors, True)
self.field.decoder_to_vocab = {idx: self.field.vocab.stoi[word] for idx, word in enumerate(self.field.decoder_itos)}
self.field.vocab_to_decoder = {idx: self.field.decoder_stoi[word] for idx, word in enumerate(self.field.vocab.itos) if word in self.field.decoder_stoi}
self._limited_idx_to_full_idx = deepcopy(field.decoder_to_vocab) # should avoid this with a conditional in map to full
self._limited_idx_to_full_idx = deepcopy(self.field.decoder_to_vocab) # should avoid this with a conditional in map to full
self._oov_to_limited_idx = {}
assert self.field.include_lengths
@ -55,8 +55,8 @@ class Server():
for name in CQA.fields:
# batch of size 1
batch = [getattr(ex, name)]
entry, lengths, limited_entry, raw = field.process(batch, device=self.device, train=True,
limited=field.decoder_stoi, l2f=self._limited_idx_to_full_idx, oov2l=self._oov_to_limited_idx)
entry, lengths, limited_entry, raw = self.field.process(batch, device=self.device, train=True,
limited=self.field.decoder_stoi, l2f=self._limited_idx_to_full_idx, oov2l=self._oov_to_limited_idx)
setattr(processed, name, entry)
setattr(processed, f'{name}_lengths', lengths)
setattr(processed, f'{name}_limited', limited_entry)
@ -81,16 +81,16 @@ class Server():
tokenize = None
context_question = get_context_question(context, question)
fields = [(x, field) for x in CQA.fields]
fields = [(x, self.field) for x in CQA.fields]
ex = Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields, tokenize=tokenize)
batch = self.numericalize_example(ex)
_, prediction_batch = model(batch, iteration=0)
_, prediction_batch = self.model(batch, iteration=0)
if task == 'almond':
predictions = field.reverse(prediction_batch, detokenize=lambda x: ' '.join(x))
predictions = self.field.reverse(prediction_batch, detokenize=lambda x: ' '.join(x))
else:
predictions = field.reverse(prediction_batch)
predictions = self.field.reverse(prediction_batch)
client_writer.write((json.dumps(dict(id=request['id'], answer=predictions[0])) + '\n').encode('utf-8'))
@ -110,15 +110,15 @@ class Server():
this_r *= s
r += this_r
return r
params = list(filter(lambda p: p.requires_grad, model.parameters()))
params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
num_param = mult(params)
print(f'{args.model} has {num_param:,} parameters')
model.to(self.device)
print(f'{self.args.model} has {num_param:,} parameters')
self.model.to(self.device)
model.eval()
self.model.eval()
with torch.no_grad():
loop = asyncio.get_event_loop()
server = loop.run_until_complete(asyncio.start_server(self.handle_client, port=args.port))
server = loop.run_until_complete(asyncio.start_server(self.handle_client, port=self.args.port))
try:
loop.run_forever()
except KeyboardInterrupt:
@ -128,8 +128,8 @@ class Server():
loop.close()
def get_args():
parser = ArgumentParser()
def get_args(argv):
parser = ArgumentParser(prog=argv[0])
parser.add_argument('--path', required=True)
parser.add_argument('--devices', default=[0], nargs='+', type=int, help='a list of devices that can be used (multi-gpu currently WIP)')
parser.add_argument('--seed', default=123, type=int, help='Random seed.')
@ -138,7 +138,7 @@ def get_args():
parser.add_argument('--checkpoint_name', default='best.pth', help='Checkpoint file to use (relative to --path, defaults to best.pth)')
parser.add_argument('--port', default=8401, type=int, help='TCP port to listen on')
args = parser.parse_args()
args = parser.parse_args(argv)
with open(os.path.join(args.path, 'config.json')) as config_file:
config = json.load(config_file)
@ -179,8 +179,8 @@ def get_args():
return args
if __name__ == '__main__':
args = get_args()
def main(argv=sys.argv):
args = get_args(argv)
print(f'Arguments:\n{pformat(vars(args))}')
np.random.seed(args.seed)

60
decanlp/tool.py Executable file
View File

@ -0,0 +1,60 @@
#!/usr/bin/env python3
# Copyright 2019 The Board of Trustees of the Leland Stanford Junior University
#
# Author: Giovanni Campagna <gcampagn@cs.stanford.edu>
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import sys
from decanlp import convert_to_logical_forms, train, predict, server
subcommands = {
'convert-to-logical-froms': ('Convert to logical forms (for SQL tasks)', convert_to_logical_forms.main),
'train': ('Train a model', train.main),
'predict': ('Evaluate a model, or compute predictions on a test dataset', predict.main),
'server': ('Export RPC interface to predict', server.main)
}
def usage():
print('Usage: %s SUBCOMMAND [OPTIONS]' % (sys.argv[0]))
print()
print('Available subcommands:')
for subcommand,(help_text,_) in subcommands.items():
print(' %s - %s' % (subcommand, help_text))
sys.exit(1)
def main():
if len(sys.argv) < 2 or sys.argv[1] not in subcommands:
usage()
return
main_fn = subcommands[sys.argv[1]][1]
canned_argv = ['decanlp-' + sys.argv[1]] + sys.argv[2:]
main_fn(canned_argv)
if __name__ == '__main__':
main()

View File

@ -5,6 +5,7 @@ import math
import time
import random
import collections
import sys
from copy import deepcopy
import logging
@ -365,8 +366,8 @@ def init_opt(args, model):
return opt
def main():
args = arguments.parse()
def main(argv=sys.argv):
args = arguments.parse(argv)
if args is None:
return
set_seed(args)

View File

@ -1,4 +1,3 @@
from text import torchtext
import time
import os
import sys
@ -6,7 +5,8 @@ import torch
import random
import numpy as np
from text.torchtext.data.utils import get_tokenizer
from .text import torchtext
from .text.torchtext.data.utils import get_tokenizer
def tokenizer(s):
return s.split()

View File

@ -1,7 +1,7 @@
import torch
from .util import pad, tokenizer
from .metrics import compute_metrics
from text.torchtext.data.utils import get_tokenizer
from .text.torchtext.data.utils import get_tokenizer
def compute_validation_outputs(model, val_iter, field, iteration, optional_names=[]):

View File

@ -37,8 +37,9 @@ setuptools.setup(
version='0.1dev',
packages=setuptools.find_packages(exclude=['tests']),
scripts=[],
entry_points= {
'console_scripts': ['decanlp=decanlp.tool:main'],
},
license='BSD-3-Clause',
author="Salesforce Inc.",
long_description=long_description,