genienlp/decanlp/arguments.py

188 lines
12 KiB
Python
Raw Normal View History

#
# Copyright (c) 2018, Salesforce, Inc.
# The Board of Trustees of the Leland Stanford Junior University
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2018-06-20 06:22:34 +00:00
import os
from argparse import ArgumentParser
import subprocess
import json
import datetime
import logging
2018-06-20 06:22:34 +00:00
from .tasks.registry import get_tasks
logger = logging.getLogger(__name__)
2018-06-20 06:22:34 +00:00
def get_commit():
directory = os.path.dirname(__file__)
2018-06-20 06:22:34 +00:00
return subprocess.Popen("cd {} && git log | head -n 1".format(directory), shell=True, stdout=subprocess.PIPE).stdout.read().split()[1].decode()
def save_args(args):
os.makedirs(args.log_dir, exist_ok=args.exist_ok)
with open(os.path.join(args.log_dir, 'config.json'), 'wt') as f:
json.dump(vars(args), f, indent=2)
def parse(argv):
2018-06-20 06:22:34 +00:00
"""
Returns the arguments from the command line.
"""
parser = ArgumentParser(prog=argv[0])
2019-01-24 00:41:37 +00:00
parser.add_argument('--root', default='./decaNLP', type=str, help='root directory for data, results, embeddings, code, etc.')
parser.add_argument('--data', default='.data/', type=str, help='where to load data from.')
parser.add_argument('--save', default='results', type=str, help='where to save results.')
parser.add_argument('--embeddings', default='.embeddings', type=str, help='where to save embeddings.')
parser.add_argument('--cached', default='', type=str, help='where to save cached files')
2019-05-20 18:02:42 +00:00
parser.add_argument('--saved_models', default='./saved_models', type=str, help='directory where cached models should be loaded from')
2018-06-20 06:22:34 +00:00
parser.add_argument('--train_tasks', nargs='+', type=str, dest='train_task_names', help='tasks to use for training', required=True)
2018-06-20 06:22:34 +00:00
parser.add_argument('--train_iterations', nargs='+', type=int, help='number of iterations to focus on each task')
parser.add_argument('--train_batch_tokens', nargs='+', default=[9000], type=int, help='Number of tokens to use for dynamic batching, corresponging to tasks in train tasks')
2018-06-20 06:22:34 +00:00
parser.add_argument('--jump_start', default=0, type=int, help='number of iterations to give jump started tasks')
parser.add_argument('--n_jump_start', default=0, type=int, help='how many tasks to jump start (presented in order)')
parser.add_argument('--num_print', default=15, type=int, help='how many validation examples with greedy output to print to std out')
parser.add_argument('--no_tensorboard', action='store_false', dest='tensorboard', help='Turn off tensorboard logging')
parser.add_argument('--tensorboard_dir', default=None, help='Directory where to save Tensorboard logs (defaults to --save)')
parser.add_argument('--max_to_keep', default=5, type=int, help='number of checkpoints to keep')
2018-06-20 06:22:34 +00:00
parser.add_argument('--log_every', default=int(1e2), type=int, help='how often to log results in # of iterations')
parser.add_argument('--save_every', default=int(1e3), type=int, help='how often to save a checkpoint in # of iterations')
parser.add_argument('--val_tasks', nargs='+', type=str, dest='val_task_names', help='tasks to collect evaluation metrics for')
2018-06-20 06:22:34 +00:00
parser.add_argument('--val_every', default=int(1e3), type=int, help='how often to run validation in # of iterations')
parser.add_argument('--val_no_filter', action='store_false', dest='val_filter', help='whether to allow filtering on the validation sets')
parser.add_argument('--val_batch_size', nargs='+', default=[256], type=int, help='Batch size for validation corresponding to tasks in val tasks')
parser.add_argument('--vocab_tasks', nargs='+', type=str, help='tasks to use in the construction of the vocabulary')
parser.add_argument('--max_output_length', default=100, type=int, help='maximum output length for generation')
parser.add_argument('--max_generative_vocab', default=50000, type=int, help='max vocabulary for the generative softmax')
parser.add_argument('--max_train_context_length', default=500, type=int, help='maximum length of the contexts during training')
parser.add_argument('--max_val_context_length', default=500, type=int, help='maximum length of the contexts during validation')
2018-06-20 06:22:34 +00:00
parser.add_argument('--max_answer_length', default=50, type=int, help='maximum length of answers during training and validation')
parser.add_argument('--subsample', default=20000000, type=int, help='subsample the datasets')
parser.add_argument('--preserve_case', action='store_false', dest='lower', help='whether to preserve casing for all text')
2020-01-16 23:34:16 +00:00
parser.add_argument('--model', type=str, choices=['Seq2Seq'], default='Seq2Seq', help='which model to import')
parser.add_argument('--seq2seq_encoder', type=str, choices=['MQANEncoder', 'BiLSTM', 'Identity'],
default='MQANEncoder', help='which encoder to use for the Seq2Seq model')
2020-01-16 23:34:16 +00:00
parser.add_argument('--seq2seq_decoder', type=str, choices=['MQANDecoder'], default='MQANDecoder',
help='which decoder to use for the Seq2Seq model')
2018-06-20 06:22:34 +00:00
parser.add_argument('--dimension', default=200, type=int, help='output dimensions for all layers')
parser.add_argument('--rnn_dimension', default=None, type=int, help='output dimensions for RNN layers')
2018-06-20 06:22:34 +00:00
parser.add_argument('--rnn_layers', default=1, type=int, help='number of layers for RNN modules')
parser.add_argument('--rnn_zero_state', default='zero', choices=['zero', 'average'],
help='how to construct RNN zero state (for Identity encoder)')
2018-06-20 06:22:34 +00:00
parser.add_argument('--transformer_layers', default=2, type=int, help='number of layers for transformer modules')
parser.add_argument('--transformer_hidden', default=150, type=int, help='hidden size of the transformer modules')
parser.add_argument('--transformer_heads', default=3, type=int, help='number of heads for transformer modules')
parser.add_argument('--dropout_ratio', default=0.2, type=float, help='dropout for the model')
2020-01-16 23:34:16 +00:00
parser.add_argument('--encoder_embeddings', default='glove+char', help='which word embedding to use on the encoder side; use a bert-* pretrained model for BERT; multiple embeddings can be concatenated with +')
parser.add_argument('--train_encoder_embeddings', action='store_true', default=False, help='back propagate into pretrained encoder embedding (recommended for BERT)')
parser.add_argument('--decoder_embeddings', default='glove+char', help='which pretrained word embedding to use on the decoder side')
parser.add_argument('--trainable_decoder_embeddings', default=0, type=int, help='size of trainable portion of decoder embedding (0 or omit to disable)')
2018-06-20 06:22:34 +00:00
parser.add_argument('--warmup', default=800, type=int, help='warmup for learning rate')
parser.add_argument('--grad_clip', default=1.0, type=float, help='gradient clipping')
parser.add_argument('--beta0', default=0.9, type=float, help='alternative momentum for Adam (only when not using transformer_lr)')
2019-01-09 01:13:46 +00:00
parser.add_argument('--optimizer', default='adam', type=str, help='Adam or SGD')
parser.add_argument('--no_transformer_lr', action='store_false', dest='transformer_lr', help='turns off the transformer learning rate strategy')
parser.add_argument('--transformer_lr_multiply', default=1.0, type=float, help='multiplier for transformer learning rate (if using Adam)')
parser.add_argument('--lr_rate', default=0.001, type=float, help='fixed learning rate (if not using warmup)')
2019-03-13 21:19:41 +00:00
parser.add_argument('--weight_decay', default=0.0, type=float, help='weight L2 regularization')
2018-06-20 06:22:34 +00:00
parser.add_argument('--load', default=None, type=str, help='path to checkpoint to load model from inside args.save')
parser.add_argument('--resume', action='store_true', help='whether to resume training with past optimizers')
parser.add_argument('--seed', default=123, type=int, help='Random seed.')
parser.add_argument('--devices', default=[0], nargs='+', type=int, help='a list of devices that can be used for training')
2018-06-20 06:22:34 +00:00
parser.add_argument('--no_commit', action='store_false', dest='commit', help='do not track the git commit associated with this training run')
2018-06-20 06:22:34 +00:00
parser.add_argument('--exist_ok', action='store_true', help='Ok if the save directory already exists, i.e. overwrite is ok')
parser.add_argument('--skip_cache', action='store_true', dest='skip_cache_bool', help='whether to use exisiting cached splits or generate new ones')
parser.add_argument('--use_curriculum', action='store_true', help='Use curriculum learning')
2019-03-15 18:17:42 +00:00
parser.add_argument('--aux_dataset', default='', type=str, help='path to auxiliary dataset (ignored if curriculum is not used)')
parser.add_argument('--curriculum_max_frac', default=1.0, type=float, help='max fraction of harder dataset to keep for curriculum')
parser.add_argument('--curriculum_rate', default=0.1, type=float, help='growth rate for curriculum')
parser.add_argument('--curriculum_strategy', default='linear', type=str, choices=['linear', 'exp'], help='growth strategy for curriculum')
parser.add_argument('--question', type=str, help='provide a fixed question')
2019-05-22 21:04:16 +00:00
parser.add_argument('--use_google_translate', action='store_true', help='use google translate instead of pre-trained machine translator')
2019-03-02 00:13:10 +00:00
args = parser.parse_args(argv[1:])
if args.val_task_names is None:
args.val_task_names = []
for t in args.train_task_names:
if t not in args.val_task_names:
args.val_task_names.append(t)
if 'imdb' in args.val_task_names:
args.val_task_names.remove('imdb')
args.timestamp = datetime.datetime.now(tz=datetime.timezone.utc).isoformat()
2018-06-20 06:22:34 +00:00
2019-05-22 21:04:16 +00:00
if args.use_google_translate:
args.data = args.data + '_google_translate'
2019-03-19 17:06:20 +00:00
if len(args.train_task_names) > 1:
2019-05-22 21:04:16 +00:00
if args.train_iterations is None:
2018-06-20 06:22:34 +00:00
args.train_iterations = [1]
2019-03-19 17:06:20 +00:00
if len(args.train_iterations) < len(args.train_task_names):
args.train_iterations = len(args.train_task_names) * args.train_iterations
if len(args.train_batch_tokens) < len(args.train_task_names):
args.train_batch_tokens = len(args.train_task_names) * args.train_batch_tokens
if len(args.val_batch_size) < len(args.val_task_names):
args.val_batch_size = len(args.val_task_names) * args.val_batch_size
2018-06-20 06:22:34 +00:00
# postprocess arguments
if args.commit:
args.commit = get_commit()
else:
args.commit = ''
if args.rnn_dimension is None:
args.rnn_dimension = args.dimension
args.log_dir = args.save
if args.tensorboard_dir is None:
args.tensorboard_dir = args.log_dir
2018-06-20 06:22:34 +00:00
args.dist_sync_file = os.path.join(args.log_dir, 'distributed_sync_file')
2018-11-30 00:19:13 +00:00
2019-01-10 21:24:43 +00:00
for x in ['data', 'save', 'embeddings', 'log_dir', 'dist_sync_file']:
setattr(args, x, os.path.join(args.root, getattr(args, x)))
2018-06-20 06:22:34 +00:00
save_args(args)
2019-03-19 17:06:20 +00:00
# create the task objects after we saved the configuration to the JSON file, because
# tasks are not JSON serializable
args.train_tasks = get_tasks(args.train_task_names, args)
args.val_tasks = get_tasks(args.val_task_names, args)
2019-03-19 17:06:20 +00:00
2018-06-20 06:22:34 +00:00
return args