From e0e04702ccf4b2c5af5504e0a37bbf66a1ed6754 Mon Sep 17 00:00:00 2001 From: Bryan Marcus McCann Date: Wed, 27 Jun 2018 21:18:45 +0000 Subject: [PATCH] rm dependency on .git; multi-gpu WIP; fine-grained time --- arguments.py | 13 ++++++++++--- train.py | 9 ++++++++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/arguments.py b/arguments.py index 49c96d91..1c4069fd 100644 --- a/arguments.py +++ b/arguments.py @@ -71,9 +71,10 @@ def parse(): parser.add_argument('--resume', action='store_true', help='whether to resume training with past optimizers') parser.add_argument('--seed', default=123, type=int, help='Random seed.') - parser.add_argument('--gpus', nargs='+', type=int, help='gpus to use') + parser.add_argument('--gpus', nargs='+', type=int, help='a list of gpus that can be used for training (multi-gpu currently WIP)') parser.add_argument('--backend', default='gloo', type=str, help='backend for distributed training') + parser.add_argument('--no_commit', action='store_false', dest='commit', help='do not track the git commit associated with this training run') parser.add_argument('--exist_ok', action='store_true', help='Ok if the save directory already exists, i.e. overwrite is ok') parser.add_argument('--token_testing', action='store_true', help='if true, sorts all iterators') parser.add_argument('--reverse', action='store_true', help='if token_testing and true, sorts all iterators in reverse') @@ -86,7 +87,10 @@ def parse(): if 'imdb' in args.val_tasks: args.val_tasks.remove('imdb') args.world_size = len(args.gpus) if args.gpus[0] > -1 else -1 - args.timestamp = '-'.join(datetime.datetime.now(tz=tz.tzoffset(None, -8*60*60)).strftime("%y/%m/%d/%H/%M").split()) + if args.world_size > 1: + print('multi-gpu training is currently a work in progress') + return + args.timestamp = '-'.join(datetime.datetime.now(tz=tz.tzoffset(None, -8*60*60)).strftime("%y/%m/%d/%H/%M/%S.%f").split()) if len(args.train_tasks) > 1: if args.train_iterations is None: @@ -99,7 +103,10 @@ def parse(): args.val_batch_size = len(args.val_tasks) * args.val_batch_size # postprocess arguments - args.commit = get_commit() + if args.commit: + args.commit = get_commit() + else: + args.commit = '' train_out = f'{",".join(args.train_tasks)}' if len(args.train_tasks) > 1: train_out += f'{"-".join([str(x) for x in args.train_iterations])}' diff --git a/train.py b/train.py index b6e88c73..9dcb767d 100644 --- a/train.py +++ b/train.py @@ -336,8 +336,10 @@ def init_opt(args, model): return opt -if __name__ == '__main__': +def main(): args = arguments.parse() + if args is None: + return set_seed(args) logger = initialize_logger(args) logger.info(f'Arguments:\n{pformat(vars(args))}') @@ -357,3 +359,8 @@ if __name__ == '__main__': else: logger.info(f'Processing') run(args, run_args, world_size=args.world_size) + + +if __name__ == '__main__': + main() +