From 03a699693bd023ed8bd69893fd0c57fdbcdd75b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 19 Jan 2023 18:10:41 +0100 Subject: [PATCH] Remove truncated backpropagation from loops (#16337) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- docs/source-pytorch/api_references.rst | 1 - .../common/lightning_module.rst | 81 +------ docs/source-pytorch/common/trainer.rst | 2 +- docs/source-pytorch/extensions/loops.rst | 18 +- docs/source-pytorch/guides/data.rst | 32 --- examples/pl_loops/yielding_training_step.py | 8 +- src/lightning_app/utilities/introspection.py | 6 - src/pytorch_lightning/CHANGELOG.md | 8 + .../callbacks/progress/base.py | 3 - src/pytorch_lightning/core/module.py | 95 +------- src/pytorch_lightning/loops/__init__.py | 1 - src/pytorch_lightning/loops/batch/__init__.py | 16 -- .../loops/batch/training_batch_loop.py | 140 ------------ .../loops/epoch/training_epoch_loop.py | 74 +++++-- src/pytorch_lightning/loops/fit_loop.py | 13 +- src/pytorch_lightning/loops/loop.py | 3 +- .../loops/optimization/manual_loop.py | 17 +- .../loops/optimization/optimizer_loop.py | 24 +- src/pytorch_lightning/loops/utilities.py | 33 +-- .../trainer/configuration_validator.py | 6 - .../logger_connector/logger_connector.py | 11 +- src/pytorch_lightning/trainer/trainer.py | 2 +- .../tuner/batch_size_scaling.py | 4 +- .../utilities/migration/migration.py | 39 +++- .../utilities/migration/utils.py | 5 +- .../progress/test_tqdm_progress_bar.py | 2 - tests/tests_pytorch/loops/batch/__init__.py | 0 .../loops/batch/test_truncated_bptt.py | 205 ------------------ .../loops/epoch/test_training_epoch_loop.py | 73 ++----- .../loops/test_evaluation_loop_flow.py | 12 +- .../loops/test_loop_state_dict.py | 9 +- tests/tests_pytorch/loops/test_loops.py | 50 ++--- .../tests_pytorch/loops/test_training_loop.py | 2 +- .../loops/test_training_loop_flow_scalar.py | 17 +- tests/tests_pytorch/loops/test_utilities.py | 27 +-- tests/tests_pytorch/trainer/test_trainer.py | 2 +- .../utilities/migration/test_migration.py | 36 +++ .../tests_pytorch/utilities/test_fetching.py | 15 -- 38 files changed, 248 insertions(+), 844 deletions(-) delete mode 100644 src/pytorch_lightning/loops/batch/__init__.py delete mode 100644 src/pytorch_lightning/loops/batch/training_batch_loop.py delete mode 100644 tests/tests_pytorch/loops/batch/__init__.py delete mode 100644 tests/tests_pytorch/loops/batch/test_truncated_bptt.py diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index 828ac7cc79..fe4b967da1 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -126,7 +126,6 @@ Training :nosignatures: :template: classtemplate.rst - ~batch.TrainingBatchLoop ~epoch.TrainingEpochLoop FitLoop ~optimization.ManualOptimization diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index e1b9b38f4f..b365edf819 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -1035,7 +1035,7 @@ global_step ~~~~~~~~~~~ The number of optimizer steps taken (does not reset each epoch). -This includes multiple optimizers and TBPTT steps (if enabled). +This includes multiple optimizers (if enabled). .. code-block:: python @@ -1195,79 +1195,6 @@ Set and access example_input_array, which basically represents a single batch. # generate some images using the example_input_array gen_images = self.generator(self.example_input_array) -truncated_bptt_steps -~~~~~~~~~~~~~~~~~~~~ - -Truncated Backpropagation Through Time (TBPTT) performs perform backpropogation every k steps of -a much longer sequence. This is made possible by passing training batches -split along the time-dimensions into splits of size k to the -``training_step``. In order to keep the same forward propagation behavior, all -hidden states should be kept in-between each time-dimension split. - - -If this is enabled, your batches will automatically get truncated -and the Trainer will apply Truncated Backprop to it. - -(`Williams et al. "An efficient gradient-based algorithm for on-line training of -recurrent network trajectories." -`_) - -`Tutorial `_ - -.. testcode:: python - - from pytorch_lightning import LightningModule - - - class MyModel(LightningModule): - def __init__(self, input_size, hidden_size, num_layers): - super().__init__() - # batch_first has to be set to True - self.lstm = nn.LSTM( - input_size=input_size, - hidden_size=hidden_size, - num_layers=num_layers, - batch_first=True, - ) - - ... - - # Important: This property activates truncated backpropagation through time - # Setting this value to 2 splits the batch into sequences of size 2 - self.truncated_bptt_steps = 2 - - # Truncated back-propagation through time - def training_step(self, batch, batch_idx, hiddens): - x, y = batch - - # the training step must be updated to accept a ``hiddens`` argument - # hiddens are the hiddens from the previous truncated backprop step - out, hiddens = self.lstm(x, hiddens) - - ... - - return {"loss": ..., "hiddens": hiddens} - -Lightning takes care of splitting your batch along the time-dimension. It is -assumed to be the second dimension of your batches. Therefore, in the -example above, we have set ``batch_first=True``. - -.. code-block:: python - - # we use the second as the time dimension - # (batch, time, ...) - sub_batch = batch[0, 0:t, ...] - -To modify how the batch is split, -override the :meth:`pytorch_lightning.core.module.LightningModule.tbptt_split_batch` method: - -.. testcode:: python - - class LitMNIST(LightningModule): - def tbptt_split_batch(self, batch, split_size): - # do your own splitting on the batch - return splits - -------------- .. _lightning_hooks: @@ -1636,12 +1563,6 @@ setup .. automethod:: pytorch_lightning.core.module.LightningModule.setup :noindex: -tbptt_split_batch -~~~~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.module.LightningModule.tbptt_split_batch - :noindex: - teardown ~~~~~~~~ diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index d1d7d51eb3..c4d4ad10f0 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -1418,7 +1418,7 @@ global_step *********** The number of optimizer steps taken (does not reset each epoch). -This includes multiple optimizers and TBPTT steps (if enabled). +This includes multiple optimizers (if enabled). .. code-block:: python diff --git a/docs/source-pytorch/extensions/loops.rst b/docs/source-pytorch/extensions/loops.rst index 740b9ef3de..5e3818c046 100644 --- a/docs/source-pytorch/extensions/loops.rst +++ b/docs/source-pytorch/extensions/loops.rst @@ -306,14 +306,11 @@ Here is what the structure would look like in plain Python: # TrainingEpochLoop for batch_idx, batch in enumerate(train_dataloader): - # TrainingBatchLoop - for split_batch in tbptt_split(batch): + # OptimizerLoop + for optimizer_idx, opt in enumerate(optimizers): - # OptimizerLoop - for optimizer_idx, opt in enumerate(optimizers): - - loss = lightning_module.training_step(batch, batch_idx, optimizer_idx) - ... + loss = lightning_module.training_step(batch, batch_idx, optimizer_idx) + ... # ValidationEpochLoop for batch_idx, batch in enumerate(val_dataloader): @@ -339,13 +336,8 @@ Each of these :code:`for`-loops represents a class implementing the :class:`~pyt The validation is carried out by yet another loop, :class:`~pytorch_lightning.loops.epoch.validation_epoch_loop.ValidationEpochLoop`. In the :code:`run()` method, the training epoch loop could in theory simply call the :code:`LightningModule.training_step` already and perform the optimization. - However, Lightning has built-in support for automatic optimization with multiple optimizers and on top of that also supports :ref:`TBPTT `. + However, Lightning has built-in support for automatic optimization with multiple optimizers. For this reason there are actually two more loops nested under :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop`. - * - :class:`~pytorch_lightning.loops.batch.training_batch_loop.TrainingBatchLoop` - - The responsibility of the :class:`~pytorch_lightning.loops.batch.training_batch_loop.TrainingBatchLoop` is to split a batch given by the :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop` along the time-dimension and iterate over the list of splits. - It also keeps track of the hidden state *hiddens* returned by the training step. - By default, when truncated back-propagation through time (TBPTT) is turned off, this loop does not do anything except redirect the call to the :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop`. - Read more about :ref:`TBPTT `. * - :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop` - The :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop` iterates over one or multiple optimizers and for each one it calls the :meth:`~pytorch_lightning.core.module.LightningModule.training_step` method with the batch, the current batch index and the optimizer index if multiple optimizers are requested. It is the leaf node in the tree of loops and performs the actual optimization (forward, zero grad, backward, optimizer step). diff --git a/docs/source-pytorch/guides/data.rst b/docs/source-pytorch/guides/data.rst index 2ce14fbc34..5f5ca1c9bf 100644 --- a/docs/source-pytorch/guides/data.rst +++ b/docs/source-pytorch/guides/data.rst @@ -343,38 +343,6 @@ When using :class:`~torch.nn.utils.rnn.PackedSequence`, do two things: x = rnn.pack_sequence(batch[0], enforce_sorted=False) y = rnn.pack_sequence(batch[1], enforce_sorted=False) - -Truncated Backpropagation Through Time (TBPTT) -============================================== - -There are times when multiple backwards passes are needed for each batch. -For example, it may save memory to use **Truncated Backpropagation Through Time** when training RNNs. - -Lightning can handle TBPTT automatically via this flag. - -.. testcode:: - - from pytorch_lightning import LightningModule - - - class MyModel(LightningModule): - def __init__(self): - super().__init__() - # Important: This property activates truncated backpropagation through time - # Setting this value to 2 splits the batch into sequences of size 2 - self.truncated_bptt_steps = 2 - - # Truncated back-propagation through time - def training_step(self, batch, batch_idx, hiddens): - # the training step must be updated to accept a ``hiddens`` argument - # hiddens are the hiddens from the previous truncated backprop step - out, hiddens = self.lstm(data, hiddens) - return {"loss": ..., "hiddens": hiddens} - -.. note:: If you need to modify how the batch is split, - override :func:`~pytorch_lightning.core.module.LightningModule.tbptt_split_batch`. - - Iterable Datasets ================= Lightning supports using :class:`~torch.utils.data.IterableDataset` as well as map-style Datasets. IterableDatasets provide a more natural diff --git a/examples/pl_loops/yielding_training_step.py b/examples/pl_loops/yielding_training_step.py index 7fedf72c4e..704a8f8479 100644 --- a/examples/pl_loops/yielding_training_step.py +++ b/examples/pl_loops/yielding_training_step.py @@ -74,7 +74,7 @@ class YieldLoop(OptimizerLoop): return partial(self._training_step, self._generator) def _get_generator(self, kwargs, opt_idx=0): - kwargs = self._build_kwargs(kwargs, opt_idx, hiddens=None) + kwargs = self._build_kwargs(kwargs, opt_idx) # Here we are basically calling `lightning_module.training_step()` # and this returns a generator! The `training_step` is handled by @@ -285,8 +285,8 @@ class GAN(LightningModule): ############################################################################################# # Step 3 / 3: Connect the loop to the Trainer # # # -# Finally, attach the loop to the `Trainer`. Here, we modified the `AutomaticOptimization` # -# loop which is a subloop of the `TrainingBatchLoop`. We use `.connect()` to attach it. # +# Finally, attach the loop to the `Trainer`. Here, we modified the `OptimizerLoop` # +# loop which is a subloop of the `TrainingEpochLoop`. We use `.connect()` to attach it. # ############################################################################################# if __name__ == "__main__": @@ -296,7 +296,7 @@ if __name__ == "__main__": # Connect the new loop # YieldLoop now replaces the previous optimizer loop - trainer.fit_loop.epoch_loop.batch_loop.connect(optimizer_loop=YieldLoop()) + trainer.fit_loop.epoch_loop.connect(optimizer_loop=YieldLoop()) # fit() will now use the new loop! trainer.fit(model, dm) diff --git a/src/lightning_app/utilities/introspection.py b/src/lightning_app/utilities/introspection.py index 19bd72e20c..3da59f488c 100644 --- a/src/lightning_app/utilities/introspection.py +++ b/src/lightning_app/utilities/introspection.py @@ -106,7 +106,6 @@ class LightningModuleVisitor(LightningVisitor): "optimizer_zero_grad", "prepare_data", "setup", - "tbptt_split_batch", "teardown", "train_dataloader", "val_dataloader", @@ -256,10 +255,6 @@ class TorchMetricVisitor(LightningVisitor): class_name = "Metric" -class LightningLiteVisitor(LightningVisitor): # deprecated - class_name = "LightningLite" - - class FabricVisitor(LightningVisitor): class_name = "Fabric" @@ -297,7 +292,6 @@ class Scanner: LightningLoggerVisitor, LightningLoopVisitor, TorchMetricVisitor, - LightningLiteVisitor, # deprecated FabricVisitor, LightningProfilerVisitor, ] diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index f4c6f8b066..6e8d242c8c 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -56,6 +56,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Removed the `Trainer(auto_select_gpus=...)` argument * Removed the `pytorch_lightning.tuner.auto_gpu_select.{pick_single_gpu,pick_multiple_gpus}` functions +- Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172)) + * Removed the `LightningModule.truncated_bptt_steps` attribute + * Removed the `LightningModule.tbptt_split_batch` hook + * The `LightningModule.training_step` no longer accepts a `hiddens` argument + * Removed the `pytorch_lightning.loops.batch.TrainingBatchLoop` + * Removed the `FitLoop.split_idx` property + * Removed the `LoggerConnector.on_train_split_start` method + ### Fixed diff --git a/src/pytorch_lightning/callbacks/progress/base.py b/src/pytorch_lightning/callbacks/progress/base.py index 7dc555ee76..b757325143 100644 --- a/src/pytorch_lightning/callbacks/progress/base.py +++ b/src/pytorch_lightning/callbacks/progress/base.py @@ -279,9 +279,6 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule") if avg_training_loss is not None: items_dict["loss"] = f"{avg_training_loss:.3g}" - if pl_module.truncated_bptt_steps > 0: - items_dict["split_idx"] = trainer.fit_loop.split_idx - if trainer.loggers: version = _version(trainer.loggers) if version is not None: diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index 68e3f50549..f31175e0d2 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -13,7 +13,6 @@ # limitations under the License. """The LightningModule - an nn.Module with many additional features.""" -import collections.abc import logging import numbers import weakref @@ -89,7 +88,6 @@ class LightningModule( "logger", "loggers", "automatic_optimization", - "truncated_bptt_steps", "trainer", "fabric", ] @@ -115,7 +113,6 @@ class LightningModule( self._example_input_array: Optional[Union[Tensor, Tuple, Dict]] = None self._current_fx_name: Optional[str] = None self._automatic_optimization: bool = True - self._truncated_bptt_steps: int = 0 self._param_requires_grad_state: Dict[str, bool] = {} self._metric_attributes: Optional[Dict[int, str]] = None self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False @@ -275,20 +272,6 @@ class LightningModule( def automatic_optimization(self, automatic_optimization: bool) -> None: self._automatic_optimization = automatic_optimization - @property - def truncated_bptt_steps(self) -> int: - """Enables `Truncated Backpropagation Through Time` in the Trainer when set to a positive integer. - - It represents - the number of times :meth:`training_step` gets called before backpropagation. If this is > 0, the - :meth:`training_step` receives an additional argument ``hiddens`` and is expected to return a hidden state. - """ - return self._truncated_bptt_steps - - @truncated_bptt_steps.setter - def truncated_bptt_steps(self, truncated_bptt_steps: int) -> None: - self._truncated_bptt_steps = truncated_bptt_steps - @property def logger(self) -> Optional[Union[Logger, FabricLogger]]: """Reference to the logger object in the Trainer.""" @@ -683,8 +666,6 @@ class LightningModule( The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. batch_idx (``int``): Integer displaying index of this batch optimizer_idx (``int``): When using multiple optimizers, this argument will also be present. - hiddens (``Any``): Passed in if - :paramref:`~pytorch_lightning.core.module.LightningModule.truncated_bptt_steps` > 0. Return: Any of. @@ -719,19 +700,6 @@ class LightningModule( # do training_step with decoder ... - - If you add truncated back propagation through time you will also get an additional - argument with the hidden states of the previous step. - - .. code-block:: python - - # Truncated back-propagation through time - def training_step(self, batch, batch_idx, hiddens): - # hiddens are the hidden states from the previous truncated backprop step - out, hiddens = self.lstm(data, hiddens) - loss = ... - return {"loss": loss, "hiddens": hiddens} - Note: The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step. @@ -817,9 +785,8 @@ class LightningModule( training_epoch_end(train_outs) Args: - outputs: List of outputs you defined in :meth:`training_step`. If there are multiple optimizers or when - using ``truncated_bptt_steps > 0``, the lists have the dimensions - (n_batches, tbptt_steps, n_optimizers). Dimensions of length 1 are squeezed. + outputs: List of outputs you defined in :meth:`training_step`. If there are multiple optimizers, the lists + have the dimensions (n_batches, n_optimizers). Dimensions of length 1 are squeezed. Return: None @@ -1764,64 +1731,6 @@ class LightningModule( """ optimizer.zero_grad() - def tbptt_split_batch(self, batch: Any, split_size: int) -> List[Any]: - r""" - When using truncated backpropagation through time, each batch must be split along the - time dimension. Lightning handles this by default, but for custom behavior override - this function. - - Args: - batch: Current batch - split_size: The size of the split - - Return: - List of batch splits. Each split will be passed to :meth:`training_step` to enable truncated - back propagation through time. The default implementation splits root level Tensors and - Sequences at dim=1 (i.e. time dim). It assumes that each time dim is the same length. - - Examples:: - - def tbptt_split_batch(self, batch, split_size): - splits = [] - for t in range(0, time_dims[0], split_size): - batch_split = [] - for i, x in enumerate(batch): - if isinstance(x, torch.Tensor): - split_x = x[:, t:t + split_size] - elif isinstance(x, collections.abc.Sequence): - split_x = [None] * len(x) - for batch_idx in range(len(x)): - split_x[batch_idx] = x[batch_idx][t:t + split_size] - batch_split.append(split_x) - splits.append(batch_split) - return splits - - Note: - Called in the training loop after - :meth:`~pytorch_lightning.callbacks.base.Callback.on_train_batch_start` - if :paramref:`~pytorch_lightning.core.module.LightningModule.truncated_bptt_steps` > 0. - Each returned batch split is passed separately to :meth:`training_step`. - """ - time_dims = [len(x[0]) for x in batch if isinstance(x, (Tensor, collections.abc.Sequence))] - assert len(time_dims) >= 1, "Unable to determine batch time dimension" - assert all(x == time_dims[0] for x in time_dims), "Batch time dimension length is ambiguous" - - splits = [] - for t in range(0, time_dims[0], split_size): - batch_split = [] - for i, x in enumerate(batch): - split_x: Union[Tensor, List[Tensor]] - if isinstance(x, Tensor): - split_x = x[:, t : t + split_size] - elif isinstance(x, collections.abc.Sequence): - split_x = [x[batch_idx][t : t + split_size] for batch_idx in range(len(x))] - - batch_split.append(split_x) - - splits.append(batch_split) - - return splits - def freeze(self) -> None: r""" Freeze all params for inference. diff --git a/src/pytorch_lightning/loops/__init__.py b/src/pytorch_lightning/loops/__init__.py index 3aa32809b8..5fde69c150 100644 --- a/src/pytorch_lightning/loops/__init__.py +++ b/src/pytorch_lightning/loops/__init__.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.loops.loop import Loop # noqa: F401 isort: skip (avoids circular imports) -from pytorch_lightning.loops.batch import TrainingBatchLoop # noqa: F401 from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401 from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401 from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401 diff --git a/src/pytorch_lightning/loops/batch/__init__.py b/src/pytorch_lightning/loops/batch/__init__.py deleted file mode 100644 index e150fb5881..0000000000 --- a/src/pytorch_lightning/loops/batch/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401 -from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization # noqa: F401 diff --git a/src/pytorch_lightning/loops/batch/training_batch_loop.py b/src/pytorch_lightning/loops/batch/training_batch_loop.py deleted file mode 100644 index f61e29cc01..0000000000 --- a/src/pytorch_lightning/loops/batch/training_batch_loop.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, List, Optional, Tuple, Union - -from torch import Tensor -from typing_extensions import OrderedDict - -from pytorch_lightning.loops.loop import Loop -from pytorch_lightning.loops.optimization.manual_loop import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE -from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization -from pytorch_lightning.loops.optimization.optimizer_loop import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE -from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop -from pytorch_lightning.loops.utilities import _get_active_optimizers -from pytorch_lightning.trainer.supporters import TensorRunningAccum - -_OUTPUTS_TYPE = List[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] - - -class TrainingBatchLoop(Loop[_OUTPUTS_TYPE]): - """Runs over a single batch of data.""" - - def __init__(self) -> None: - super().__init__() - self.accumulated_loss = TensorRunningAccum(window_length=20) - self.running_loss = TensorRunningAccum(window_length=20) - # the current split index when the batch gets split into chunks in truncated backprop through time - self.split_idx: int = 0 - self.optimizer_loop = OptimizerLoop() - self.manual_loop = ManualOptimization() - - self._outputs: _OUTPUTS_TYPE = [] - self._remaining_splits: List[Tuple[int, Any]] = [] - - @property - def done(self) -> bool: - """Returns if all batch splits have been processed already.""" - return len(self._remaining_splits) == 0 - - def connect( # type: ignore[override] - self, optimizer_loop: Optional[OptimizerLoop] = None, manual_loop: Optional[ManualOptimization] = None - ) -> None: - if optimizer_loop is not None: - self.optimizer_loop = optimizer_loop - if manual_loop is not None: - self.manual_loop = manual_loop - - def reset(self) -> None: - """Resets the loop state.""" - self._outputs = [] - - def on_run_start(self, kwargs: OrderedDict) -> None: - """Splits the data into tbptt splits. - - Args: - kwargs: the kwargs passed down to the hooks. - """ - batch = kwargs["batch"] - self._remaining_splits = list(enumerate(self._tbptt_split_batch(batch))) - - def advance(self, kwargs: OrderedDict) -> None: - """Runs the train step together with optimization (if necessary) on the current batch split. - - Args: - kwargs: the kwargs passed down to the hooks. - """ - # replace the batch with the split batch - self.split_idx, kwargs["batch"] = self._remaining_splits.pop(0) - - self.trainer._logger_connector.on_train_split_start(self.split_idx) - - outputs: Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] = None # for mypy - # choose which loop will run the optimization - if self.trainer.lightning_module.automatic_optimization: - optimizers = _get_active_optimizers( - self.trainer.optimizers, self.trainer.optimizer_frequencies, kwargs.get("batch_idx", 0) - ) - outputs = self.optimizer_loop.run(optimizers, kwargs) - else: - outputs = self.manual_loop.run(kwargs) - if outputs: - # automatic: can be empty if all optimizers skip their batches - # manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens, - # then `advance` doesn't finish and an empty dict is returned - self._outputs.append(outputs) - - def on_run_end(self) -> _OUTPUTS_TYPE: - self.optimizer_loop._hiddens = None - # this is not necessary as the manual loop runs for only 1 iteration, but just in case - self.manual_loop._hiddens = None - output, self._outputs = self._outputs, [] # free memory - self._remaining_splits = [] - return output - - def teardown(self) -> None: - self.optimizer_loop.teardown() - self.manual_loop.teardown() - # release memory - if self.accumulated_loss.memory is not None: - self.accumulated_loss.memory = self.accumulated_loss.memory.cpu() - if self.running_loss.memory is not None: - self.running_loss.memory = self.running_loss.memory.cpu() - - def _tbptt_split_batch(self, batch: Any) -> List[Any]: - """Splits a single batch into a list of sequence steps for tbptt. - - Args: - batch: the current batch to split - """ - tbptt_steps = self.trainer.lightning_module.truncated_bptt_steps - if tbptt_steps == 0: - return [batch] - - splits = self.trainer._call_lightning_module_hook("tbptt_split_batch", batch, tbptt_steps) - return splits - - def _update_running_loss(self, current_loss: Tensor) -> None: - """Updates the running loss value with the current value.""" - if self.trainer.lightning_module.automatic_optimization: - # track total loss for logging (avoid mem leaks) - self.accumulated_loss.append(current_loss) - - accumulated_loss = self.accumulated_loss.mean() - - if accumulated_loss is not None: - # calculate running loss for display - self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) - - # reset for next set of accumulated grads - self.accumulated_loss.reset() diff --git a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py index 5bdec1b552..980d6aa133 100644 --- a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -18,15 +18,17 @@ from typing import Any, DefaultDict, Dict, Generator, List, Optional, overload, import numpy as np import torch from lightning_utilities.core.apply_func import apply_to_collection +from torch import Tensor import pytorch_lightning as pl from pytorch_lightning import loops # import as loops to avoid circular imports -from pytorch_lightning.loops.batch import TrainingBatchLoop -from pytorch_lightning.loops.batch.training_batch_loop import _OUTPUTS_TYPE as _BATCH_OUTPUTS_TYPE +from pytorch_lightning.loops.optimization import ManualOptimization, OptimizerLoop +from pytorch_lightning.loops.optimization.manual_loop import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE +from pytorch_lightning.loops.optimization.optimizer_loop import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress -from pytorch_lightning.trainer.supporters import CombinedLoader +from pytorch_lightning.trainer.supporters import CombinedLoader, TensorRunningAccum from pytorch_lightning.utilities.auto_restart import _collect_states_on_rank_zero_over_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher @@ -34,6 +36,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import rank_zero_warn, WarningCache from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature +_BATCH_OUTPUTS_TYPE = Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] _OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE] @@ -57,7 +60,11 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): self.batch_progress = BatchProgress() self.scheduler_progress = SchedulerProgress() - self.batch_loop = TrainingBatchLoop() + self.accumulated_loss = TensorRunningAccum(window_length=20) + self.running_loss = TensorRunningAccum(window_length=20) + self.optimizer_loop = OptimizerLoop() + self.manual_loop = ManualOptimization() + self.val_loop = loops.EvaluationLoop(verbose=False) self._results = _ResultCollection(training=True) @@ -85,8 +92,8 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): def global_step(self) -> int: lightning_module = self.trainer.lightning_module if lightning_module is None or lightning_module.automatic_optimization: - return self.batch_loop.optimizer_loop.optim_progress.optimizer_steps - return self.batch_loop.manual_loop.optim_step_progress.total.completed + return self.optimizer_loop.optim_progress.optimizer_steps + return self.manual_loop.optim_step_progress.total.completed @property def _is_training_done(self) -> bool: @@ -119,12 +126,15 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): def connect( # type: ignore[override] self, - batch_loop: Optional[TrainingBatchLoop] = None, + optimizer_loop: Optional[OptimizerLoop] = None, + manual_loop: Optional[ManualOptimization] = None, val_loop: Optional["loops.EvaluationLoop"] = None, ) -> None: """Optionally connect a custom batch or validation loop to this training epoch loop.""" - if batch_loop is not None: - self.batch_loop = batch_loop + if optimizer_loop is not None: + self.optimizer_loop = optimizer_loop + if manual_loop is not None: + self.manual_loop = manual_loop if val_loop is not None: self.val_loop = val_loop @@ -133,7 +143,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): if self.restarting: self.batch_progress.reset_on_restart() self.scheduler_progress.reset_on_restart() - self.batch_loop.optimizer_loop.optim_progress.reset_on_restart() + self.optimizer_loop.optim_progress.reset_on_restart() trainer = self.trainer if not trainer.state._fault_tolerant_mode.is_enabled and trainer.num_training_batches != float("inf"): @@ -148,7 +158,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): else: self.batch_progress.reset_on_run() self.scheduler_progress.reset_on_run() - self.batch_loop.optimizer_loop.optim_progress.reset_on_run() + self.optimizer_loop.optim_progress.reset_on_run() # when the epoch starts, the total val batch progress should be reset as it's supposed to count the batches # seen per epoch, this is useful for tracking when validation is run multiple times per epoch self.val_loop.epoch_loop.batch_progress.total.reset() @@ -195,9 +205,9 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): self.trainer._logger_connector.on_batch_start(batch, batch_idx) + batch_output: _BATCH_OUTPUTS_TYPE = None # for mypy if batch is None: self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") - batch_output = [] else: # hook self.trainer._call_callback_hooks("on_train_batch_start", batch, batch_idx) @@ -210,7 +220,14 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): self.batch_progress.increment_started() with self.trainer.profiler.profile("run_training_batch"): - batch_output = self.batch_loop.run(kwargs) + # choose which loop will run the optimization + if self.trainer.lightning_module.automatic_optimization: + optimizers = _get_active_optimizers( + self.trainer.optimizers, self.trainer.optimizer_frequencies, kwargs.get("batch_idx", 0) + ) + batch_output = self.optimizer_loop.run(optimizers, kwargs) + else: + batch_output = self.manual_loop.run(kwargs) self.batch_progress.increment_processed() @@ -232,7 +249,11 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): self.batch_progress.increment_completed() - if is_overridden("training_epoch_end", self.trainer.lightning_module): + if batch_output and is_overridden("training_epoch_end", self.trainer.lightning_module): + # batch_output may be empty + # automatic: can be empty if all optimizers skip their batches + # manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens, + # then `advance` doesn't finish and an empty dict is returned self._outputs.append(batch_output) # ----------------------------------------- @@ -254,7 +275,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): self.update_lr_schedulers("step", update_plateau_schedulers=True) if not self._should_accumulate(): - # this is increased once per batch disregarding multiple optimizers or tbptt on purpose for loggers + # this is increased once per batch disregarding multiple optimizers on purpose for loggers self._batches_that_stepped += 1 # this will save based on the `batches_that_stepped` value self._save_loggers_on_train_batch_end() @@ -271,7 +292,13 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): def teardown(self) -> None: self._results.cpu() - self.batch_loop.teardown() + self.optimizer_loop.teardown() + self.manual_loop.teardown() + # release memory + if self.accumulated_loss.memory is not None: + self.accumulated_loss.memory = self.accumulated_loss.memory.cpu() + if self.running_loss.memory is not None: + self.running_loss.memory = self.running_loss.memory.cpu() self.val_loop.teardown() def on_save_checkpoint(self) -> Dict: @@ -527,6 +554,21 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): kwargs["batch_idx"] = batch_idx return kwargs + def _update_running_loss(self, current_loss: Tensor) -> None: + """Updates the running loss value with the current value.""" + if self.trainer.lightning_module.automatic_optimization: + # track total loss for logging (avoid mem leaks) + self.accumulated_loss.append(current_loss) + + accumulated_loss = self.accumulated_loss.mean() + + if accumulated_loss is not None: + # calculate running loss for display + self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) + + # reset for next set of accumulated grads + self.accumulated_loss.reset() + def _convert_optim_dict(outs: Dict[int, Dict[str, Any]], num_optimizers: int) -> List[Optional[Dict[str, Any]]]: """Converts an optimizer dict to a list in which the key of the dict determines the position of the element. diff --git a/src/pytorch_lightning/loops/fit_loop.py b/src/pytorch_lightning/loops/fit_loop.py index 0c46566e84..c0ac03cc98 100644 --- a/src/pytorch_lightning/loops/fit_loop.py +++ b/src/pytorch_lightning/loops/fit_loop.py @@ -77,11 +77,6 @@ class FitLoop(Loop[None]): """Returns the current batch index (within this epoch)""" return self.epoch_loop.batch_idx - @property - def split_idx(self) -> int: - """Returns the index of the current batch split (within the current batch) for bptt.""" - return self.epoch_loop.batch_loop.split_idx - @property def min_steps(self) -> Optional[int]: # TODO(@justusschock): Why aren't we using the attribute in this class? @@ -112,7 +107,7 @@ class FitLoop(Loop[None]): @property def running_loss(self) -> TensorRunningAccum: """Returns the running loss.""" - return self.epoch_loop.batch_loop.running_loss + return self.epoch_loop.running_loss @Loop.restarting.setter def restarting(self, restarting: bool) -> None: @@ -131,12 +126,12 @@ class FitLoop(Loop[None]): @property def _skip_backward(self) -> bool: """Determines whether the loop will skip backward during automatic optimization.""" - return self.epoch_loop.batch_loop.optimizer_loop._skip_backward + return self.epoch_loop.optimizer_loop._skip_backward @_skip_backward.setter def _skip_backward(self, value: bool) -> None: """Determines whether the loop will skip backward during automatic optimization.""" - self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value + self.epoch_loop.optimizer_loop._skip_backward = value @property def _results(self) -> _ResultCollection: @@ -239,7 +234,7 @@ class FitLoop(Loop[None]): self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) # stores accumulated grad fractions per batch - self.epoch_loop.batch_loop.accumulated_loss.reset(window_length=self.trainer.accumulate_grad_batches) + self.epoch_loop.accumulated_loss.reset(window_length=self.trainer.accumulate_grad_batches) self.epoch_progress.increment_ready() diff --git a/src/pytorch_lightning/loops/loop.py b/src/pytorch_lightning/loops/loop.py index 18ad1da939..386dd71d31 100644 --- a/src/pytorch_lightning/loops/loop.py +++ b/src/pytorch_lightning/loops/loop.py @@ -145,7 +145,8 @@ class Loop(ABC, Generic[T]): # connect sub-loops kwargs = {n: lp for n, lp in old_loop.__dict__.items() if isinstance(lp, Loop)} - loop.connect(**kwargs) + if kwargs: + loop.connect(**kwargs) # set the trainer reference loop.trainer = self.trainer diff --git a/src/pytorch_lightning/loops/optimization/manual_loop.py b/src/pytorch_lightning/loops/optimization/manual_loop.py index 28e700fcbb..a9e4a07e06 100644 --- a/src/pytorch_lightning/loops/optimization/manual_loop.py +++ b/src/pytorch_lightning/loops/optimization/manual_loop.py @@ -20,7 +20,7 @@ from torch import Tensor from pytorch_lightning.core.optimizer import do_nothing_closure from pytorch_lightning.loops import Loop from pytorch_lightning.loops.optimization.closure import OutputResult -from pytorch_lightning.loops.utilities import _build_training_step_kwargs, _extract_hiddens +from pytorch_lightning.loops.utilities import _build_training_step_kwargs from pytorch_lightning.trainer.progress import Progress, ReadyCompletedTracker from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -42,7 +42,7 @@ class ManualResult(OutputResult): def from_training_step_output(cls, training_step_output: Optional[STEP_OUTPUT]) -> "ManualResult": extra = {} if isinstance(training_step_output, dict): - extra = {k: v for k, v in training_step_output.items() if k != "hiddens"} + extra = training_step_output.copy() elif isinstance(training_step_output, Tensor): extra = {"loss": training_step_output} elif training_step_output is not None: @@ -82,7 +82,6 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]): self.optim_step_progress = Progress.from_defaults(ReadyCompletedTracker) self._done: bool = False - self._hiddens: Optional[Any] = None self._output: _OUTPUTS_TYPE = {} @property @@ -104,7 +103,7 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]): Args: kwargs: The kwargs passed down to the hooks. """ - kwargs = self._build_kwargs(kwargs, self._hiddens) + kwargs = self._build_kwargs(kwargs) # manually capture logged metrics training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values()) @@ -114,12 +113,11 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]): model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output) strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output) training_step_output = strategy_output if model_output is None else model_output - self._hiddens = _extract_hiddens(training_step_output, self.trainer.lightning_module.truncated_bptt_steps) result = self.output_result_cls.from_training_step_output(training_step_output) if self.trainer.move_metrics_to_cpu: - # hiddens and the training step output are not moved as they are not considered "metrics" + # training step output does not get moved because it is not considered a "metric" # the user might need them on the correct device for an operation in `training_epoch_end` assert self.trainer._results is not None self.trainer._results.cpu() @@ -144,16 +142,13 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]): self.trainer.profiler.stop("optimizer_step") self.optim_step_progress.increment_completed() - def _build_kwargs(self, kwargs: OrderedDict, hiddens: Optional[Any]) -> OrderedDict: + def _build_kwargs(self, kwargs: OrderedDict) -> OrderedDict: """Helper method to build the arguments for the current step. Args: kwargs: The kwargs passed down to the hooks. - hiddens: the hidden state of the previous RNN iteration. Returns: The kwargs passed down to the hooks. """ - return _build_training_step_kwargs( - kwargs, self.trainer.lightning_module, self.trainer.optimizers, None, hiddens - ) + return _build_training_step_kwargs(kwargs, self.trainer.lightning_module, self.trainer.optimizers, None) diff --git a/src/pytorch_lightning/loops/optimization/optimizer_loop.py b/src/pytorch_lightning/loops/optimization/optimizer_loop.py index 56637b7f68..0158f97f68 100644 --- a/src/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/src/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -24,11 +24,7 @@ from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loops import Loop from pytorch_lightning.loops.optimization.closure import AbstractClosure, OutputResult -from pytorch_lightning.loops.utilities import ( - _block_parallel_sync_behavior, - _build_training_step_kwargs, - _extract_hiddens, -) +from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior, _build_training_step_kwargs from pytorch_lightning.trainer.progress import OptimizationProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.rank_zero import WarningCache @@ -72,7 +68,7 @@ class ClosureResult(OutputResult): raise MisconfigurationException( "In automatic_optimization, when `training_step` returns a dict, the 'loss' key needs to be present" ) - extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")} + extra = {k: v for k, v in training_step_output.items() if k != "loss"} elif isinstance(training_step_output, Tensor): closure_loss = training_step_output elif training_step_output is not None: @@ -166,7 +162,6 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): self._skip_backward: bool = False self._optimizers: Tuple[Optimizer, ...] = tuple() self._indices: Tuple[int, ...] = tuple() - self._hiddens: Optional[Any] = None @property def optimizer_idx(self) -> int: @@ -194,7 +189,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): self.optim_progress.optimizer_position = 0 def advance(self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict) -> None: - kwargs = self._build_kwargs(kwargs, self.optimizer_idx, self._hiddens) + kwargs = self._build_kwargs(kwargs, self.optimizer_idx) result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position]) if result.loss is not None: @@ -251,7 +246,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): # if no result, user decided to skip optimization # otherwise update running loss + reset accumulated loss # TODO: find proper way to handle updating running loss - self.trainer.fit_loop.epoch_loop.batch_loop._update_running_loss(result.loss) + self.trainer.fit_loop.epoch_loop._update_running_loss(result.loss) # untoggle model params self._run_optimization_end(opt_idx) @@ -404,30 +399,25 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output) training_step_output = strategy_output if model_output is None else model_output - self._hiddens = _extract_hiddens(training_step_output, self.trainer.lightning_module.truncated_bptt_steps) - result = self.output_result_cls.from_training_step_output( training_step_output, self.trainer.accumulate_grad_batches ) if self.trainer.move_metrics_to_cpu: - # hiddens and the training step output are not moved as they are not considered "metrics" + # training step output does not get moved because it is not considered a "metric" assert self.trainer._results is not None self.trainer._results.cpu() return result - def _build_kwargs(self, kwargs: OrderedDict, opt_idx: int, hiddens: Optional[Any]) -> OrderedDict: + def _build_kwargs(self, kwargs: OrderedDict, opt_idx: int) -> OrderedDict: """Helper method to build the arguments for the current step. Args: kwargs: The kwargs passed down to the hooks. opt_idx: the index of the current optimizer. - hiddens: the hidden state of the previous RNN iteration. Returns: The kwargs passed down to the hooks. """ - return _build_training_step_kwargs( - kwargs, self.trainer.lightning_module, self.trainer.optimizers, opt_idx, hiddens - ) + return _build_training_step_kwargs(kwargs, self.trainer.lightning_module, self.trainer.optimizers, opt_idx) diff --git a/src/pytorch_lightning/loops/utilities.py b/src/pytorch_lightning/loops/utilities.py index e6d41ab119..342cded638 100644 --- a/src/pytorch_lightning/loops/utilities.py +++ b/src/pytorch_lightning/loops/utilities.py @@ -14,7 +14,7 @@ from collections import OrderedDict from contextlib import contextmanager from functools import lru_cache -from typing import Any, Generator, List, Optional, Sequence, Tuple, Union +from typing import Generator, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -30,11 +30,8 @@ from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.trainer.progress import BaseProgress from pytorch_lightning.trainer.supporters import CombinedLoader -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.rank_zero import rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.types import STEP_OUTPUT def check_finite_loss(loss: Optional[Tensor]) -> None: @@ -47,28 +44,6 @@ def check_finite_loss(loss: Optional[Tensor]) -> None: raise ValueError(f"The loss returned in `training_step` is {loss}.") -def _extract_hiddens(training_step_output: STEP_OUTPUT, truncated_bptt_steps: int) -> Optional[Any]: - """Get the hidden state if present from the training step output. - - Raises: - MisconfigurationException: If :attr:`~pytorch_lightning.core.Lightning.LightningModule.truncated_bptt_steps` is - not enabled and hiddens are returned or vice versa. - """ - if not truncated_bptt_steps: - if isinstance(training_step_output, dict) and "hiddens" in training_step_output: - raise MisconfigurationException( - 'You returned "hiddens" in your `training_step` but `truncated_bptt_steps` is disabled' - ) - return None - if not isinstance(training_step_output, dict) or "hiddens" not in training_step_output: - raise MisconfigurationException( - 'You enabled `truncated_bptt_steps` but did not `return {..., "hiddens": ...}` in your `training_step`' - ) - # detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time` - hiddens = recursive_detach(training_step_output["hiddens"]) - return hiddens - - def _parse_loop_limits( min_steps: Optional[int], max_steps: int, @@ -116,7 +91,6 @@ def _build_training_step_kwargs( lightning_module: "pl.LightningModule", optimizers: Sequence[Optimizer], opt_idx: Optional[int], - hiddens: Optional[Any], ) -> OrderedDict: """Builds the keyword arguments for training_step. @@ -125,7 +99,6 @@ def _build_training_step_kwargs( lightning_module: the LightningModule with a `training_step` hook implementation optimizers: the list of optimizers from the Trainer opt_idx: the index of the current optimizer - hiddens: the hidden state of the previous RNN iteration Returns: the keyword arguments for the training step @@ -147,10 +120,6 @@ def _build_training_step_kwargs( " `training_step` is missing the `optimizer_idx` argument." ) - # pass hiddens if using tbptt - if lightning_module.truncated_bptt_steps > 0: - kwargs["hiddens"] = hiddens - return kwargs diff --git a/src/pytorch_lightning/trainer/configuration_validator.py b/src/pytorch_lightning/trainer/configuration_validator.py index ac2696ec8a..e3c053b045 100644 --- a/src/pytorch_lightning/trainer/configuration_validator.py +++ b/src/pytorch_lightning/trainer/configuration_validator.py @@ -165,9 +165,3 @@ def __check_training_step_requires_dataloader_iter(model: "pl.LightningModule") " not match with the actual batch index when using a `dataloader_iter`" " argument in your `training_step`." ) - - if model.truncated_bptt_steps > 0: - raise MisconfigurationException( - "The model taking a `dataloader_iter` argument in your `training_step` " - "is incompatible with `truncated_bptt_steps > 0`." - ) diff --git a/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 86c87db6e6..abd78867ec 100644 --- a/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -37,7 +37,6 @@ class LoggerConnector: self._epoch_end_reached = False self._current_fx: Optional[str] = None self._batch_idx: Optional[int] = None - self._split_idx: Optional[int] = None def on_trainer_init( self, @@ -144,9 +143,6 @@ class LoggerConnector: Train metric updates """ - def on_train_split_start(self, split_idx: int) -> None: - self._split_idx = split_idx - def update_train_step_metrics(self) -> None: if self.trainer.fit_loop._should_accumulate() and self.trainer.lightning_module.automatic_optimization: return @@ -185,7 +181,6 @@ class LoggerConnector: def epoch_end_reached(self) -> None: self._epoch_end_reached = True self._batch_idx = None - self._split_idx = None def on_epoch_end(self) -> None: assert self._epoch_end_reached @@ -209,10 +204,7 @@ class LoggerConnector: def should_reset_tensors(self, fx: str) -> bool: is_different_fx = self._current_fx != fx - if self._split_idx is None: - is_first_batch = self._batch_idx in (None, 0) - else: - is_first_batch = bool(self._batch_idx) + self._split_idx == 0 + is_first_batch = self._batch_idx in (None, 0) return is_different_fx and is_first_batch def reset_metrics(self) -> None: @@ -226,7 +218,6 @@ class LoggerConnector: results.reset() self._batch_idx = None - self._split_idx = None self._current_fx = None @property diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 9fe3bb2346..dcaad0f750 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -1944,7 +1944,7 @@ class Trainer: def global_step(self) -> int: """The number of optimizer steps taken (does not reset each epoch). - This includes multiple optimizers and TBPTT steps (if enabled). + This includes multiple optimizers (if enabled). """ return self.fit_loop.epoch_loop.global_step diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index 82ea298a6b..902d6021ce 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -335,8 +335,8 @@ def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: def _reset_progress(trainer: "pl.Trainer") -> None: if trainer.lightning_module.automatic_optimization: - trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.reset() + trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.reset() else: - trainer.fit_loop.epoch_loop.batch_loop.manual_loop.optim_step_progress.reset() + trainer.fit_loop.epoch_loop.manual_loop.optim_step_progress.reset() trainer.fit_loop.epoch_progress.reset() diff --git a/src/pytorch_lightning/utilities/migration/migration.py b/src/pytorch_lightning/utilities/migration/migration.py index eceaa91cbe..e7f3952e05 100644 --- a/src/pytorch_lightning/utilities/migration/migration.py +++ b/src/pytorch_lightning/utilities/migration/migration.py @@ -46,7 +46,7 @@ def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: "1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking], "1.6.5": [_migrate_loop_batches_that_stepped], "1.9.0": [_migrate_model_checkpoint_save_on_train_epoch_end_default], - "2.0.0": [_drop_apex_amp_state], + "2.0.0": [_drop_apex_amp_state, _migrate_loop_structure_after_tbptt_removal], } @@ -219,3 +219,40 @@ def _drop_apex_amp_state(checkpoint: _CHECKPOINT) -> _CHECKPOINT: rank_zero_warn("This checkpoint contains apex AMP data, but apex support has been removed in v2.0.0.") del checkpoint[key] return checkpoint + + +def _migrate_loop_structure_after_tbptt_removal(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Adjusts the loop structure since it changed when the support for truncated backpropagation was removed. The + optimizer loop and the manual loop were previously children of the training batch loop. After its removal, they + became the children of the training epoch loop. + + Version: 2.0.0 + Commit: TBD + PR: #16172 + """ + if "loops" not in checkpoint: + return checkpoint + + fit_loop = checkpoint["loops"]["fit_loop"] + + # remap `x.batch_loop.y` to `x.y` + old_key_new_key_mapping = { + "epoch_loop.batch_loop.manual_loop.optim_step_progress": "epoch_loop.manual_loop.optim_step_progress", + "epoch_loop.batch_loop.manual_loop.state_dict": "epoch_loop.manual_loop.state_dict", + "epoch_loop.batch_loop.optimizer_loop.optim_progress": "epoch_loop.optimizer_loop.optim_progress", + "epoch_loop.batch_loop.optimizer_loop.state_dict": "epoch_loop.optimizer_loop.state_dict", + } + for old, new in list(old_key_new_key_mapping.items()): + if old in fit_loop: + fit_loop[new] = fit_loop[old] + del fit_loop[old] + + # We can safely drop this key: our default implementation of `batch_loop` did not have state. + # If there was state from a custom batch loop, we wouldn't be able to load it meaningfully. + # But just in case, we save a copy of it in `epoch_loop.state_dict` in case the user wants to process it after + # loading the checkpoint. + if "epoch_loop.batch_loop.state_dict" in fit_loop and fit_loop["epoch_loop.batch_loop.state_dict"]: + fit_loop["epoch_loop.state_dict"]["old_batch_loop_state_dict"] = fit_loop["epoch_loop.batch_loop.state_dict"] + fit_loop.pop("epoch_loop.batch_loop.state_dict", None) + + return checkpoint diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index e869fa4a7f..5364426436 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -148,5 +148,6 @@ def _set_legacy_version(checkpoint: _CHECKPOINT, version: str) -> None: def _should_upgrade(checkpoint: _CHECKPOINT, target: str, max_version: Optional[str] = None) -> bool: """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target.""" - is_lte_max_version = max_version is None or Version(target) <= Version(max_version) - return Version(_get_version(checkpoint)) < Version(target) and is_lte_max_version + target_version = Version(target) + is_lte_max_version = max_version is None or target_version <= Version(max_version) + return is_lte_max_version and Version(_get_version(checkpoint)) < target_version diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index bdd1c2002f..c0f93d0924 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -659,10 +659,8 @@ def test_get_progress_bar_metrics(tmpdir: str): ) model = BoringModel() trainer.fit(model) - model.truncated_bptt_steps = 2 standard_metrics = progress_bar.get_metrics(trainer, model) assert "loss" in standard_metrics.keys() - assert "split_idx" in standard_metrics.keys() assert "v_num" not in standard_metrics.keys() diff --git a/tests/tests_pytorch/loops/batch/__init__.py b/tests/tests_pytorch/loops/batch/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/tests_pytorch/loops/batch/test_truncated_bptt.py b/tests/tests_pytorch/loops/batch/test_truncated_bptt.py deleted file mode 100644 index a43d15909f..0000000000 --- a/tests/tests_pytorch/loops/batch/test_truncated_bptt.py +++ /dev/null @@ -1,205 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import math - -import pytest -import torch -import torch.nn.functional as F -from torch.utils.data import DataLoader, TensorDataset - -from pytorch_lightning import LightningModule, Trainer - - -class LSTMModel(LightningModule): - """LSTM sequence-to-sequence model for testing TBPTT with automatic optimization.""" - - def __init__(self, truncated_bptt_steps=2, input_size=1, hidden_size=8): - super().__init__() - self.input_size = input_size - self.hidden_size = hidden_size - self.lstm = torch.nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True) - self.truncated_bptt_steps = truncated_bptt_steps - self.automatic_optimization = True - - def configure_optimizers(self): - return torch.optim.SGD(self.parameters(), lr=0.01) - - def training_step(self, batch, batch_idx, hiddens): - x, y = batch - pred, hiddens = self.lstm(x, hiddens) - loss = F.mse_loss(pred, y) - return {"loss": loss, "hiddens": hiddens} - - def train_dataloader(self): - dataset = TensorDataset(torch.rand(16, 8, self.input_size), torch.rand(16, 8, self.input_size)) - return DataLoader(dataset=dataset, batch_size=4) - - -class ManualLSTMModel(LSTMModel): - """LSTM sequence-to-sequence model for testing TBPTT with manual optimization.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.automatic_optimization = False - - def training_step(self, batch, batch_idx, hiddens): - out = super().training_step(batch, batch_idx, hiddens) - loss, hiddens = out["loss"], out["hiddens"] - opt = self.optimizers() - opt.zero_grad() - self.manual_backward(loss) - opt.step() - return {"loss": loss, "hiddens": hiddens} - - -@pytest.mark.parametrize("model_class", (LSTMModel, ManualLSTMModel)) -def test_persistent_hidden_state_transfer(tmpdir, model_class): - """Test that the hidden state reference gets passed through from one training_step to the next and remains - unmodified apart from detached grad_fn.""" - - class TBPTTModel(model_class): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.test_hidden = None - - def training_step(self, batch, batch_idx, hiddens): - split_idx = self.trainer.fit_loop.split_idx - # the hidden state may only be None for the first split_idx - assert not ((split_idx == 0) ^ (hiddens is None)) - # test_hiddens is None when hiddens is None - assert not ((hiddens is None) ^ (self.test_hidden is None)) - # the states are equal (persistent) - assert hiddens is None or all(torch.equal(h, th) for h, th in zip(hiddens, self.test_hidden)) - # the incoming hidden state never has a grad_fn (gets automatically detached) - assert hiddens is None or all(h.grad_fn is None for h in hiddens) - out = super().training_step(batch, batch_idx, hiddens) - - # store hiddens, assert persistence in next training_step - self.test_hidden = out["hiddens"] - - # hiddens may have grad_fn when returning, gets automatically detached - assert all(h.grad_fn is not None for h in self.test_hidden) - return out - - def on_train_batch_start(self, *_, **__) -> None: - self.test_hidden = None - - model = TBPTTModel(truncated_bptt_steps=2, input_size=1, hidden_size=8) - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=2, - enable_model_summary=False, - logger=False, - enable_checkpointing=False, - ) - trainer.fit(model) - - -@pytest.mark.parametrize("model_class", (LSTMModel, ManualLSTMModel)) -def test_tbptt_split_shapes(tmpdir, model_class): - """Test that the sequence data gets split correctly and that the outputs are correctly passed from hook to - hook.""" - batch_size = 10 - truncated_bptt_steps = 2 - n, t, f = 32, 15, 1 # (num samples, sequence size, input size) - assert t % truncated_bptt_steps != 0, "test must run with sequence length not divisible by tbptt steps" - - seq2seq_dataset = TensorDataset(torch.rand(n, t, f), torch.rand(n, t, f)) - train_dataloader = DataLoader(dataset=seq2seq_dataset, batch_size=batch_size) - - class TBPTTModel(model_class): - def training_step(self, batch, batch_idx, hiddens): - x, y = batch - if self.trainer.fit_loop.epoch_loop.batch_loop.done: - # last split idx, not aligned - assert x.shape[1] == t % truncated_bptt_steps - assert y.shape[1] == t % truncated_bptt_steps - else: - assert x.shape[1] == truncated_bptt_steps - assert y.shape[1] == truncated_bptt_steps - return super().training_step(batch, batch_idx, hiddens) - - def training_epoch_end(self, training_step_outputs): - training_step_outputs = training_step_outputs[0] - assert len(training_step_outputs) == math.ceil(t / self.truncated_bptt_steps) - assert all(out["loss"].grad_fn is None for out in training_step_outputs) - assert all("hiddens" not in out for out in training_step_outputs) - - model = TBPTTModel(truncated_bptt_steps=truncated_bptt_steps, input_size=f, hidden_size=8) - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - enable_model_summary=False, - logger=False, - enable_checkpointing=False, - ) - trainer.fit(model, train_dataloaders=train_dataloader) - - assert trainer.fit_loop.batch_idx == n // batch_size - assert trainer.fit_loop.split_idx == t // truncated_bptt_steps - - -@pytest.mark.parametrize("model_class", (LSTMModel, ManualLSTMModel)) -def test_tbptt_logging(tmpdir, model_class): - """Test step-level and epoch-level logging works with TBPTT.""" - - class TBPTTModel(model_class): - def training_step(self, *args, **kwargs): - out = super().training_step(*args, **kwargs) - self.log("loss", out["loss"], on_step=True, on_epoch=True) - return out - - model = TBPTTModel(truncated_bptt_steps=2) - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=2, - log_every_n_steps=2, - enable_model_summary=False, - enable_checkpointing=False, - ) - trainer.fit(model) - assert set(trainer.logged_metrics) == {"loss_step", "loss_epoch"} - - -def test_hiddens_multiple_optimizers(tmpdir): - class TBPTTModel(LSTMModel): - # TODO: `optimizer_idx=n` gets the hiddens from `optimizer_idx=n-1` instead of the hidden from - # `optimizer_idx=n`, `split_idx=m-1`. This is unexpected and should be changed - test_hiddens = None - - def training_step(self, batch, batch_idx, optimizer_idx, hiddens): - if hiddens is None: - assert self.test_hiddens is None - else: - assert all(torch.equal(h, th) for h, th in zip(hiddens, self.test_hiddens)) - out = super().training_step(batch, batch_idx, hiddens) - self.test_hiddens = out["hiddens"] - return out - - def configure_optimizers(self): - return [super().configure_optimizers(), super().configure_optimizers()] - - model = TBPTTModel(truncated_bptt_steps=2, input_size=1, hidden_size=1) - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=1, - limit_val_batches=0, - enable_model_summary=False, - logger=False, - enable_checkpointing=False, - enable_progress_bar=False, - ) - trainer.fit(model) - assert trainer.global_step == 8 / 2 * 2 # time_dim_length / tbptt_steps * num_optimizers diff --git a/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py b/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py index 3aaced7dc1..b6e9757c83 100644 --- a/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py +++ b/tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py @@ -32,73 +32,56 @@ _out13 = {"loss": 1.3} class TestPrepareOutputs: - def prepare_outputs(self, fn, tbptt_splits, batch_outputs, num_optimizers, automatic_optimization): + def prepare_outputs(self, fn, batch_outputs, num_optimizers, automatic_optimization): lightning_module = LightningModule() lightning_module.automatic_optimization = automatic_optimization - lightning_module.truncated_bptt_steps = tbptt_splits return fn( batch_outputs, lightning_module=lightning_module, num_optimizers=num_optimizers, # does not matter for manual optimization ) - def prepare_outputs_training_epoch_end( - self, tbptt_splits, batch_outputs, num_optimizers, automatic_optimization=True - ): + def prepare_outputs_training_epoch_end(self, batch_outputs, num_optimizers, automatic_optimization=True): return self.prepare_outputs( TrainingEpochLoop._prepare_outputs_training_epoch_end, - tbptt_splits, batch_outputs, num_optimizers, automatic_optimization=automatic_optimization, ) - def prepare_outputs_training_batch_end( - self, tbptt_splits, batch_outputs, num_optimizers, automatic_optimization=True - ): + def prepare_outputs_training_batch_end(self, batch_outputs, num_optimizers, automatic_optimization=True): return self.prepare_outputs( TrainingEpochLoop._prepare_outputs_training_batch_end, - tbptt_splits, batch_outputs, num_optimizers, automatic_optimization=automatic_optimization, ) @pytest.mark.parametrize( - "num_optimizers,tbptt_splits,batch_outputs,expected", + "num_optimizers,batch_outputs,expected", [ - (1, 0, [], []), - (1, 0, [[]], []), + (1, [], []), + (1, [[]], []), # 1 batch - (1, 0, [[{0: _out00}]], [_out00]), + (1, [[{0: _out00}]], [_out00]), # 2 batches - (1, 0, [[{0: _out00}], [{0: _out01}]], [_out00, _out01]), + (1, [[{0: _out00}], [{0: _out01}]], [_out00, _out01]), # 1 batch, 2 optimizers - (2, 0, [[{0: _out00, 1: _out01}]], [_out00, _out01]), + (2, [[{0: _out00, 1: _out01}]], [_out00, _out01]), # 2 batches, 2 optimizers - (2, 0, [[{0: _out00, 1: _out01}], [{0: _out10, 1: _out11}]], [[_out00, _out01], [_out10, _out11]]), + (2, [[{0: _out00, 1: _out01}], [{0: _out10, 1: _out11}]], [[_out00, _out01], [_out10, _out11]]), # 4 batches, 2 optimizers, different frequency ( 2, - 0, [[{0: _out00}], [{1: _out10}], [{1: _out11}], [{0: _out01}]], [[_out00], [_out10], [_out11], [_out01]], ), - # 1 batch, tbptt with 2 splits (uneven) - (1, 2, [[{0: _out00}, {0: _out01}], [{0: _out03}]], [[_out00, _out01], [_out03]]), - # 3 batches, tbptt with 2 splits, 2 optimizers alternating - ( - 2, - 2, - [[{0: _out00}, {0: _out01}], [{1: _out10}, {1: _out11}], [{0: _out02}, {0: _out03}]], - [[[_out00], [_out01]], [[_out10], [_out11]], [[_out02], [_out03]]], - ), ], ) - def test_prepare_outputs_training_epoch_end_automatic(self, num_optimizers, tbptt_splits, batch_outputs, expected): + def test_prepare_outputs_training_epoch_end_automatic(self, num_optimizers, batch_outputs, expected): """Test that the loop converts the nested lists of outputs to the format that the `training_epoch_end` hook currently expects in the case of automatic optimization.""" - assert self.prepare_outputs_training_epoch_end(tbptt_splits, batch_outputs, num_optimizers) == expected + assert self.prepare_outputs_training_epoch_end(batch_outputs, num_optimizers) == expected @pytest.mark.parametrize( "batch_outputs,expected", @@ -111,37 +94,29 @@ class TestPrepareOutputs: ([[_out00], [_out01]], [_out00, _out01]), # skipped outputs ([[_out00], [], [], [_out03]], [_out00, _out03]), - # tbptt with 2 splits, uneven, skipped output - ([[_out00, _out01], [_out02, _out03], [], [_out10]], [[_out00, _out01], [_out02, _out03], [_out10]]), ], ) def test_prepare_outputs_training_epoch_end_manual(self, batch_outputs, expected): """Test that the loop converts the nested lists of outputs to the format that the `training_epoch_end` hook currently expects in the case of manual optimization.""" - assert self.prepare_outputs_training_epoch_end(0, batch_outputs, -1, automatic_optimization=False) == expected + assert self.prepare_outputs_training_epoch_end(batch_outputs, -1, automatic_optimization=False) == expected @pytest.mark.parametrize( - "num_optimizers,tbptt_splits,batch_end_outputs,expected", + "num_optimizers,batch_end_outputs,expected", [ - (1, 0, [], []), - (1, 0, [[]], []), + (1, [], []), + (1, [[]], []), # 1 optimizer - (1, 0, [{0: _out00}], _out00), + (1, [{0: _out00}], _out00), # 2 optimizers - (2, 0, [{0: _out00, 1: _out01}], [_out00, _out01]), - # tbptt with 2 splits - (1, 2, [{0: _out00}, {0: _out01}], [_out00, _out01]), - # 2 optimizers, tbptt with 2 splits - (2, 2, [{0: _out00, 1: _out01}, {0: _out10, 1: _out11}], [[_out00, _out01], [_out10, _out11]]), + (2, [{0: _out00, 1: _out01}], [_out00, _out01]), ], ) - def test_prepare_outputs_training_batch_end_automatic( - self, num_optimizers, tbptt_splits, batch_end_outputs, expected - ): + def test_prepare_outputs_training_batch_end_automatic(self, num_optimizers, batch_end_outputs, expected): """Test that the loop converts the nested lists of outputs to the format that the `on_train_batch_end` hook currently expects in the case of automatic optimization.""" - assert self.prepare_outputs_training_batch_end(tbptt_splits, batch_end_outputs, num_optimizers) == expected + assert self.prepare_outputs_training_batch_end(batch_end_outputs, num_optimizers) == expected @pytest.mark.parametrize( "batch_end_outputs,expected", @@ -150,16 +125,12 @@ class TestPrepareOutputs: ([[]], []), # skipped outputs ([_out00, None, _out02], [_out00, _out02]), - # tbptt with 3 splits, skipped output - ([_out00, _out01, None, _out03], [_out00, _out01, _out03]), ], ) def test_prepare_outputs_training_batch_end_manual(self, batch_end_outputs, expected): """Test that the loop converts the nested lists of outputs to the format that the `on_train_batch_end` hook currently expects in the case of manual optimization.""" - assert ( - self.prepare_outputs_training_batch_end(0, batch_end_outputs, -1, automatic_optimization=False) == expected - ) + assert self.prepare_outputs_training_batch_end(batch_end_outputs, -1, automatic_optimization=False) == expected def test_no_val_on_train_epoch_loop_restart(tmpdir): @@ -208,7 +179,7 @@ def test_should_stop_early_stopping_conditions_not_met( trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0) trainer.num_training_batches = 10 trainer.should_stop = True - trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = global_step + trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = global_step trainer.fit_loop.epoch_loop.batch_progress.current.ready = global_step trainer.fit_loop.epoch_progress.current.completed = current_epoch - 1 diff --git a/tests/tests_pytorch/loops/test_evaluation_loop_flow.py b/tests/tests_pytorch/loops/test_evaluation_loop_flow.py index fb3c73fd56..d244d6e08a 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop_flow.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop_flow.py @@ -65,15 +65,15 @@ def test__eval_step__flow(tmpdir): # simulate training manually trainer.state.stage = RunningStage.TRAINING kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0} - train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs) + train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs) assert len(train_step_out) == 1 - train_step_out = train_step_out[0][0] + train_step_out = train_step_out[0] assert isinstance(train_step_out["loss"], Tensor) assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0]) + opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0]) opt_closure_result = opt_closure() assert opt_closure_result.item() == 171 @@ -126,15 +126,15 @@ def test__eval_step__eval_step_end__flow(tmpdir): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0} - train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs) + train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs) assert len(train_step_out) == 1 - train_step_out = train_step_out[0][0] + train_step_out = train_step_out[0] assert isinstance(train_step_out["loss"], Tensor) assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0]) + opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0]) opt_closure_result = opt_closure() assert opt_closure_result.item() == 171 diff --git a/tests/tests_pytorch/loops/test_loop_state_dict.py b/tests/tests_pytorch/loops/test_loop_state_dict.py index 0c5c637eea..72d846e535 100644 --- a/tests/tests_pytorch/loops/test_loop_state_dict.py +++ b/tests/tests_pytorch/loops/test_loop_state_dict.py @@ -52,14 +52,13 @@ def test_loops_state_dict_structure(): "total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}, }, - "epoch_loop.batch_loop.state_dict": {}, - "epoch_loop.batch_loop.manual_loop.state_dict": {}, - "epoch_loop.batch_loop.manual_loop.optim_step_progress": { + "epoch_loop.manual_loop.state_dict": {}, + "epoch_loop.manual_loop.optim_step_progress": { "total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}, }, - "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, - "epoch_loop.batch_loop.optimizer_loop.optim_progress": { + "epoch_loop.optimizer_loop.state_dict": {}, + "epoch_loop.optimizer_loop.optim_progress": { "optimizer": { "step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}}, "zero_grad": { diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index ce9b5ab704..b7084a0c72 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -25,7 +25,7 @@ from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoad from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback, ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset -from pytorch_lightning.loops import EvaluationLoop, Loop, TrainingBatchLoop, TrainingEpochLoop +from pytorch_lightning.loops import EvaluationLoop, Loop, OptimizerLoop, TrainingEpochLoop from pytorch_lightning.trainer.progress import BaseProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.runif import RunIf @@ -109,15 +109,15 @@ def test_connect_subloops(tmpdir): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) epoch_loop = trainer.fit_loop.epoch_loop - new_batch_loop = TrainingBatchLoop() - epoch_loop.connect(batch_loop=new_batch_loop) - assert epoch_loop.batch_loop is new_batch_loop + new_optimizer_loop = OptimizerLoop() + epoch_loop.connect(optimizer_loop=new_optimizer_loop) + assert epoch_loop.optimizer_loop is new_optimizer_loop with pytest.raises(RuntimeError, match="The loop is not attached to a Trainer"): - _ = new_batch_loop.trainer + _ = new_optimizer_loop.trainer trainer.fit(model) - assert new_batch_loop.trainer is trainer + assert new_optimizer_loop.trainer is trainer def test_replace_loops(): @@ -144,22 +144,22 @@ def test_replace_loops(): assert trainer.fit_loop.epoch_loop is new_loop assert new_loop.min_steps == 123 assert new_loop.max_steps == 321 - assert new_loop.batch_loop is old_loop.batch_loop + assert new_loop.optimizer_loop is old_loop.optimizer_loop assert new_loop.val_loop is old_loop.val_loop assert new_loop.trainer is trainer - class MyBatchLoop(TrainingBatchLoop): + class MyOptimizerLoop(OptimizerLoop): ... class MyEvalLoop(EvaluationLoop): ... # test passing more than one where one is an instance and the other a class - trainer.fit_loop.epoch_loop.replace(batch_loop=MyBatchLoop, val_loop=MyEvalLoop()) - new_batch_loop = trainer.fit_loop.epoch_loop.batch_loop + trainer.fit_loop.epoch_loop.replace(optimizer_loop=MyOptimizerLoop, val_loop=MyEvalLoop()) + new_optimizer_loop = trainer.fit_loop.epoch_loop.optimizer_loop new_val_loop = trainer.fit_loop.epoch_loop.val_loop - assert isinstance(new_batch_loop, MyBatchLoop) + assert isinstance(new_optimizer_loop, MyOptimizerLoop) assert isinstance(new_val_loop, MyEvalLoop) @@ -436,7 +436,7 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch assert os.path.exists(ckpt_path) checkpoint = torch.load(ckpt_path) - optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress + optim_progress = trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress sch_progress = trainer.fit_loop.epoch_loop.scheduler_progress # `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch @@ -510,14 +510,13 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch "total": {"ready": nbe_sch_steps + be_sch_steps, "completed": nbe_sch_steps + be_sch_steps}, "current": {"ready": be_sch_steps, "completed": be_sch_steps}, }, - "epoch_loop.batch_loop.state_dict": ANY, - "epoch_loop.batch_loop.manual_loop.state_dict": ANY, - "epoch_loop.batch_loop.manual_loop.optim_step_progress": { + "epoch_loop.manual_loop.state_dict": ANY, + "epoch_loop.manual_loop.optim_step_progress": { "total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}, }, - "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, - "epoch_loop.batch_loop.optimizer_loop.optim_progress": { + "epoch_loop.optimizer_loop.state_dict": {}, + "epoch_loop.optimizer_loop.optim_progress": { "optimizer_position": stop_optimizer, "optimizer": { "step": { @@ -563,8 +562,8 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch # test resetting manually, we expect all `ready` counters to be reset to `completed` trainer.fit_loop.reset() trainer.fit_loop.epoch_loop.reset() - trainer.fit_loop.epoch_loop.batch_loop.reset() - trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.reset() + trainer.fit_loop.epoch_loop.optimizer_loop.reset() + trainer.fit_loop.epoch_loop.manual_loop.reset() epoch_progress = trainer.fit_loop.epoch_progress assert epoch_progress.current.ready == stop_epoch @@ -574,7 +573,7 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch assert batch_progress.current.ready == be_batches_completed assert batch_progress.current.completed == be_batches_completed - optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress + optim_progress = trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress assert optim_progress.optimizer.step.current.ready == be_total_opt_steps assert optim_progress.optimizer.step.current.completed == be_total_opt_steps assert optim_progress.optimizer.zero_grad.current.ready == be_total_zero_grad @@ -677,14 +676,13 @@ def test_loop_state_on_complete_run(n_optimizers, tmpdir): "total": {"ready": n_sch_steps_total, "completed": n_sch_steps_total}, "current": {"ready": n_sch_steps_current, "completed": n_sch_steps_current}, }, - "epoch_loop.batch_loop.state_dict": ANY, - "epoch_loop.batch_loop.manual_loop.state_dict": ANY, - "epoch_loop.batch_loop.manual_loop.optim_step_progress": { + "epoch_loop.manual_loop.state_dict": ANY, + "epoch_loop.manual_loop.optim_step_progress": { "total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}, }, - "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, - "epoch_loop.batch_loop.optimizer_loop.optim_progress": { + "epoch_loop.optimizer_loop.state_dict": {}, + "epoch_loop.optimizer_loop.optim_progress": { "optimizer_position": n_optimizers, "optimizer": { "step": { @@ -746,7 +744,7 @@ def test_fit_loop_reset(tmpdir): mid_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=2.ckpt")) fit_loop = trainer.fit_loop epoch_loop = fit_loop.epoch_loop - optimizer_loop = epoch_loop.batch_loop.optimizer_loop + optimizer_loop = epoch_loop.optimizer_loop assert not fit_loop.restarting assert not epoch_loop.restarting assert not optimizer_loop.restarting diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index 95b60f99b7..8cc06315e3 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -217,7 +217,7 @@ def test_should_stop_early_stopping_conditions_met( trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0, max_epochs=100) trainer.num_training_batches = 10 trainer.should_stop = True - trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = ( + trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = ( current_epoch * trainer.num_training_batches ) trainer.fit_loop.epoch_loop.batch_progress.current.ready = 10 diff --git a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py index 087b9f953d..d9dd5fc341 100644 --- a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py +++ b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py @@ -147,15 +147,15 @@ def test__training_step__epoch_end__flow_scalar(tmpdir): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0} - train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs) + train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs) assert len(train_step_out) == 1 - train_step_out = train_step_out[0][0] + train_step_out = train_step_out[0] assert isinstance(train_step_out["loss"], Tensor) assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0]) + opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0]) opt_closure_result = opt_closure() assert opt_closure_result.item() == 171 @@ -217,15 +217,15 @@ def test__training_step__step_end__epoch_end__flow_scalar(tmpdir): trainer.state.stage = RunningStage.TRAINING # make sure training outputs what is expected kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0} - train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs) + train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs) assert len(train_step_out) == 1 - train_step_out = train_step_out[0][0] + train_step_out = train_step_out[0] assert isinstance(train_step_out["loss"], Tensor) assert train_step_out["loss"].item() == 171 # make sure the optimizer closure returns the correct things - opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0]) + opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0]) opt_closure_result = opt_closure() assert opt_closure_result.item() == 171 @@ -301,9 +301,10 @@ def test_training_step_no_return_when_even(tmpdir): # manually check a few batches for batch_idx, batch in enumerate(model.train_dataloader()): - out = trainer.fit_loop.epoch_loop.batch_loop.run({"batch": batch, "batch_idx": batch_idx}) + kwargs = {"batch": batch, "batch_idx": batch_idx} + out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs) if not batch_idx % 2: - assert out == [] + assert out == {} def test_training_step_none_batches(tmpdir): diff --git a/tests/tests_pytorch/loops/test_utilities.py b/tests/tests_pytorch/loops/test_utilities.py index 2bd86d3258..21c00ee6e0 100644 --- a/tests/tests_pytorch/loops/test_utilities.py +++ b/tests/tests_pytorch/loops/test_utilities.py @@ -13,32 +13,7 @@ # limitations under the License. from unittest.mock import Mock -import pytest -import torch - -from pytorch_lightning.loops.utilities import _extract_hiddens, _set_sampler_epoch -from pytorch_lightning.utilities.exceptions import MisconfigurationException - - -def test_extract_hiddens(): - # tbptt not enabled, no hiddens return - training_step_output = 1 # anything - hiddens = _extract_hiddens(training_step_output, 0) - assert hiddens is None - - # tbptt enabled, hiddens return - hiddens = torch.tensor(321.12, requires_grad=True) - training_step_output = {"hiddens": hiddens} - hiddens = _extract_hiddens(training_step_output, 2) - assert "hiddens" in training_step_output - assert not hiddens.requires_grad - - # tbptt not enabled, hiddens return - with pytest.raises(MisconfigurationException, match='returned "hiddens" .* but `truncated_bptt_steps` is disabled'): - _extract_hiddens(training_step_output, 0) - # tbptt enabled, no hiddens return - with pytest.raises(MisconfigurationException, match="enabled `truncated_bptt_steps` but did not `return"): - _extract_hiddens(None, 1) +from pytorch_lightning.loops.utilities import _set_sampler_epoch def test_set_sampler_epoch(): diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 598be1a9a2..ad61d0bbe6 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -353,7 +353,7 @@ def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files) # emulate callback's calls during the training for i, loss in enumerate(losses, 1): # sets `trainer.global_step` - trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = i + trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = i trainer.callback_metrics.update({"checkpoint_on": torch.tensor(loss)}) checkpoint_callback.on_validation_end(trainer, trainer.lightning_module) trainer.fit_loop.epoch_progress.current.completed = i # sets `trainer.current_epoch` diff --git a/tests/tests_pytorch/utilities/migration/test_migration.py b/tests/tests_pytorch/utilities/migration/test_migration.py index f77c4ea3d2..11c804be91 100644 --- a/tests/tests_pytorch/utilities/migration/test_migration.py +++ b/tests/tests_pytorch/utilities/migration/test_migration.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import MagicMock + import pytest import torch @@ -156,3 +158,37 @@ def test_migrate_dropped_apex_amp_state(monkeypatch): with pytest.warns(UserWarning, match="checkpoint contains apex AMP data"): updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy()) assert "amp_scaling_state" not in updated_checkpoint + + +def test_migrate_loop_structure_after_tbptt_removal(): + """Test the loop state migration after truncated backpropagation support was removed in 2.0.0, and with it the + training batch loop.""" + # automatic- and manual optimization state are combined into a single checkpoint to simplify testing + state_automatic = MagicMock() + state_manual = MagicMock() + optim_progress_automatic = MagicMock() + optim_progress_manual = MagicMock() + old_batch_loop_state = MagicMock() + old_checkpoint = { + "loops": { + "fit_loop": { + "epoch_loop.state_dict": {"any": "state"}, + "epoch_loop.batch_loop.state_dict": old_batch_loop_state, + "epoch_loop.batch_loop.optimizer_loop.state_dict": state_automatic, + "epoch_loop.batch_loop.optimizer_loop.optim_progress": optim_progress_automatic, + "epoch_loop.batch_loop.manual_loop.state_dict": state_manual, + "epoch_loop.batch_loop.manual_loop.optim_step_progress": optim_progress_manual, + } + } + } + _set_version(old_checkpoint, "1.8.0") # pretend a checkpoint prior to 2.0.0 + updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="2.0.0") + assert updated_checkpoint["loops"] == { + "fit_loop": { + "epoch_loop.state_dict": {"any": "state", "old_batch_loop_state_dict": old_batch_loop_state}, + "epoch_loop.optimizer_loop.state_dict": state_automatic, + "epoch_loop.optimizer_loop.optim_progress": optim_progress_automatic, + "epoch_loop.manual_loop.state_dict": state_manual, + "epoch_loop.manual_loop.optim_step_progress": optim_progress_manual, + } + } diff --git a/tests/tests_pytorch/utilities/test_fetching.py b/tests/tests_pytorch/utilities/test_fetching.py index e5ce8ca2d5..7bebd02a5e 100644 --- a/tests/tests_pytorch/utilities/test_fetching.py +++ b/tests/tests_pytorch/utilities/test_fetching.py @@ -447,21 +447,6 @@ def test_on_train_batch_end_overridden(tmpdir) -> None: trainer.fit(m) -def test_tbptt_split_batch_overridden(tmpdir) -> None: - """Verify that a `MisconfigurationException` is raised when `tbptt_split_batch` is overridden on the - `LightningModule`.""" - - class InvalidModel(AsyncBoringModel): - def __init__(self) -> None: - super().__init__() - self.truncated_bptt_steps = 2 - - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) - m = InvalidModel() - with pytest.raises(MisconfigurationException, match="is incompatible with `truncated_bptt_steps > 0`."): - trainer.fit(m) - - def test_transfer_hooks_with_unpacking(tmpdir): """This test asserts the `transfer_batch` hooks are called only once per batch."""