diff --git a/pytorch_lightning/root_module/model_saving.py b/pytorch_lightning/root_module/model_saving.py index dae284ea3c..0fca9161e6 100644 --- a/pytorch_lightning/root_module/model_saving.py +++ b/pytorch_lightning/root_module/model_saving.py @@ -2,7 +2,7 @@ import torch import os import re import pdb -from pytorch_lightning.pt_overrides.override_data_parallel import LightningDataParallel +from pytorch_lightning.pt_overrides.override_data_parallel import LightningDistributedDataParallel class ModelIO(object): @@ -66,7 +66,7 @@ class TrainerIO(object): checkpoint['optimizer_states'] = optimizer_states # request what to save from the model - model = self.model.module if type(self.model) is LightningDataParallel else self.model + model = self.model.module if type(self.model) is LightningDistributedDataParallel else self.model checkpoint_dict = model.get_save_dict() # merge trainer and model saving items