diff --git a/pytorch_lightning/root_module/model_saving.py b/pytorch_lightning/root_module/model_saving.py index 493aa05bcc..557b5d7d79 100644 --- a/pytorch_lightning/root_module/model_saving.py +++ b/pytorch_lightning/root_module/model_saving.py @@ -41,6 +41,11 @@ class ModelIO(object): class TrainerIO(object): + def __get_model(self): + is_dp_module = type(self.model) is LightningDistributedDataParallel or type(self.model) is LightningDataParallel + model = self.model.module if is_dp_module else self.model + return model + # -------------------- # MODEL SAVE CHECKPOINT # -------------------- @@ -71,8 +76,7 @@ class TrainerIO(object): checkpoint['optimizer_states'] = optimizer_states # request what to save from the model - is_dp_module = type(self.model) is LightningDistributedDataParallel or type(self.model) is LightningDataParallel - model = self.model.module if is_dp_module else self.model + model = self.__get_model() checkpoint_dict = model.get_save_dict() # merge trainer and model saving items @@ -144,7 +148,7 @@ class TrainerIO(object): filepath = '{}/hpc_ckpt_{}.ckpt'.format(folderpath, ckpt_number) # give model a chance to do something on hpc_save - model = self.model.module if type(self.model) is LightningDataParallel else self.model + model = self.__get_model() model.on_hpc_save() # request what to save from the model @@ -165,7 +169,7 @@ class TrainerIO(object): self.restore_training_state(checkpoint) # load model state - model = self.model.module if type(self.model) is LightningDataParallel else self.model + model = self.__get_model() model.load_model_specific(checkpoint) # call model hook