added on_hpc_load and on_hpc_save hooks
This commit is contained in:
parent
cd11b7de98
commit
f257c080c0
|
@ -21,6 +21,21 @@ class ModelIO(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def on_hpc_save(self):
|
||||
"""
|
||||
Hook to do whatever you need right before Slurm manager saves the model
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_hpc_load(self):
|
||||
"""
|
||||
Hook to do whatever you need right before Slurm manager loads the model
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class TrainerIO(object):
|
||||
|
||||
|
@ -116,6 +131,9 @@ class TrainerIO(object):
|
|||
os.makedirs(folderpath, exist_ok=True)
|
||||
filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, ckpt_number)
|
||||
|
||||
# give model a chance to do something on hpc_save
|
||||
self.on_hpc_save()
|
||||
|
||||
# request what to save from the model
|
||||
checkpoint_dict = self.dump_checkpoint()
|
||||
|
||||
|
@ -137,6 +155,9 @@ class TrainerIO(object):
|
|||
model = self.model.module if type(self.model) is LightningDataParallel else self.model
|
||||
model.load_model_specific(checkpoint)
|
||||
|
||||
# call model hook
|
||||
self.on_hpc_load()
|
||||
|
||||
def max_ckpt_in_folder(self, path):
|
||||
files = os.listdir(path)
|
||||
files = [x for x in files if 'ckpt_' in x]
|
||||
|
|
Loading…
Reference in New Issue