Clear reference to training loss at the end of train step (#9336)

Without clearing this reference, the loss tensor stays live through the next training
step. This can be a problem for memory intensive models that produce very deep backward
graphs such as neural ODEs. For these models, keeping the backward graph of the previous
loss in memory can lead to OOM errors in the next training step even though the step might
have succeeded if we had cleared (and thus GC'd) the previous backward graph.

Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Marten Lienen 2021-09-06 15:37:27 +02:00 committed by GitHub
parent 6e124e7207
commit 98e2f56db0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 1 deletions

View File

@ -200,15 +200,18 @@ class LoggerConnector:
"""
def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None:
assert self.trainer._results is not None
# when the user requests `dataloader_iter`, we can't track the batch_size
# and this is left to user responsibility.
if isinstance(split_batch, pl.utilities.fetching.DataLoaderIterDataFetcher):
assert self.trainer._results is not None
self.trainer._results.extract_batch_size(split_batch)
self._batch_idx = batch_idx
self._split_idx = split_idx
# clear reference to this step's training loss so that it can be garbage collected before the next training step
self.trainer._results.minimize = None
def update_train_step_metrics(self) -> None:
if self.trainer.fit_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization:
return

View File

@ -188,3 +188,16 @@ def test_prepare_outputs(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
trainer.fit(model)
assert model.on_train_batch_end_called == 2
def test_batch_loop_releases_loss(tmpdir):
"""Test that loss/graph is released so that it can be garbage collected before the next training step"""
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
assert self.trainer._results.minimize is None
return super().training_step(batch, batch_idx)
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
trainer.fit(model)