From c6244594a630b033474ecfa9f64fbd977d2010e3 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 23 Oct 2019 11:41:00 -0400 Subject: [PATCH] clear memory cache before train starts (#418) * clear memory cache before train starts * clear memory cache before train starts --- pytorch_lightning/trainer/trainer.py | 4 ++++ tests/test_restore_models.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 067ffe9db3..19f310344f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -447,6 +447,10 @@ class Trainer(TrainerIOMixin, self.evaluate(model, self.get_val_dataloaders(), self.nb_sanity_val_steps, self.testing) + # clear cache before training + if self.on_gpu: + torch.cuda.empty_cache() + # CORE TRAINING LOOP self.train() diff --git a/tests/test_restore_models.py b/tests/test_restore_models.py index 1e2938d676..afa21cfde8 100644 --- a/tests/test_restore_models.py +++ b/tests/test_restore_models.py @@ -327,7 +327,8 @@ def test_cpu_restore_training(): # set the epoch start hook so we can predict before the model does the full training def assert_good_acc(): - assert trainer.current_epoch == real_global_epoch and trainer.current_epoch > 0 + assert trainer.current_epoch > 0 + assert trainer.current_epoch == real_global_epoch # if model and state loaded correctly, predictions will be good even though we # haven't trained with the new loaded model