From a63f74281a12ebd947e5161b8d599c14840884e6 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 24 Jul 2019 18:03:19 -0400 Subject: [PATCH] fixed correct module on hpc save --- pytorch_lightning/root_module/model_saving.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/root_module/model_saving.py b/pytorch_lightning/root_module/model_saving.py index 493aa05bcc..557b5d7d79 100644 --- a/pytorch_lightning/root_module/model_saving.py +++ b/pytorch_lightning/root_module/model_saving.py @@ -41,6 +41,11 @@ class ModelIO(object): class TrainerIO(object): + def __get_model(self): + is_dp_module = type(self.model) is LightningDistributedDataParallel or type(self.model) is LightningDataParallel + model = self.model.module if is_dp_module else self.model + return model + # -------------------- # MODEL SAVE CHECKPOINT # -------------------- @@ -71,8 +76,7 @@ class TrainerIO(object): checkpoint['optimizer_states'] = optimizer_states # request what to save from the model - is_dp_module = type(self.model) is LightningDistributedDataParallel or type(self.model) is LightningDataParallel - model = self.model.module if is_dp_module else self.model + model = self.__get_model() checkpoint_dict = model.get_save_dict() # merge trainer and model saving items @@ -144,7 +148,7 @@ class TrainerIO(object): filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, ckpt_number) # give model a chance to do something on hpc_save - model = self.model.module if type(self.model) is LightningDataParallel else self.model + model = self.__get_model() model.on_hpc_save() # request what to save from the model @@ -165,7 +169,7 @@ class TrainerIO(object): self.restore_training_state(checkpoint) # load model state - model = self.model.module if type(self.model) is LightningDataParallel else self.model + model = self.__get_model() model.load_model_specific(checkpoint) # call model hook