fixing checkpoint saving procedure
This commit is contained in:
parent
6e8d5798d9
commit
60dbc94028
|
@ -18,6 +18,8 @@ models/.DS_Store
|
|||
multiprocess/.DS_Store
|
||||
text/.DS_Store
|
||||
src/
|
||||
decaNLP_models/
|
||||
workdir/
|
||||
|
||||
|
||||
# C extensions
|
||||
|
|
|
@ -185,7 +185,7 @@ def train(args, model, opt, train_iters, train_iterations, field, rank=0, world_
|
|||
local_train_metric_dict = {}
|
||||
|
||||
train_iters = [(task, iter(train_iter)) for task, train_iter in train_iters]
|
||||
saver = Saver(args.log_dir, args.max_to_keep)
|
||||
saver = Saver(args.log_dir, world_size, args.max_to_keep)
|
||||
|
||||
while True:
|
||||
# For some number of rounds, we 'jump start' some subset of the tasks
|
||||
|
@ -242,30 +242,29 @@ def train(args, model, opt, train_iters, train_iterations, field, rank=0, world_
|
|||
logger.info(f'{args.timestamp}:{elapsed_time(logger)}:iteration_{iteration}:{round_progress}train_{task}:{task_progress}val_deca:deca_{deca_score:.2f}')
|
||||
|
||||
# saving
|
||||
if save_every is not None and (iteration % args.save_every == 0 % args.save_every):
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
if save_every is not None and (iteration % args.save_every == 0):
|
||||
if rank is not None and rank == 0:
|
||||
should_save_best = False
|
||||
if deca_score is not None and (best_decascore is None or best_decascore < deca_score):
|
||||
best_decascore = deca_score
|
||||
should_save_best = True
|
||||
|
||||
save_state_dict = {'model_state_dict': {k: v.cpu() for k, v in model.state_dict().items()}, 'field': field,
|
||||
save_model_state_dict = {'model_state_dict': {k: v.cpu() for k, v in model.state_dict().items()}, 'field': field,
|
||||
'best_decascore': best_decascore}
|
||||
|
||||
saver.save(save_state_dict, global_step=iteration)
|
||||
save_opt_state_dict = opt.state_dict()
|
||||
save_opt_state_dict.update({'start_iteration': iteration})
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
saver.save(save_model_state_dict, save_opt_state_dict, global_step=iteration)
|
||||
if should_save_best:
|
||||
logger.info(f'{args.timestamp}:{elapsed_time(logger)}:iteration_{iteration}:{round_progress}train_{task}:{task_progress}found new best model')
|
||||
torch.save(save_state_dict, os.path.join(args.log_dir, 'best.pth'))
|
||||
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
torch.save(opt.state_dict(), os.path.join(args.log_dir, f'iteration_{iteration}_rank_{rank}_optim.pth'))
|
||||
torch.distributed.barrier()
|
||||
|
||||
torch.save(save_model_state_dict, os.path.join(args.log_dir, 'best.pth'))
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
torch.save(save_opt_state_dict, os.path.join(args.log_dir, 'best_optim.pth'))
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
|
||||
# lr update
|
||||
lr = opt.param_groups[0]['lr']
|
||||
|
@ -354,9 +353,12 @@ def run(args, run_args, rank=0, world_size=1):
|
|||
save_dict = torch.load(os.path.join(args.save, args.load))
|
||||
model.load_state_dict(save_dict['model_state_dict'])
|
||||
if args.resume:
|
||||
logger.info(f'Resuming Training from {os.path.splitext(args.load)[0]}_rank_{rank}_optim.pth')
|
||||
opt.load_state_dict(torch.load(os.path.join(args.save, f'{os.path.splitext(args.load)[0]}_rank_{rank}_optim.pth')))
|
||||
start_iteration = int(os.path.splitext(os.path.basename(args.load))[0].split('_')[1])
|
||||
logger.info(f'Resuming Training from {os.path.splitext(args.load)[0]}_optim.pth')
|
||||
opt_state_dict = torch.load(os.path.join(args.save, f'{os.path.splitext(args.load)[0]}_optim.pth'))
|
||||
start_iteration = opt_state_dict.pop('start_iteration')
|
||||
logger.info(f'Starting iteration is {start_iteration}')
|
||||
opt.load_state_dict(opt_state_dict)
|
||||
# start_iteration = int(os.path.splitext(os.path.basename(args.load))[0].split('_')[1])
|
||||
|
||||
logger.info(f'Begin Training')
|
||||
train(args, model, opt, train_iters, args.train_iterations, field, val_iters=val_iters,
|
||||
|
|
|
@ -47,9 +47,10 @@ class Saver(object):
|
|||
and creating checkpoint files to keep track of which saves are valid and which are not.
|
||||
'''
|
||||
|
||||
def __init__(self, savedir, max_to_keep=5):
|
||||
def __init__(self, savedir, world_size, max_to_keep=5):
|
||||
self._savedir = savedir
|
||||
self._max_to_keep = max_to_keep
|
||||
self.world_size = world_size
|
||||
assert max_to_keep >= 1
|
||||
|
||||
self._loaded_last_checkpoints = False
|
||||
|
@ -71,20 +72,30 @@ class Saver(object):
|
|||
self._all_checkpoints = []
|
||||
self._latest_checkpoint = None
|
||||
|
||||
def save(self, save_dict, global_step):
|
||||
def save(self, save_model_state_dict, save_opt_state_dict, global_step):
|
||||
self._maybe_load_last_checkpoints()
|
||||
|
||||
filename = 'iteration_' + str(global_step) + '.pth'
|
||||
abspath = os.path.join(self._savedir, filename)
|
||||
|
||||
self._latest_checkpoint = filename
|
||||
self._all_checkpoints.append(filename)
|
||||
|
||||
model_name = 'iteration_' + str(global_step) + '.pth'
|
||||
opt_name = 'iteration_' + str(global_step) + '_optim.pth'
|
||||
|
||||
|
||||
self._latest_checkpoint = model_name
|
||||
self._all_checkpoints.append(model_name)
|
||||
if len(self._all_checkpoints) > self._max_to_keep:
|
||||
try:
|
||||
todelete = self._all_checkpoints.pop(0)
|
||||
os.unlink(os.path.join(self._savedir, todelete))
|
||||
opt_todelete = todelete.rsplit('.', 1)[0] + '_optim.' + todelete.rsplit('.', 1)[1]
|
||||
os.unlink(os.path.join(self._savedir, opt_todelete))
|
||||
except (OSError, IOError) as e:
|
||||
logging.warn('Failed to delete old checkpoint: %s', e)
|
||||
torch.save(save_dict, abspath)
|
||||
logging.warning('Failed to delete old checkpoint: %s', e)
|
||||
if self.world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
torch.save(save_model_state_dict, os.path.join(self._savedir, model_name))
|
||||
if self.world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
torch.save(save_opt_state_dict, os.path.join(self._savedir, opt_name))
|
||||
if self.world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
with open(os.path.join(self._savedir, 'checkpoint.json'), 'w') as fp:
|
||||
json.dump(dict(all=self._all_checkpoints, latest=self._latest_checkpoint), fp)
|
||||
|
|
Loading…
Reference in New Issue