diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 067ffe9db3..19f310344f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -447,6 +447,10 @@ class Trainer(TrainerIOMixin, self.evaluate(model, self.get_val_dataloaders(), self.nb_sanity_val_steps, self.testing) + # clear cache before training + if self.on_gpu: + torch.cuda.empty_cache() + # CORE TRAINING LOOP self.train() diff --git a/tests/test_restore_models.py b/tests/test_restore_models.py index 1e2938d676..afa21cfde8 100644 --- a/tests/test_restore_models.py +++ b/tests/test_restore_models.py @@ -327,7 +327,8 @@ def test_cpu_restore_training(): # set the epoch start hook so we can predict before the model does the full training def assert_good_acc(): - assert trainer.current_epoch == real_global_epoch and trainer.current_epoch > 0 + assert trainer.current_epoch > 0 + assert trainer.current_epoch == real_global_epoch # if model and state loaded correctly, predictions will be good even though we # haven't trained with the new loaded model