From bf0f5a5cbb3f4d8d38cc2dcfcc77097ea644f663 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 26 Jun 2019 18:12:33 -0400 Subject: [PATCH] removed self.model refs --- pytorch_lightning/root_module/model_saving.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/root_module/model_saving.py b/pytorch_lightning/root_module/model_saving.py index fc151b693b..b3b178e0be 100644 --- a/pytorch_lightning/root_module/model_saving.py +++ b/pytorch_lightning/root_module/model_saving.py @@ -2,7 +2,7 @@ import torch import os import re import pdb - +from pytorch_lightning.pt_overrides.override_data_parallel import LightningDataParallel class ModelIO(object): @@ -49,7 +49,8 @@ class TrainerIO(object): checkpoint['optimizer_states'] = optimizer_states # request what to save from the model - checkpoint_dict = self.model.get_save_dict() + model = self.model.module if type(self.model) is LightningDataParallel else self.model + checkpoint_dict = model.get_save_dict() # merge trainer and model saving items checkpoint.update(checkpoint_dict) @@ -130,7 +131,8 @@ class TrainerIO(object): self.restore_training_state(checkpoint) # load model state - self.model.load_model_specific(checkpoint) + model = self.model.module if type(self.model) is LightningDataParallel else self.model + model.load_model_specific(checkpoint) def max_ckpt_in_folder(self, path): files = os.listdir(path)