move device-specific teardown logic from training loop to accelerator (#5973)

* on train end

* switch order
This commit is contained in:
Adrian Wälchli 2021-02-15 23:38:03 +01:00 committed by GitHub
parent ae4dca9725
commit aa60c08641
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 7 deletions

View File

@ -27,6 +27,7 @@ class GPUAccelerator(Accelerator):
def on_train_end(self):
# clean up memory
self.model.cpu()
with torch.cuda.device(self.root_device):
torch.cuda.empty_cache()

View File

@ -148,13 +148,7 @@ class TrainLoop:
self.trainer.profiler.describe()
# give accelerators a chance to finish
self.trainer.accelerator_backend.on_train_end()
# clear mem
if self.trainer._device_type == DeviceType.GPU:
model = self.trainer.get_model()
model.cpu()
torch.cuda.empty_cache()
self.trainer.accelerator.on_train_end()
def check_checkpoint_callback(self, should_update, is_last=False):
# TODO bake this logic into the ModelCheckpoint callback