diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 9c4b7d5fe2..ff11150b26 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -610,14 +610,18 @@ class Trainer(TrainerIO): if self.proc_rank == 0: self.experiment.save() + # track model now. + # if cluster resets state, the model will update with the saved weights + self.model = model + # enable cluster checkpointing + # also restores training state if self.cluster is not None: # pragma: no cover self.enable_auto_hpc_walltime_manager() # --------------------------- # CORE TRAINING LOOP # --------------------------- - self.model = model self.__train() def __train(self): diff --git a/tests/test_models.py b/tests/test_models.py index 5cee2e5ef6..e95870e8c4 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -85,15 +85,22 @@ def test_cpu_slurm_save_load(): checkpoint_callback=ModelCheckpoint(save_dir) ) trainer = Trainer(**trainer_options) + model = LightningTestModel(hparams) - # test HPC loading - trainer.hpc_load(save_dir, on_gpu=False) - assert trainer.global_step == real_global_step and trainer.global_step > 0 + # set the epoch start hook so we can predict before the model does the full training + def assert_pred_same(): + assert trainer.global_step == real_global_step and trainer.global_step > 0 - # predict with loaded model to make sure answers are the same - trainer.model.eval() - new_pred = trainer.model(x) - assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1 + # predict with loaded model to make sure answers are the same + trainer.model.eval() + new_pred = trainer.model(x) + assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1 + + model.on_epoch_start = assert_pred_same + + # by calling fit again, we trigger training, loading weights from the cluster + # and our hook to predict using current model before any more weight updates + trainer.fit(model) clear_save_dir()