diff --git a/pytorch_lightning/root_module/model_saving.py b/pytorch_lightning/root_module/model_saving.py index fc151b693b..b3b178e0be 100644 --- a/pytorch_lightning/root_module/model_saving.py +++ b/pytorch_lightning/root_module/model_saving.py @@ -2,7 +2,7 @@ import torch import os import re import pdb - +from pytorch_lightning.pt_overrides.override_data_parallel import LightningDataParallel class ModelIO(object): @@ -49,7 +49,8 @@ class TrainerIO(object): checkpoint['optimizer_states'] = optimizer_states # request what to save from the model - checkpoint_dict = self.model.get_save_dict() + model = self.model.module if type(self.model) is LightningDataParallel else self.model + checkpoint_dict = model.get_save_dict() # merge trainer and model saving items checkpoint.update(checkpoint_dict) @@ -130,7 +131,8 @@ class TrainerIO(object): self.restore_training_state(checkpoint) # load model state - self.model.load_model_specific(checkpoint) + model = self.model.module if type(self.model) is LightningDataParallel else self.model + model.load_model_specific(checkpoint) def max_ckpt_in_folder(self, path): files = os.listdir(path)