Clear reference to training loss at the end of train step (#9336)
Without clearing this reference, the loss tensor stays live through the next training step. This can be a problem for memory intensive models that produce very deep backward graphs such as neural ODEs. For these models, keeping the backward graph of the previous loss in memory can lead to OOM errors in the next training step even though the step might have succeeded if we had cleared (and thus GC'd) the previous backward graph. Co-authored-by: tchaton <thomas@grid.ai> Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
6e124e7207
commit
98e2f56db0
|
@ -200,15 +200,18 @@ class LoggerConnector:
|
|||
"""
|
||||
|
||||
def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None:
|
||||
assert self.trainer._results is not None
|
||||
# when the user requests `dataloader_iter`, we can't track the batch_size
|
||||
# and this is left to user responsibility.
|
||||
if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher):
|
||||
assert self.trainer._results is not None
|
||||
self.trainer._results.extract_batch_size(split_batch)
|
||||
|
||||
self._batch_idx = batch_idx
|
||||
self._split_idx = split_idx
|
||||
|
||||
# clear reference to this step's training loss so that it can be garbage collected before the next training step
|
||||
self.trainer._results.minimize = None
|
||||
|
||||
def update_train_step_metrics(self) -> None:
|
||||
if self.trainer.fit_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization:
|
||||
return
|
||||
|
|
|
@ -188,3 +188,16 @@ def test_prepare_outputs(tmpdir):
|
|||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
|
||||
trainer.fit(model)
|
||||
assert model.on_train_batch_end_called == 2
|
||||
|
||||
|
||||
def test_batch_loop_releases_loss(tmpdir):
|
||||
"""Test that loss/graph is released so that it can be garbage collected before the next training step"""
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
assert self.trainer._results.minimize is None
|
||||
return super().training_step(batch, batch_idx)
|
||||
|
||||
model = TestModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue