move device-specific teardown logic from training loop to accelerator (#5973)
* on train end * switch order
This commit is contained in:
parent
ae4dca9725
commit
aa60c08641
|
@ -27,6 +27,7 @@ class GPUAccelerator(Accelerator):
|
||||||
|
|
||||||
def on_train_end(self):
|
def on_train_end(self):
|
||||||
# clean up memory
|
# clean up memory
|
||||||
|
self.model.cpu()
|
||||||
with torch.cuda.device(self.root_device):
|
with torch.cuda.device(self.root_device):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
|
@ -148,13 +148,7 @@ class TrainLoop:
|
||||||
self.trainer.profiler.describe()
|
self.trainer.profiler.describe()
|
||||||
|
|
||||||
# give accelerators a chance to finish
|
# give accelerators a chance to finish
|
||||||
self.trainer.accelerator_backend.on_train_end()
|
self.trainer.accelerator.on_train_end()
|
||||||
|
|
||||||
# clear mem
|
|
||||||
if self.trainer._device_type == DeviceType.GPU:
|
|
||||||
model = self.trainer.get_model()
|
|
||||||
model.cpu()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
def check_checkpoint_callback(self, should_update, is_last=False):
|
def check_checkpoint_callback(self, should_update, is_last=False):
|
||||||
# TODO bake this logic into the ModelCheckpoint callback
|
# TODO bake this logic into the ModelCheckpoint callback
|
||||||
|
|
Loading…
Reference in New Issue