diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 9ec6ad5cde..53f9388d83 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -27,6 +27,7 @@ class GPUAccelerator(Accelerator): def on_train_end(self): # clean up memory + self.model.cpu() with torch.cuda.device(self.root_device): torch.cuda.empty_cache() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f727a15310..f7f44625a3 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -148,13 +148,7 @@ class TrainLoop: self.trainer.profiler.describe() # give accelerators a chance to finish - self.trainer.accelerator_backend.on_train_end() - - # clear mem - if self.trainer._device_type == DeviceType.GPU: - model = self.trainer.get_model() - model.cpu() - torch.cuda.empty_cache() + self.trainer.accelerator.on_train_end() def check_checkpoint_callback(self, should_update, is_last=False): # TODO bake this logic into the ModelCheckpoint callback