added auto restore

This commit is contained in:
William Falcon 2019-08-07 06:55:05 -04:00
parent a79de1ec8e
commit d3f19c8321
2 changed files with 46 additions and 3 deletions

View File

@ -244,6 +244,30 @@ class Trainer(TrainerIO):
'''
raise ModuleNotFoundError(msg)
# restore training and model
self.restore_state_if_existing_checkpoint()
def restore_state_if_existing_checkpoint(self):
# restore trainer state and model if there is a weight for this experiment
last_epoch = -1
last_ckpt_name = None
# find last epoch
checkpoints = os.listdir(self.checkpoint_callback.filepath)
for name in checkpoints:
if '.ckpt' in name:
epoch = name.split('epoch_')[1]
epoch = re.sub('[^0-9]', '' ,epoch)
if epoch > last_epoch:
last_epoch = epoch
last_ckpt_name = name
# restore last checkpoint
last_ckpt_path = os.path.join(self.checkpoint_callback.filepath, last_ckpt_name)
self.restore(last_ckpt_path, self.on_gpu)
print(f'model and trainer restored from checkpoint: {last_ckpt_path}')
@property
def data_parallel(self):
return self.use_dp or self.use_ddp

View File

@ -58,6 +58,25 @@ class TrainerIO(object):
# do the actual save
torch.save(checkpoint, filepath)
def restore(self, checkpoint_path, on_gpu):
if on_gpu:
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
# load training state (affects trainer only)
self.restore_training_state(checkpoint)
# load model state
model = self.__get_model()
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])
# call model hook
model.on_hpc_load(checkpoint)
def dump_checkpoint(self):
checkpoint = {
@ -198,15 +217,15 @@ class TrainerIO(object):
# call model hook
model.on_hpc_load(checkpoint)
def max_ckpt_in_folder(self, path):
def max_ckpt_in_folder(self, path, name_key='ckpt_'):
files = os.listdir(path)
files = [x for x in files if 'ckpt_' in x]
files = [x for x in files if name_key in x]
if len(files) == 0:
return 0
ckpt_vs = []
for name in files:
name = name.split('ckpt_')[-1]
name = name.split(name_key)[-1]
name = re.sub('[^0-9]', '', name)
ckpt_vs.append(int(name))