2019-03-01 23:51:45 +00:00
#
# Copyright (c) 2018, Salesforce, Inc.
2019-03-01 23:54:54 +00:00
# The Board of Trustees of the Leland Stanford Junior University
2019-03-01 23:51:45 +00:00
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2018-06-20 06:22:34 +00:00
import os
from argparse import ArgumentParser
import subprocess
import json
import datetime
2019-03-02 01:35:04 +00:00
import logging
2018-06-20 06:22:34 +00:00
2019-03-19 17:01:45 +00:00
from . tasks . registry import get_tasks
2019-03-02 01:35:04 +00:00
logger = logging . getLogger ( __name__ )
2018-06-20 06:22:34 +00:00
def get_commit ( ) :
2019-03-02 18:28:04 +00:00
directory = os . path . dirname ( __file__ )
2018-06-20 06:22:34 +00:00
return subprocess . Popen ( " cd {} && git log | head -n 1 " . format ( directory ) , shell = True , stdout = subprocess . PIPE ) . stdout . read ( ) . split ( ) [ 1 ] . decode ( )
def save_args ( args ) :
os . makedirs ( args . log_dir , exist_ok = args . exist_ok )
with open ( os . path . join ( args . log_dir , ' config.json ' ) , ' wt ' ) as f :
json . dump ( vars ( args ) , f , indent = 2 )
2019-01-23 20:08:41 +00:00
def parse ( argv ) :
2018-06-20 06:22:34 +00:00
"""
Returns the arguments from the command line .
"""
2019-01-23 20:08:41 +00:00
parser = ArgumentParser ( prog = argv [ 0 ] )
2019-01-24 00:41:37 +00:00
parser . add_argument ( ' --root ' , default = ' ./decaNLP ' , type = str , help = ' root directory for data, results, embeddings, code, etc. ' )
2019-01-08 02:05:55 +00:00
parser . add_argument ( ' --data ' , default = ' .data/ ' , type = str , help = ' where to load data from. ' )
parser . add_argument ( ' --save ' , default = ' results ' , type = str , help = ' where to save results. ' )
parser . add_argument ( ' --embeddings ' , default = ' .embeddings ' , type = str , help = ' where to save embeddings. ' )
2019-02-19 23:55:20 +00:00
parser . add_argument ( ' --cached ' , default = ' ' , type = str , help = ' where to save cached files ' )
2019-05-20 18:02:42 +00:00
parser . add_argument ( ' --saved_models ' , default = ' ./saved_models ' , type = str , help = ' directory where cached models should be loaded from ' )
2018-06-20 06:22:34 +00:00
2019-03-19 17:01:45 +00:00
parser . add_argument ( ' --train_tasks ' , nargs = ' + ' , type = str , dest = ' train_task_names ' , help = ' tasks to use for training ' , required = True )
2018-06-20 06:22:34 +00:00
parser . add_argument ( ' --train_iterations ' , nargs = ' + ' , type = int , help = ' number of iterations to focus on each task ' )
2018-10-24 00:44:39 +00:00
parser . add_argument ( ' --train_batch_tokens ' , nargs = ' + ' , default = [ 9000 ] , type = int , help = ' Number of tokens to use for dynamic batching, corresponging to tasks in train tasks ' )
2018-06-20 06:22:34 +00:00
parser . add_argument ( ' --jump_start ' , default = 0 , type = int , help = ' number of iterations to give jump started tasks ' )
parser . add_argument ( ' --n_jump_start ' , default = 0 , type = int , help = ' how many tasks to jump start (presented in order) ' )
parser . add_argument ( ' --num_print ' , default = 15 , type = int , help = ' how many validation examples with greedy output to print to std out ' )
2019-12-15 01:21:23 +00:00
parser . add_argument ( ' --no_tensorboard ' , action = ' store_false ' , dest = ' tensorboard ' , help = ' Turn off tensorboard logging ' )
parser . add_argument ( ' --tensorboard_dir ' , default = None , help = ' Directory where to save Tensorboard logs (defaults to --save) ' )
2019-03-04 20:03:12 +00:00
parser . add_argument ( ' --max_to_keep ' , default = 5 , type = int , help = ' number of checkpoints to keep ' )
2018-06-20 06:22:34 +00:00
parser . add_argument ( ' --log_every ' , default = int ( 1e2 ) , type = int , help = ' how often to log results in # of iterations ' )
parser . add_argument ( ' --save_every ' , default = int ( 1e3 ) , type = int , help = ' how often to save a checkpoint in # of iterations ' )
2019-03-19 17:01:45 +00:00
parser . add_argument ( ' --val_tasks ' , nargs = ' + ' , type = str , dest = ' val_task_names ' , help = ' tasks to collect evaluation metrics for ' )
2018-06-20 06:22:34 +00:00
parser . add_argument ( ' --val_every ' , default = int ( 1e3 ) , type = int , help = ' how often to run validation in # of iterations ' )
parser . add_argument ( ' --val_no_filter ' , action = ' store_false ' , dest = ' val_filter ' , help = ' whether to allow filtering on the validation sets ' )
parser . add_argument ( ' --val_batch_size ' , nargs = ' + ' , default = [ 256 ] , type = int , help = ' Batch size for validation corresponding to tasks in val tasks ' )
parser . add_argument ( ' --vocab_tasks ' , nargs = ' + ' , type = str , help = ' tasks to use in the construction of the vocabulary ' )
parser . add_argument ( ' --max_output_length ' , default = 100 , type = int , help = ' maximum output length for generation ' )
parser . add_argument ( ' --max_generative_vocab ' , default = 50000 , type = int , help = ' max vocabulary for the generative softmax ' )
2019-03-15 21:59:15 +00:00
parser . add_argument ( ' --max_train_context_length ' , default = 500 , type = int , help = ' maximum length of the contexts during training ' )
parser . add_argument ( ' --max_val_context_length ' , default = 500 , type = int , help = ' maximum length of the contexts during validation ' )
2018-06-20 06:22:34 +00:00
parser . add_argument ( ' --max_answer_length ' , default = 50 , type = int , help = ' maximum length of answers during training and validation ' )
parser . add_argument ( ' --subsample ' , default = 20000000 , type = int , help = ' subsample the datasets ' )
parser . add_argument ( ' --preserve_case ' , action = ' store_false ' , dest = ' lower ' , help = ' whether to preserve casing for all text ' )
2020-01-16 23:34:16 +00:00
parser . add_argument ( ' --model ' , type = str , choices = [ ' Seq2Seq ' ] , default = ' Seq2Seq ' , help = ' which model to import ' )
2020-01-18 05:41:31 +00:00
parser . add_argument ( ' --seq2seq_encoder ' , type = str , choices = [ ' MQANEncoder ' , ' BiLSTM ' , ' Identity ' ] ,
default = ' MQANEncoder ' , help = ' which encoder to use for the Seq2Seq model ' )
2020-01-16 23:34:16 +00:00
parser . add_argument ( ' --seq2seq_decoder ' , type = str , choices = [ ' MQANDecoder ' ] , default = ' MQANDecoder ' ,
help = ' which decoder to use for the Seq2Seq model ' )
2018-06-20 06:22:34 +00:00
parser . add_argument ( ' --dimension ' , default = 200 , type = int , help = ' output dimensions for all layers ' )
2020-01-19 01:44:11 +00:00
parser . add_argument ( ' --rnn_dimension ' , default = None , type = int , help = ' output dimensions for RNN layers ' )
2018-06-20 06:22:34 +00:00
parser . add_argument ( ' --rnn_layers ' , default = 1 , type = int , help = ' number of layers for RNN modules ' )
2020-01-19 21:46:25 +00:00
parser . add_argument ( ' --rnn_zero_state ' , default = ' zero ' , choices = [ ' zero ' , ' average ' ] ,
help = ' how to construct RNN zero state (for Identity encoder) ' )
2018-06-20 06:22:34 +00:00
parser . add_argument ( ' --transformer_layers ' , default = 2 , type = int , help = ' number of layers for transformer modules ' )
parser . add_argument ( ' --transformer_hidden ' , default = 150 , type = int , help = ' hidden size of the transformer modules ' )
parser . add_argument ( ' --transformer_heads ' , default = 3 , type = int , help = ' number of heads for transformer modules ' )
parser . add_argument ( ' --dropout_ratio ' , default = 0.2 , type = float , help = ' dropout for the model ' )
2020-01-16 23:34:16 +00:00
parser . add_argument ( ' --encoder_embeddings ' , default = ' glove+char ' , help = ' which word embedding to use on the encoder side; use a bert-* pretrained model for BERT; multiple embeddings can be concatenated with + ' )
parser . add_argument ( ' --train_encoder_embeddings ' , action = ' store_true ' , default = False , help = ' back propagate into pretrained encoder embedding (recommended for BERT) ' )
parser . add_argument ( ' --decoder_embeddings ' , default = ' glove+char ' , help = ' which pretrained word embedding to use on the decoder side ' )
parser . add_argument ( ' --trainable_decoder_embeddings ' , default = 0 , type = int , help = ' size of trainable portion of decoder embedding (0 or omit to disable) ' )
2018-06-20 06:22:34 +00:00
parser . add_argument ( ' --warmup ' , default = 800 , type = int , help = ' warmup for learning rate ' )
parser . add_argument ( ' --grad_clip ' , default = 1.0 , type = float , help = ' gradient clipping ' )
parser . add_argument ( ' --beta0 ' , default = 0.9 , type = float , help = ' alternative momentum for Adam (only when not using transformer_lr) ' )
2019-01-09 01:13:46 +00:00
parser . add_argument ( ' --optimizer ' , default = ' adam ' , type = str , help = ' Adam or SGD ' )
2020-01-18 00:13:54 +00:00
parser . add_argument ( ' --no_transformer_lr ' , action = ' store_false ' , dest = ' transformer_lr ' , help = ' turns off the transformer learning rate strategy ' )
2020-01-18 07:11:21 +00:00
parser . add_argument ( ' --transformer_lr_multiply ' , default = 1.0 , type = float , help = ' multiplier for transformer learning rate (if using Adam) ' )
parser . add_argument ( ' --lr_rate ' , default = 0.001 , type = float , help = ' fixed learning rate (if not using warmup) ' )
2019-03-13 21:19:41 +00:00
parser . add_argument ( ' --weight_decay ' , default = 0.0 , type = float , help = ' weight L2 regularization ' )
2018-06-20 06:22:34 +00:00
parser . add_argument ( ' --load ' , default = None , type = str , help = ' path to checkpoint to load model from inside args.save ' )
parser . add_argument ( ' --resume ' , action = ' store_true ' , help = ' whether to resume training with past optimizers ' )
parser . add_argument ( ' --seed ' , default = 123 , type = int , help = ' Random seed. ' )
2020-01-14 17:43:09 +00:00
parser . add_argument ( ' --devices ' , default = [ 0 ] , nargs = ' + ' , type = int , help = ' a list of devices that can be used for training ' )
2018-06-20 06:22:34 +00:00
2018-06-27 21:18:45 +00:00
parser . add_argument ( ' --no_commit ' , action = ' store_false ' , dest = ' commit ' , help = ' do not track the git commit associated with this training run ' )
2018-06-20 06:22:34 +00:00
parser . add_argument ( ' --exist_ok ' , action = ' store_true ' , help = ' Ok if the save directory already exists, i.e. overwrite is ok ' )
2018-11-27 23:22:38 +00:00
parser . add_argument ( ' --skip_cache ' , action = ' store_true ' , dest = ' skip_cache_bool ' , help = ' whether to use exisiting cached splits or generate new ones ' )
2019-03-12 17:47:57 +00:00
parser . add_argument ( ' --use_curriculum ' , action = ' store_true ' , help = ' Use curriculum learning ' )
2019-03-15 18:17:42 +00:00
parser . add_argument ( ' --aux_dataset ' , default = ' ' , type = str , help = ' path to auxiliary dataset (ignored if curriculum is not used) ' )
2019-03-12 17:47:57 +00:00
parser . add_argument ( ' --curriculum_max_frac ' , default = 1.0 , type = float , help = ' max fraction of harder dataset to keep for curriculum ' )
parser . add_argument ( ' --curriculum_rate ' , default = 0.1 , type = float , help = ' growth rate for curriculum ' )
parser . add_argument ( ' --curriculum_strategy ' , default = ' linear ' , type = str , choices = [ ' linear ' , ' exp ' ] , help = ' growth strategy for curriculum ' )
2019-05-13 20:03:51 +00:00
parser . add_argument ( ' --question ' , type = str , help = ' provide a fixed question ' )
2019-05-22 21:04:16 +00:00
parser . add_argument ( ' --use_google_translate ' , action = ' store_true ' , help = ' use google translate instead of pre-trained machine translator ' )
2018-11-07 23:06:41 +00:00
2019-03-02 00:13:10 +00:00
args = parser . parse_args ( argv [ 1 : ] )
2020-01-14 18:52:39 +00:00
2019-03-19 17:01:45 +00:00
if args . val_task_names is None :
args . val_task_names = [ ]
for t in args . train_task_names :
if t not in args . val_task_names :
args . val_task_names . append ( t )
if ' imdb ' in args . val_task_names :
args . val_task_names . remove ( ' imdb ' )
2020-01-17 04:57:40 +00:00
args . timestamp = datetime . datetime . now ( tz = datetime . timezone . utc ) . isoformat ( )
2018-06-20 06:22:34 +00:00
2019-05-22 21:04:16 +00:00
if args . use_google_translate :
args . data = args . data + ' _google_translate '
2019-03-19 17:06:20 +00:00
if len ( args . train_task_names ) > 1 :
2019-05-22 21:04:16 +00:00
if args . train_iterations is None :
2018-06-20 06:22:34 +00:00
args . train_iterations = [ 1 ]
2019-03-19 17:06:20 +00:00
if len ( args . train_iterations ) < len ( args . train_task_names ) :
args . train_iterations = len ( args . train_task_names ) * args . train_iterations
if len ( args . train_batch_tokens ) < len ( args . train_task_names ) :
args . train_batch_tokens = len ( args . train_task_names ) * args . train_batch_tokens
if len ( args . val_batch_size ) < len ( args . val_task_names ) :
args . val_batch_size = len ( args . val_task_names ) * args . val_batch_size
2018-06-20 06:22:34 +00:00
# postprocess arguments
2018-06-27 21:18:45 +00:00
if args . commit :
args . commit = get_commit ( )
else :
args . commit = ' '
2019-03-01 16:30:35 +00:00
2020-01-19 01:44:11 +00:00
if args . rnn_dimension is None :
args . rnn_dimension = args . dimension
2019-03-01 16:30:35 +00:00
args . log_dir = args . save
2019-12-15 01:21:23 +00:00
if args . tensorboard_dir is None :
args . tensorboard_dir = args . log_dir
2018-06-20 06:22:34 +00:00
args . dist_sync_file = os . path . join ( args . log_dir , ' distributed_sync_file ' )
2018-11-30 00:19:13 +00:00
2019-01-10 21:24:43 +00:00
for x in [ ' data ' , ' save ' , ' embeddings ' , ' log_dir ' , ' dist_sync_file ' ] :
2019-01-08 02:05:55 +00:00
setattr ( args , x , os . path . join ( args . root , getattr ( args , x ) ) )
2018-06-20 06:22:34 +00:00
save_args ( args )
2019-03-19 17:06:20 +00:00
# create the task objects after we saved the configuration to the JSON file, because
# tasks are not JSON serializable
2019-03-19 18:14:32 +00:00
args . train_tasks = get_tasks ( args . train_task_names , args )
args . val_tasks = get_tasks ( args . val_task_names , args )
2019-03-19 17:06:20 +00:00
2018-06-20 06:22:34 +00:00
return args