From 60dbc940281fbdf38b557e7eb7fa1913720edafe Mon Sep 17 00:00:00 2001 From: mehrad Date: Fri, 8 Mar 2019 18:52:46 -0800 Subject: [PATCH] fixing checkpoint saving procedure --- .gitignore | 2 ++ decanlp/train.py | 40 +++++++++++++++++++++------------------- decanlp/utils/saver.py | 31 +++++++++++++++++++++---------- 3 files changed, 44 insertions(+), 29 deletions(-) diff --git a/.gitignore b/.gitignore index c4452af0..0a79e959 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,8 @@ models/.DS_Store multiprocess/.DS_Store text/.DS_Store src/ +decaNLP_models/ +workdir/ # C extensions diff --git a/decanlp/train.py b/decanlp/train.py index 6061d6cf..f6c0798b 100644 --- a/decanlp/train.py +++ b/decanlp/train.py @@ -185,7 +185,7 @@ def train(args, model, opt, train_iters, train_iterations, field, rank=0, world_ local_train_metric_dict = {} train_iters = [(task, iter(train_iter)) for task, train_iter in train_iters] - saver = Saver(args.log_dir, args.max_to_keep) + saver = Saver(args.log_dir, world_size, args.max_to_keep) while True: # For some number of rounds, we 'jump start' some subset of the tasks @@ -242,30 +242,29 @@ def train(args, model, opt, train_iters, train_iterations, field, rank=0, world_ logger.info(f'{args.timestamp}:{elapsed_time(logger)}:iteration_{iteration}:{round_progress}train_{task}:{task_progress}val_deca:deca_{deca_score:.2f}') # saving - if save_every is not None and (iteration % args.save_every == 0 % args.save_every): - - if world_size > 1: - torch.distributed.barrier() + if save_every is not None and (iteration % args.save_every == 0): if rank is not None and rank == 0: should_save_best = False if deca_score is not None and (best_decascore is None or best_decascore < deca_score): best_decascore = deca_score should_save_best = True - save_state_dict = {'model_state_dict': {k: v.cpu() for k, v in model.state_dict().items()}, 'field': field, + save_model_state_dict = {'model_state_dict': {k: v.cpu() for k, v in model.state_dict().items()}, 'field': field, 'best_decascore': best_decascore} - - saver.save(save_state_dict, global_step=iteration) + save_opt_state_dict = opt.state_dict() + save_opt_state_dict.update({'start_iteration': iteration}) + + if world_size > 1: + torch.distributed.barrier() + saver.save(save_model_state_dict, save_opt_state_dict, global_step=iteration) if should_save_best: logger.info(f'{args.timestamp}:{elapsed_time(logger)}:iteration_{iteration}:{round_progress}train_{task}:{task_progress}found new best model') - torch.save(save_state_dict, os.path.join(args.log_dir, 'best.pth')) - - - if world_size > 1: - torch.distributed.barrier() - torch.save(opt.state_dict(), os.path.join(args.log_dir, f'iteration_{iteration}_rank_{rank}_optim.pth')) - torch.distributed.barrier() - + torch.save(save_model_state_dict, os.path.join(args.log_dir, 'best.pth')) + if world_size > 1: + torch.distributed.barrier() + torch.save(save_opt_state_dict, os.path.join(args.log_dir, 'best_optim.pth')) + if world_size > 1: + torch.distributed.barrier() # lr update lr = opt.param_groups[0]['lr'] @@ -354,9 +353,12 @@ def run(args, run_args, rank=0, world_size=1): save_dict = torch.load(os.path.join(args.save, args.load)) model.load_state_dict(save_dict['model_state_dict']) if args.resume: - logger.info(f'Resuming Training from {os.path.splitext(args.load)[0]}_rank_{rank}_optim.pth') - opt.load_state_dict(torch.load(os.path.join(args.save, f'{os.path.splitext(args.load)[0]}_rank_{rank}_optim.pth'))) - start_iteration = int(os.path.splitext(os.path.basename(args.load))[0].split('_')[1]) + logger.info(f'Resuming Training from {os.path.splitext(args.load)[0]}_optim.pth') + opt_state_dict = torch.load(os.path.join(args.save, f'{os.path.splitext(args.load)[0]}_optim.pth')) + start_iteration = opt_state_dict.pop('start_iteration') + logger.info(f'Starting iteration is {start_iteration}') + opt.load_state_dict(opt_state_dict) + # start_iteration = int(os.path.splitext(os.path.basename(args.load))[0].split('_')[1]) logger.info(f'Begin Training') train(args, model, opt, train_iters, args.train_iterations, field, val_iters=val_iters, diff --git a/decanlp/utils/saver.py b/decanlp/utils/saver.py index 297debc2..252fbea6 100644 --- a/decanlp/utils/saver.py +++ b/decanlp/utils/saver.py @@ -47,9 +47,10 @@ class Saver(object): and creating checkpoint files to keep track of which saves are valid and which are not. ''' - def __init__(self, savedir, max_to_keep=5): + def __init__(self, savedir, world_size, max_to_keep=5): self._savedir = savedir self._max_to_keep = max_to_keep + self.world_size = world_size assert max_to_keep >= 1 self._loaded_last_checkpoints = False @@ -71,20 +72,30 @@ class Saver(object): self._all_checkpoints = [] self._latest_checkpoint = None - def save(self, save_dict, global_step): + def save(self, save_model_state_dict, save_opt_state_dict, global_step): self._maybe_load_last_checkpoints() - - filename = 'iteration_' + str(global_step) + '.pth' - abspath = os.path.join(self._savedir, filename) - - self._latest_checkpoint = filename - self._all_checkpoints.append(filename) + + model_name = 'iteration_' + str(global_step) + '.pth' + opt_name = 'iteration_' + str(global_step) + '_optim.pth' + + + self._latest_checkpoint = model_name + self._all_checkpoints.append(model_name) if len(self._all_checkpoints) > self._max_to_keep: try: todelete = self._all_checkpoints.pop(0) os.unlink(os.path.join(self._savedir, todelete)) + opt_todelete = todelete.rsplit('.', 1)[0] + '_optim.' + todelete.rsplit('.', 1)[1] + os.unlink(os.path.join(self._savedir, opt_todelete)) except (OSError, IOError) as e: - logging.warn('Failed to delete old checkpoint: %s', e) - torch.save(save_dict, abspath) + logging.warning('Failed to delete old checkpoint: %s', e) + if self.world_size > 1: + torch.distributed.barrier() + torch.save(save_model_state_dict, os.path.join(self._savedir, model_name)) + if self.world_size > 1: + torch.distributed.barrier() + torch.save(save_opt_state_dict, os.path.join(self._savedir, opt_name)) + if self.world_size > 1: + torch.distributed.barrier() with open(os.path.join(self._savedir, 'checkpoint.json'), 'w') as fp: json.dump(dict(all=self._all_checkpoints, latest=self._latest_checkpoint), fp)