diff --git a/pytorch_lightning/root_module/model_saving.py b/pytorch_lightning/root_module/model_saving.py index 142d2b337b..361a08548b 100644 --- a/pytorch_lightning/root_module/model_saving.py +++ b/pytorch_lightning/root_module/model_saving.py @@ -181,6 +181,7 @@ class TrainerIO(object): # call model hook model.on_hpc_load(checkpoint) + self.model = model def max_ckpt_in_folder(self, path): files = os.listdir(path)