added on_hpc_load and on_hpc_save hooks

This commit is contained in:
William Falcon 2019-07-02 09:35:15 -04:00
parent cd11b7de98
commit f257c080c0
1 changed files with 21 additions and 0 deletions

View File

@ -21,6 +21,21 @@ class ModelIO(object):
"""
raise NotImplementedError
def on_hpc_save(self):
"""
Hook to do whatever you need right before Slurm manager saves the model
:return:
"""
pass
def on_hpc_load(self):
"""
Hook to do whatever you need right before Slurm manager loads the model
:return:
"""
pass
class TrainerIO(object):
@ -116,6 +131,9 @@ class TrainerIO(object):
os.makedirs(folderpath, exist_ok=True)
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, ckpt_number)
# give model a chance to do something on hpc_save
self.on_hpc_save()
# request what to save from the model
checkpoint_dict = self.dump_checkpoint()
@ -137,6 +155,9 @@ class TrainerIO(object):
model = self.model.module if type(self.model) is LightningDataParallel else self.model
model.load_model_specific(checkpoint)
# call model hook
self.on_hpc_load()
def max_ckpt_in_folder(self, path):
files = os.listdir(path)
files = [x for x in files if 'ckpt_' in x]