move device-specific teardown logic from training loop to accelerator (#5973)
* on train end * switch order
This commit is contained in:
parent
ae4dca9725
commit
aa60c08641
|
@ -27,6 +27,7 @@ class GPUAccelerator(Accelerator):
|
|||
|
||||
def on_train_end(self):
|
||||
# clean up memory
|
||||
self.model.cpu()
|
||||
with torch.cuda.device(self.root_device):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
|
|
@ -148,13 +148,7 @@ class TrainLoop:
|
|||
self.trainer.profiler.describe()
|
||||
|
||||
# give accelerators a chance to finish
|
||||
self.trainer.accelerator_backend.on_train_end()
|
||||
|
||||
# clear mem
|
||||
if self.trainer._device_type == DeviceType.GPU:
|
||||
model = self.trainer.get_model()
|
||||
model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
self.trainer.accelerator.on_train_end()
|
||||
|
||||
def check_checkpoint_callback(self, should_update, is_last=False):
|
||||
# TODO bake this logic into the ModelCheckpoint callback
|
||||
|
|
Loading…
Reference in New Issue