diff --git a/pytorch_lightning/root_module/model_saving.py b/pytorch_lightning/root_module/model_saving.py index 3d150752ba..1ac91fd308 100644 --- a/pytorch_lightning/root_module/model_saving.py +++ b/pytorch_lightning/root_module/model_saving.py @@ -21,6 +21,21 @@ class ModelIO(object): """ raise NotImplementedError + def on_hpc_save(self): + """ + Hook to do whatever you need right before Slurm manager saves the model + :return: + """ + pass + + def on_hpc_load(self): + """ + Hook to do whatever you need right before Slurm manager loads the model + :return: + """ + pass + + class TrainerIO(object): @@ -116,6 +131,9 @@ class TrainerIO(object): os.makedirs(folderpath, exist_ok=True) filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, ckpt_number) + # give model a chance to do something on hpc_save + self.on_hpc_save() + # request what to save from the model checkpoint_dict = self.dump_checkpoint() @@ -137,6 +155,9 @@ class TrainerIO(object): model = self.model.module if type(self.model) is LightningDataParallel else self.model model.load_model_specific(checkpoint) + # call model hook + self.on_hpc_load() + def max_ckpt_in_folder(self, path): files = os.listdir(path) files = [x for x in files if 'ckpt_' in x]