diff --git a/decanlp/arguments.py b/decanlp/arguments.py index 129b31c2..dcc3bc0a 100644 --- a/decanlp/arguments.py +++ b/decanlp/arguments.py @@ -20,11 +20,11 @@ def save_args(args): json.dump(vars(args), f, indent=2) -def parse(): +def parse(argv): """ Returns the arguments from the command line. """ - parser = ArgumentParser() + parser = ArgumentParser(prog=argv[0]) parser.add_argument('--root', default='./decaNLP', type=str, help='root directory for data, results, embeddings, code, etc.') parser.add_argument('--data', default='.data/', type=str, help='where to load data from.') parser.add_argument('--save', default='results', type=str, help='where to save results.') @@ -96,7 +96,7 @@ def parse(): parser.add_argument('--loss_switch', default=0.666, type=float, help='switch to BLEU loss after certain iterations controlled by this ratio') - args = parser.parse_args() + args = parser.parse_args(argv) if args.model is None: args.model = 'mcqa' if args.val_tasks is None: diff --git a/convert_to_logical_forms.py b/decanlp/convert_to_logical_forms.py similarity index 96% rename from convert_to_logical_forms.py rename to decanlp/convert_to_logical_forms.py index 9a4e0a3b..37dadb9e 100644 --- a/convert_to_logical_forms.py +++ b/decanlp/convert_to_logical_forms.py @@ -3,6 +3,7 @@ from decanlp.text.torchtext.datasets.generic import Query from argparse import ArgumentParser import os import re +import sys import ujson as json from decanlp.metrics import to_lf @@ -64,14 +65,14 @@ def write_logical_forms(greedy, args): except Exception as e: f.write(json.dumps(correct_format({})) + '\n') -if __name__ == '__main__': - parser = ArgumentParser() +def main(argv=sys.argv): + parser = ArgumentParser(prog=argv[0]) parser.add_argument('data', help='path to the directory containing data for WikiSQL') parser.add_argument('predictions', help='path to prediction file, containing one prediction per line') parser.add_argument('ids', help='path to file for indices, a list of integers indicating the index into the dev/test set of the predictions on the corresponding line in \'predicitons\'') parser.add_argument('output', help='path for logical forms output line by line') parser.add_argument('evaluate', help='running on the \'validation\' or \'test\' set') - args = parser.parse_args() + args = parser.parse_args(argv) with open(args.predictions) as f: greedy = [l for l in f] if args.ids is not None: @@ -79,3 +80,6 @@ if __name__ == '__main__': ids = [int(l.strip()) for l in f] greedy = [x[1] for x in sorted([(i, g) for i, g in zip(ids, greedy)])] write_logical_forms(greedy, args) + +if __name__ == '__main__': + main() diff --git a/decanlp/metrics.py b/decanlp/metrics.py index 74b29a2f..08563762 100644 --- a/decanlp/metrics.py +++ b/decanlp/metrics.py @@ -1,6 +1,6 @@ from subprocess import Popen, PIPE, CalledProcessError import json -from text.torchtext.datasets.generic import Query +from .text.torchtext.datasets.generic import Query import logging import os import re diff --git a/decanlp/predict.py b/decanlp/predict.py index 09156c6a..2782aa26 100644 --- a/decanlp/predict.py +++ b/decanlp/predict.py @@ -7,9 +7,10 @@ import ujson as json import torch import numpy as np import random +import sys from pprint import pformat -from .util import get_splits, set_seed, preprocess_examples, tokenizer +from .util import get_splits, set_seed, preprocess_examples from .metrics import compute_metrics from . import models @@ -207,8 +208,8 @@ def run(args, field, val_sets, model): print(f'\nSummary: | {sum(decaScore)} | {" | ".join([str(x) for x in decaScore])} |\n') -def get_args(): - parser = ArgumentParser() +def get_args(argv): + parser = ArgumentParser(prog=argv[0]) parser.add_argument('--path', required=True) parser.add_argument('--evaluate', type=str, required=True) parser.add_argument('--tasks', default=['almond', 'squad', 'iwslt.en.de', 'cnn_dailymail', 'multinli.in.out', 'sst', 'srl', 'zre', 'woz.en', 'wikisql', 'schema'], nargs='+') @@ -227,7 +228,7 @@ def get_args(): parser.add_argument('--eval_dir', type=str, default=None, help='use this directory to store eval results') parser.add_argument('--cached', default='', type=str, help='where to save cached files') - args = parser.parse_args() + args = parser.parse_args(argv) with open(os.path.join(args.path, 'config.json')) as config_file: config = json.load(config_file) @@ -267,7 +268,7 @@ def get_args(): return args -if __name__ == '__main__': +def main(argv=sys.argv): args = get_args() print(f'Arguments:\n{pformat(vars(args))}') @@ -299,3 +300,6 @@ if __name__ == '__main__': model.set_embeddings(field.vocab.vectors) run(args, field, splits, model) + +if __name__ == '__main__': + main() diff --git a/server.py b/decanlp/server.py similarity index 82% rename from server.py rename to decanlp/server.py index 97bcc6a0..1c3345ea 100644 --- a/server.py +++ b/decanlp/server.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 import os -from text import torchtext from argparse import ArgumentParser import ujson as json import torch @@ -11,11 +10,12 @@ import logging from copy import deepcopy from pprint import pformat -from util import set_seed -import models +from .util import set_seed +from . import models -from text.torchtext.data import Example -from text.torchtext.datasets.generic import CONTEXT_SPECIAL, QUESTION_SPECIAL, get_context_question, CQA +from .text import torchtext +from .text.torchtext.data import Example +from .text.torchtext.datasets.generic import CONTEXT_SPECIAL, QUESTION_SPECIAL, get_context_question, CQA logger = logging.getLogger(__name__) @@ -37,14 +37,14 @@ class Server(): def prepare_data(self): print(f'Vocabulary has {len(self.field.vocab)} tokens from training') - char_vectors = torchtext.vocab.CharNGram(cache=args.embeddings) - glove_vectors = torchtext.vocab.GloVe(cache=args.embeddings) + char_vectors = torchtext.vocab.CharNGram(cache=self.args.embeddings) + glove_vectors = torchtext.vocab.GloVe(cache=self.args.embeddings) vectors = [char_vectors, glove_vectors] self.field.vocab.load_vectors(vectors, True) self.field.decoder_to_vocab = {idx: self.field.vocab.stoi[word] for idx, word in enumerate(self.field.decoder_itos)} self.field.vocab_to_decoder = {idx: self.field.decoder_stoi[word] for idx, word in enumerate(self.field.vocab.itos) if word in self.field.decoder_stoi} - self._limited_idx_to_full_idx = deepcopy(field.decoder_to_vocab) # should avoid this with a conditional in map to full + self._limited_idx_to_full_idx = deepcopy(self.field.decoder_to_vocab) # should avoid this with a conditional in map to full self._oov_to_limited_idx = {} assert self.field.include_lengths @@ -55,8 +55,8 @@ class Server(): for name in CQA.fields: # batch of size 1 batch = [getattr(ex, name)] - entry, lengths, limited_entry, raw = field.process(batch, device=self.device, train=True, - limited=field.decoder_stoi, l2f=self._limited_idx_to_full_idx, oov2l=self._oov_to_limited_idx) + entry, lengths, limited_entry, raw = self.field.process(batch, device=self.device, train=True, + limited=self.field.decoder_stoi, l2f=self._limited_idx_to_full_idx, oov2l=self._oov_to_limited_idx) setattr(processed, name, entry) setattr(processed, f'{name}_lengths', lengths) setattr(processed, f'{name}_limited', limited_entry) @@ -81,16 +81,16 @@ class Server(): tokenize = None context_question = get_context_question(context, question) - fields = [(x, field) for x in CQA.fields] + fields = [(x, self.field) for x in CQA.fields] ex = Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields, tokenize=tokenize) batch = self.numericalize_example(ex) - _, prediction_batch = model(batch, iteration=0) + _, prediction_batch = self.model(batch, iteration=0) if task == 'almond': - predictions = field.reverse(prediction_batch, detokenize=lambda x: ' '.join(x)) + predictions = self.field.reverse(prediction_batch, detokenize=lambda x: ' '.join(x)) else: - predictions = field.reverse(prediction_batch) + predictions = self.field.reverse(prediction_batch) client_writer.write((json.dumps(dict(id=request['id'], answer=predictions[0])) + '\n').encode('utf-8')) @@ -110,15 +110,15 @@ class Server(): this_r *= s r += this_r return r - params = list(filter(lambda p: p.requires_grad, model.parameters())) + params = list(filter(lambda p: p.requires_grad, self.model.parameters())) num_param = mult(params) - print(f'{args.model} has {num_param:,} parameters') - model.to(self.device) + print(f'{self.args.model} has {num_param:,} parameters') + self.model.to(self.device) - model.eval() + self.model.eval() with torch.no_grad(): loop = asyncio.get_event_loop() - server = loop.run_until_complete(asyncio.start_server(self.handle_client, port=args.port)) + server = loop.run_until_complete(asyncio.start_server(self.handle_client, port=self.args.port)) try: loop.run_forever() except KeyboardInterrupt: @@ -128,8 +128,8 @@ class Server(): loop.close() -def get_args(): - parser = ArgumentParser() +def get_args(argv): + parser = ArgumentParser(prog=argv[0]) parser.add_argument('--path', required=True) parser.add_argument('--devices', default=[0], nargs='+', type=int, help='a list of devices that can be used (multi-gpu currently WIP)') parser.add_argument('--seed', default=123, type=int, help='Random seed.') @@ -138,7 +138,7 @@ def get_args(): parser.add_argument('--checkpoint_name', default='best.pth', help='Checkpoint file to use (relative to --path, defaults to best.pth)') parser.add_argument('--port', default=8401, type=int, help='TCP port to listen on') - args = parser.parse_args() + args = parser.parse_args(argv) with open(os.path.join(args.path, 'config.json')) as config_file: config = json.load(config_file) @@ -179,8 +179,8 @@ def get_args(): return args -if __name__ == '__main__': - args = get_args() +def main(argv=sys.argv): + args = get_args(argv) print(f'Arguments:\n{pformat(vars(args))}') np.random.seed(args.seed) diff --git a/decanlp/tool.py b/decanlp/tool.py new file mode 100755 index 00000000..93ddb09d --- /dev/null +++ b/decanlp/tool.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# Copyright 2019 The Board of Trustees of the Leland Stanford Junior University +# +# Author: Giovanni Campagna +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys + +from decanlp import convert_to_logical_forms, train, predict, server + +subcommands = { + 'convert-to-logical-froms': ('Convert to logical forms (for SQL tasks)', convert_to_logical_forms.main), + 'train': ('Train a model', train.main), + 'predict': ('Evaluate a model, or compute predictions on a test dataset', predict.main), + 'server': ('Export RPC interface to predict', server.main) +} + +def usage(): + print('Usage: %s SUBCOMMAND [OPTIONS]' % (sys.argv[0])) + print() + print('Available subcommands:') + for subcommand,(help_text,_) in subcommands.items(): + print(' %s - %s' % (subcommand, help_text)) + sys.exit(1) + +def main(): + if len(sys.argv) < 2 or sys.argv[1] not in subcommands: + usage() + return + + main_fn = subcommands[sys.argv[1]][1] + canned_argv = ['decanlp-' + sys.argv[1]] + sys.argv[2:] + main_fn(canned_argv) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/decanlp/train.py b/decanlp/train.py index 0e03181e..56f98d1c 100644 --- a/decanlp/train.py +++ b/decanlp/train.py @@ -5,6 +5,7 @@ import math import time import random import collections +import sys from copy import deepcopy import logging @@ -365,8 +366,8 @@ def init_opt(args, model): return opt -def main(): - args = arguments.parse() +def main(argv=sys.argv): + args = arguments.parse(argv) if args is None: return set_seed(args) diff --git a/decanlp/util.py b/decanlp/util.py index 06f49f1e..2951c6e9 100644 --- a/decanlp/util.py +++ b/decanlp/util.py @@ -1,4 +1,3 @@ -from text import torchtext import time import os import sys @@ -6,7 +5,8 @@ import torch import random import numpy as np -from text.torchtext.data.utils import get_tokenizer +from .text import torchtext +from .text.torchtext.data.utils import get_tokenizer def tokenizer(s): return s.split() diff --git a/decanlp/validate.py b/decanlp/validate.py index 7422089c..f3b363c4 100644 --- a/decanlp/validate.py +++ b/decanlp/validate.py @@ -1,7 +1,7 @@ import torch from .util import pad, tokenizer from .metrics import compute_metrics -from text.torchtext.data.utils import get_tokenizer +from .text.torchtext.data.utils import get_tokenizer def compute_validation_outputs(model, val_iter, field, iteration, optional_names=[]): diff --git a/setup.py b/setup.py index 17ede647..4b2d188e 100644 --- a/setup.py +++ b/setup.py @@ -37,8 +37,9 @@ setuptools.setup( version='0.1dev', packages=setuptools.find_packages(exclude=['tests']), - scripts=[], - + entry_points= { + 'console_scripts': ['decanlp=decanlp.tool:main'], + }, license='BSD-3-Clause', author="Salesforce Inc.", long_description=long_description,