diff --git a/pytorch_lightning/root_module/model_saving.py b/pytorch_lightning/root_module/model_saving.py index 1e47e3e751..493aa05bcc 100644 --- a/pytorch_lightning/root_module/model_saving.py +++ b/pytorch_lightning/root_module/model_saving.py @@ -144,7 +144,8 @@ class TrainerIO(object): filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, ckpt_number) # give model a chance to do something on hpc_save - self.model.on_hpc_save() + model = self.model.module if type(self.model) is LightningDataParallel else self.model + model.on_hpc_save() # request what to save from the model checkpoint_dict = self.dump_checkpoint() @@ -168,7 +169,7 @@ class TrainerIO(object): model.load_model_specific(checkpoint) # call model hook - self.model.on_hpc_load() + model.on_hpc_load() def max_ckpt_in_folder(self, path): files = os.listdir(path)