lightning/pytorch_lightning/loops/batch/training_batch_loop.py

141 lines
6.0 KiB
Python
Raw Normal View History

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2021-11-30 20:28:55 +00:00
from typing import Any, List, Optional, Tuple, Union
from deprecate import void
from torch import Tensor
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.optimization.manual_loop import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE
2021-09-10 13:18:24 +00:00
from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization
from pytorch_lightning.loops.optimization.optimizer_loop import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE
2021-09-10 13:18:24 +00:00
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop
from pytorch_lightning.loops.utilities import _get_active_optimizers
from pytorch_lightning.trainer.supporters import TensorRunningAccum
_OUTPUTS_TYPE = List[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]]
class TrainingBatchLoop(Loop[_OUTPUTS_TYPE]):
"""Runs over a single batch of data."""
def __init__(self) -> None:
super().__init__()
2021-11-30 20:28:55 +00:00
self.accumulated_loss = TensorRunningAccum(window_length=20)
self.running_loss = TensorRunningAccum(window_length=20)
# the current split index when the batch gets split into chunks in truncated backprop through time
self.split_idx: int = 0
2021-09-02 11:40:05 +00:00
self.optimizer_loop = OptimizerLoop()
self.manual_loop = ManualOptimization()
self._outputs: _OUTPUTS_TYPE = []
2021-11-30 20:28:55 +00:00
self._remaining_splits: List[Tuple[int, Any]] = []
@property
def done(self) -> bool:
"""Returns if all batch splits have been processed already."""
return len(self._remaining_splits) == 0
2021-11-30 20:28:55 +00:00
def connect( # type: ignore[override]
self, optimizer_loop: Optional[OptimizerLoop] = None, manual_loop: Optional[ManualOptimization] = None
) -> None:
if optimizer_loop is not None:
self.optimizer_loop = optimizer_loop
if manual_loop is not None:
self.manual_loop = manual_loop
def reset(self) -> None:
"""Resets the loop state."""
self._outputs = []
2021-11-30 20:28:55 +00:00
def on_run_start(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
"""Splits the data into tbptt splits.
Args:
batch: the current batch to run the trainstep on
batch_idx: the index of the current batch
"""
void(batch_idx)
self._remaining_splits = list(enumerate(self._tbptt_split_batch(batch)))
2021-11-30 20:28:55 +00:00
def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
"""Runs the train step together with optimization (if necessary) on the current batch split.
Args:
batch: the current batch to run the training on (this is not the split!)
batch_idx: the index of the current batch
"""
void(batch)
self.split_idx, split_batch = self._remaining_splits.pop(0)
self.trainer._logger_connector.on_train_split_start(self.split_idx)
2021-11-30 20:28:55 +00:00
outputs: Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] = None # for mypy
# choose which loop will run the optimization
if self.trainer.lightning_module.automatic_optimization:
optimizers = _get_active_optimizers(self.trainer.optimizers, self.trainer.optimizer_frequencies, batch_idx)
outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
else:
outputs = self.manual_loop.run(split_batch, batch_idx)
if outputs:
# automatic: can be empty if all optimizers skip their batches
# manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens,
# then `advance` doesn't finish and an empty dict is returned
self._outputs.append(outputs)
def on_run_end(self) -> _OUTPUTS_TYPE:
self.optimizer_loop._hiddens = None
# this is not necessary as the manual loop runs for only 1 iteration, but just in case
self.manual_loop._hiddens = None
2021-11-30 20:28:55 +00:00
output, self._outputs = self._outputs, [] # free memory
self._remaining_splits = []
return output
move batch to device before sending it to hooks (#7378) * update train step * test * x * limits * val * typeo * x * x * step * min gpus * run all loops * x * limit test * profiler * clean up accelerator code * move files * rename * move tests * changelog * reorder callbacks and model hooks * add test description * replace unneccessary method * fix chlog * adjust batch_to_device for DP Plugin * update tests for dataloader idx * unused imports * hook change * switch None * clear memory * change to None * None * None * memory savings * remove redundant todo * hack * cheat * Revert "cheat" This reverts commit a8433bd0b4bd35f218993335f7d4ff18977ae423. * Revert "hack" This reverts commit 43a6d1edeb62a15ac69ef69ef2352581ba1947a5. * update new epoch loop * remove from old loop code * update chlog * update hook test * changelog * teardown * integrate changes in new eval loop * fix hook calls * add prediction step * bad merge * Revert "bad merge" This reverts commit 488080863cf012dcf04446be3b7d973b7340687e. * fix train batch hook test * rm -rf _notebooks * update chlog * release memory * fix type * notebooks mess * debug * Revert "debug" This reverts commit eec4ee2f77b5eb39965211a250598ed5d2320e88. * teardown * fix teardown bug * debug * x * debug * Revert "debug" This reverts commit a6e61019462b80d09d31b65bed289fa6e4dd15f6. Revert "debug" This reverts commit 5ddeaec06911e96730aade1be6ee71d097b46b9a. debug debug Revert "debug" This reverts commit 605be746f7daedf265b2c05a1c153ce543394435. Revert "Revert "debug"" This reverts commit a7612d5410409ed886cfb609457349ecf44cbfa8. debug x x x s tol x tol * Fix changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
2021-07-05 08:31:39 +00:00
def teardown(self) -> None:
2021-11-30 20:28:55 +00:00
self.optimizer_loop.teardown()
self.manual_loop.teardown()
# release memory
if self.accumulated_loss.memory is not None:
self.accumulated_loss.memory = self.accumulated_loss.memory.cpu()
if self.running_loss.memory is not None:
self.running_loss.memory = self.running_loss.memory.cpu()
move batch to device before sending it to hooks (#7378) * update train step * test * x * limits * val * typeo * x * x * step * min gpus * run all loops * x * limit test * profiler * clean up accelerator code * move files * rename * move tests * changelog * reorder callbacks and model hooks * add test description * replace unneccessary method * fix chlog * adjust batch_to_device for DP Plugin * update tests for dataloader idx * unused imports * hook change * switch None * clear memory * change to None * None * None * memory savings * remove redundant todo * hack * cheat * Revert "cheat" This reverts commit a8433bd0b4bd35f218993335f7d4ff18977ae423. * Revert "hack" This reverts commit 43a6d1edeb62a15ac69ef69ef2352581ba1947a5. * update new epoch loop * remove from old loop code * update chlog * update hook test * changelog * teardown * integrate changes in new eval loop * fix hook calls * add prediction step * bad merge * Revert "bad merge" This reverts commit 488080863cf012dcf04446be3b7d973b7340687e. * fix train batch hook test * rm -rf _notebooks * update chlog * release memory * fix type * notebooks mess * debug * Revert "debug" This reverts commit eec4ee2f77b5eb39965211a250598ed5d2320e88. * teardown * fix teardown bug * debug * x * debug * Revert "debug" This reverts commit a6e61019462b80d09d31b65bed289fa6e4dd15f6. Revert "debug" This reverts commit 5ddeaec06911e96730aade1be6ee71d097b46b9a. debug debug Revert "debug" This reverts commit 605be746f7daedf265b2c05a1c153ce543394435. Revert "Revert "debug"" This reverts commit a7612d5410409ed886cfb609457349ecf44cbfa8. debug x x x s tol x tol * Fix changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
2021-07-05 08:31:39 +00:00
def _tbptt_split_batch(self, batch: Any) -> List[Any]:
"""Splits a single batch into a list of sequence steps for tbptt.
Args:
batch: the current batch to split
"""
tbptt_steps = self.trainer.lightning_module.truncated_bptt_steps
if tbptt_steps == 0:
return [batch]
splits = self.trainer._call_lightning_module_hook("tbptt_split_batch", batch, tbptt_steps)
return splits
def _update_running_loss(self, current_loss: Tensor) -> None:
"""Updates the running loss value with the current value."""
if self.trainer.lightning_module.automatic_optimization:
# track total loss for logging (avoid mem leaks)
self.accumulated_loss.append(current_loss)
accumulated_loss = self.accumulated_loss.mean()
if accumulated_loss is not None:
# calculate running loss for display
self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches)
# reset for next set of accumulated grads
self.accumulated_loss.reset()