Remove truncated backpropagation from loops (#16337)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Adrian Wälchli 2023-01-19 18:10:41 +01:00 committed by Luca Antiga
parent 92a922ccd4
commit 03a699693b
38 changed files with 248 additions and 844 deletions

View File

@ -126,7 +126,6 @@ Training
:nosignatures:
:template: classtemplate.rst
~batch.TrainingBatchLoop
~epoch.TrainingEpochLoop
FitLoop
~optimization.ManualOptimization

View File

@ -1035,7 +1035,7 @@ global_step
~~~~~~~~~~~
The number of optimizer steps taken (does not reset each epoch).
This includes multiple optimizers and TBPTT steps (if enabled).
This includes multiple optimizers (if enabled).
.. code-block:: python
@ -1195,79 +1195,6 @@ Set and access example_input_array, which basically represents a single batch.
# generate some images using the example_input_array
gen_images = self.generator(self.example_input_array)
truncated_bptt_steps
~~~~~~~~~~~~~~~~~~~~
Truncated Backpropagation Through Time (TBPTT) performs perform backpropogation every k steps of
a much longer sequence. This is made possible by passing training batches
split along the time-dimensions into splits of size k to the
``training_step``. In order to keep the same forward propagation behavior, all
hidden states should be kept in-between each time-dimension split.
If this is enabled, your batches will automatically get truncated
and the Trainer will apply Truncated Backprop to it.
(`Williams et al. "An efficient gradient-based algorithm for on-line training of
recurrent network trajectories."
<https://ieeexplore.ieee.org/document/6797135>`_)
`Tutorial <https://d2l.ai/chapter_recurrent-neural-networks/bptt.html>`_
.. testcode:: python
from pytorch_lightning import LightningModule
class MyModel(LightningModule):
def __init__(self, input_size, hidden_size, num_layers):
super().__init__()
# batch_first has to be set to True
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
)
...
# Important: This property activates truncated backpropagation through time
# Setting this value to 2 splits the batch into sequences of size 2
self.truncated_bptt_steps = 2
# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
x, y = batch
# the training step must be updated to accept a ``hiddens`` argument
# hiddens are the hiddens from the previous truncated backprop step
out, hiddens = self.lstm(x, hiddens)
...
return {"loss": ..., "hiddens": hiddens}
Lightning takes care of splitting your batch along the time-dimension. It is
assumed to be the second dimension of your batches. Therefore, in the
example above, we have set ``batch_first=True``.
.. code-block:: python
# we use the second as the time dimension
# (batch, time, ...)
sub_batch = batch[0, 0:t, ...]
To modify how the batch is split,
override the :meth:`pytorch_lightning.core.module.LightningModule.tbptt_split_batch` method:
.. testcode:: python
class LitMNIST(LightningModule):
def tbptt_split_batch(self, batch, split_size):
# do your own splitting on the batch
return splits
--------------
.. _lightning_hooks:
@ -1636,12 +1563,6 @@ setup
.. automethod:: pytorch_lightning.core.module.LightningModule.setup
:noindex:
tbptt_split_batch
~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.core.module.LightningModule.tbptt_split_batch
:noindex:
teardown
~~~~~~~~

View File

@ -1418,7 +1418,7 @@ global_step
***********
The number of optimizer steps taken (does not reset each epoch).
This includes multiple optimizers and TBPTT steps (if enabled).
This includes multiple optimizers (if enabled).
.. code-block:: python

View File

@ -306,14 +306,11 @@ Here is what the structure would look like in plain Python:
# TrainingEpochLoop
for batch_idx, batch in enumerate(train_dataloader):
# TrainingBatchLoop
for split_batch in tbptt_split(batch):
# OptimizerLoop
for optimizer_idx, opt in enumerate(optimizers):
# OptimizerLoop
for optimizer_idx, opt in enumerate(optimizers):
loss = lightning_module.training_step(batch, batch_idx, optimizer_idx)
...
loss = lightning_module.training_step(batch, batch_idx, optimizer_idx)
...
# ValidationEpochLoop
for batch_idx, batch in enumerate(val_dataloader):
@ -339,13 +336,8 @@ Each of these :code:`for`-loops represents a class implementing the :class:`~pyt
The validation is carried out by yet another loop, :class:`~pytorch_lightning.loops.epoch.validation_epoch_loop.ValidationEpochLoop`.
In the :code:`run()` method, the training epoch loop could in theory simply call the :code:`LightningModule.training_step` already and perform the optimization.
However, Lightning has built-in support for automatic optimization with multiple optimizers and on top of that also supports :ref:`TBPTT <sequential-data>`.
However, Lightning has built-in support for automatic optimization with multiple optimizers.
For this reason there are actually two more loops nested under :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop`.
* - :class:`~pytorch_lightning.loops.batch.training_batch_loop.TrainingBatchLoop`
- The responsibility of the :class:`~pytorch_lightning.loops.batch.training_batch_loop.TrainingBatchLoop` is to split a batch given by the :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop` along the time-dimension and iterate over the list of splits.
It also keeps track of the hidden state *hiddens* returned by the training step.
By default, when truncated back-propagation through time (TBPTT) is turned off, this loop does not do anything except redirect the call to the :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop`.
Read more about :ref:`TBPTT <sequential-data>`.
* - :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop`
- The :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop` iterates over one or multiple optimizers and for each one it calls the :meth:`~pytorch_lightning.core.module.LightningModule.training_step` method with the batch, the current batch index and the optimizer index if multiple optimizers are requested.
It is the leaf node in the tree of loops and performs the actual optimization (forward, zero grad, backward, optimizer step).

View File

@ -343,38 +343,6 @@ When using :class:`~torch.nn.utils.rnn.PackedSequence`, do two things:
x = rnn.pack_sequence(batch[0], enforce_sorted=False)
y = rnn.pack_sequence(batch[1], enforce_sorted=False)
Truncated Backpropagation Through Time (TBPTT)
==============================================
There are times when multiple backwards passes are needed for each batch.
For example, it may save memory to use **Truncated Backpropagation Through Time** when training RNNs.
Lightning can handle TBPTT automatically via this flag.
.. testcode::
from pytorch_lightning import LightningModule
class MyModel(LightningModule):
def __init__(self):
super().__init__()
# Important: This property activates truncated backpropagation through time
# Setting this value to 2 splits the batch into sequences of size 2
self.truncated_bptt_steps = 2
# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
# the training step must be updated to accept a ``hiddens`` argument
# hiddens are the hiddens from the previous truncated backprop step
out, hiddens = self.lstm(data, hiddens)
return {"loss": ..., "hiddens": hiddens}
.. note:: If you need to modify how the batch is split,
override :func:`~pytorch_lightning.core.module.LightningModule.tbptt_split_batch`.
Iterable Datasets
=================
Lightning supports using :class:`~torch.utils.data.IterableDataset` as well as map-style Datasets. IterableDatasets provide a more natural

View File

@ -74,7 +74,7 @@ class YieldLoop(OptimizerLoop):
return partial(self._training_step, self._generator)
def _get_generator(self, kwargs, opt_idx=0):
kwargs = self._build_kwargs(kwargs, opt_idx, hiddens=None)
kwargs = self._build_kwargs(kwargs, opt_idx)
# Here we are basically calling `lightning_module.training_step()`
# and this returns a generator! The `training_step` is handled by
@ -285,8 +285,8 @@ class GAN(LightningModule):
#############################################################################################
# Step 3 / 3: Connect the loop to the Trainer #
# #
# Finally, attach the loop to the `Trainer`. Here, we modified the `AutomaticOptimization` #
# loop which is a subloop of the `TrainingBatchLoop`. We use `.connect()` to attach it. #
# Finally, attach the loop to the `Trainer`. Here, we modified the `OptimizerLoop` #
# loop which is a subloop of the `TrainingEpochLoop`. We use `.connect()` to attach it. #
#############################################################################################
if __name__ == "__main__":
@ -296,7 +296,7 @@ if __name__ == "__main__":
# Connect the new loop
# YieldLoop now replaces the previous optimizer loop
trainer.fit_loop.epoch_loop.batch_loop.connect(optimizer_loop=YieldLoop())
trainer.fit_loop.epoch_loop.connect(optimizer_loop=YieldLoop())
# fit() will now use the new loop!
trainer.fit(model, dm)

View File

@ -106,7 +106,6 @@ class LightningModuleVisitor(LightningVisitor):
"optimizer_zero_grad",
"prepare_data",
"setup",
"tbptt_split_batch",
"teardown",
"train_dataloader",
"val_dataloader",
@ -256,10 +255,6 @@ class TorchMetricVisitor(LightningVisitor):
class_name = "Metric"
class LightningLiteVisitor(LightningVisitor): # deprecated
class_name = "LightningLite"
class FabricVisitor(LightningVisitor):
class_name = "Fabric"
@ -297,7 +292,6 @@ class Scanner:
LightningLoggerVisitor,
LightningLoopVisitor,
TorchMetricVisitor,
LightningLiteVisitor, # deprecated
FabricVisitor,
LightningProfilerVisitor,
]

View File

@ -56,6 +56,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Removed the `Trainer(auto_select_gpus=...)` argument
* Removed the `pytorch_lightning.tuner.auto_gpu_select.{pick_single_gpu,pick_multiple_gpus}` functions
- Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172))
* Removed the `LightningModule.truncated_bptt_steps` attribute
* Removed the `LightningModule.tbptt_split_batch` hook
* The `LightningModule.training_step` no longer accepts a `hiddens` argument
* Removed the `pytorch_lightning.loops.batch.TrainingBatchLoop`
* Removed the `FitLoop.split_idx` property
* Removed the `LoggerConnector.on_train_split_start` method
### Fixed

View File

@ -279,9 +279,6 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule")
if avg_training_loss is not None:
items_dict["loss"] = f"{avg_training_loss:.3g}"
if pl_module.truncated_bptt_steps > 0:
items_dict["split_idx"] = trainer.fit_loop.split_idx
if trainer.loggers:
version = _version(trainer.loggers)
if version is not None:

View File

@ -13,7 +13,6 @@
# limitations under the License.
"""The LightningModule - an nn.Module with many additional features."""
import collections.abc
import logging
import numbers
import weakref
@ -89,7 +88,6 @@ class LightningModule(
"logger",
"loggers",
"automatic_optimization",
"truncated_bptt_steps",
"trainer",
"fabric",
]
@ -115,7 +113,6 @@ class LightningModule(
self._example_input_array: Optional[Union[Tensor, Tuple, Dict]] = None
self._current_fx_name: Optional[str] = None
self._automatic_optimization: bool = True
self._truncated_bptt_steps: int = 0
self._param_requires_grad_state: Dict[str, bool] = {}
self._metric_attributes: Optional[Dict[int, str]] = None
self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False
@ -275,20 +272,6 @@ class LightningModule(
def automatic_optimization(self, automatic_optimization: bool) -> None:
self._automatic_optimization = automatic_optimization
@property
def truncated_bptt_steps(self) -> int:
"""Enables `Truncated Backpropagation Through Time` in the Trainer when set to a positive integer.
It represents
the number of times :meth:`training_step` gets called before backpropagation. If this is > 0, the
:meth:`training_step` receives an additional argument ``hiddens`` and is expected to return a hidden state.
"""
return self._truncated_bptt_steps
@truncated_bptt_steps.setter
def truncated_bptt_steps(self, truncated_bptt_steps: int) -> None:
self._truncated_bptt_steps = truncated_bptt_steps
@property
def logger(self) -> Optional[Union[Logger, FabricLogger]]:
"""Reference to the logger object in the Trainer."""
@ -683,8 +666,6 @@ class LightningModule(
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (``int``): Integer displaying index of this batch
optimizer_idx (``int``): When using multiple optimizers, this argument will also be present.
hiddens (``Any``): Passed in if
:paramref:`~pytorch_lightning.core.module.LightningModule.truncated_bptt_steps` > 0.
Return:
Any of.
@ -719,19 +700,6 @@ class LightningModule(
# do training_step with decoder
...
If you add truncated back propagation through time you will also get an additional
argument with the hidden states of the previous step.
.. code-block:: python
# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
# hiddens are the hidden states from the previous truncated backprop step
out, hiddens = self.lstm(data, hiddens)
loss = ...
return {"loss": loss, "hiddens": hiddens}
Note:
The loss value shown in the progress bar is smoothed (averaged) over the last values,
so it differs from the actual loss returned in train/validation step.
@ -817,9 +785,8 @@ class LightningModule(
training_epoch_end(train_outs)
Args:
outputs: List of outputs you defined in :meth:`training_step`. If there are multiple optimizers or when
using ``truncated_bptt_steps > 0``, the lists have the dimensions
(n_batches, tbptt_steps, n_optimizers). Dimensions of length 1 are squeezed.
outputs: List of outputs you defined in :meth:`training_step`. If there are multiple optimizers, the lists
have the dimensions (n_batches, n_optimizers). Dimensions of length 1 are squeezed.
Return:
None
@ -1764,64 +1731,6 @@ class LightningModule(
"""
optimizer.zero_grad()
def tbptt_split_batch(self, batch: Any, split_size: int) -> List[Any]:
r"""
When using truncated backpropagation through time, each batch must be split along the
time dimension. Lightning handles this by default, but for custom behavior override
this function.
Args:
batch: Current batch
split_size: The size of the split
Return:
List of batch splits. Each split will be passed to :meth:`training_step` to enable truncated
back propagation through time. The default implementation splits root level Tensors and
Sequences at dim=1 (i.e. time dim). It assumes that each time dim is the same length.
Examples::
def tbptt_split_batch(self, batch, split_size):
splits = []
for t in range(0, time_dims[0], split_size):
batch_split = []
for i, x in enumerate(batch):
if isinstance(x, torch.Tensor):
split_x = x[:, t:t + split_size]
elif isinstance(x, collections.abc.Sequence):
split_x = [None] * len(x)
for batch_idx in range(len(x)):
split_x[batch_idx] = x[batch_idx][t:t + split_size]
batch_split.append(split_x)
splits.append(batch_split)
return splits
Note:
Called in the training loop after
:meth:`~pytorch_lightning.callbacks.base.Callback.on_train_batch_start`
if :paramref:`~pytorch_lightning.core.module.LightningModule.truncated_bptt_steps` > 0.
Each returned batch split is passed separately to :meth:`training_step`.
"""
time_dims = [len(x[0]) for x in batch if isinstance(x, (Tensor, collections.abc.Sequence))]
assert len(time_dims) >= 1, "Unable to determine batch time dimension"
assert all(x == time_dims[0] for x in time_dims), "Batch time dimension length is ambiguous"
splits = []
for t in range(0, time_dims[0], split_size):
batch_split = []
for i, x in enumerate(batch):
split_x: Union[Tensor, List[Tensor]]
if isinstance(x, Tensor):
split_x = x[:, t : t + split_size]
elif isinstance(x, collections.abc.Sequence):
split_x = [x[batch_idx][t : t + split_size] for batch_idx in range(len(x))]
batch_split.append(split_x)
splits.append(batch_split)
return splits
def freeze(self) -> None:
r"""
Freeze all params for inference.

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.loops.loop import Loop # noqa: F401 isort: skip (avoids circular imports)
from pytorch_lightning.loops.batch import TrainingBatchLoop # noqa: F401
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401

View File

@ -1,16 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401
from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization # noqa: F401

View File

@ -1,140 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union
from torch import Tensor
from typing_extensions import OrderedDict
from pytorch_lightning.loops.loop import Loop
from pytorch_lightning.loops.optimization.manual_loop import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE
from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization
from pytorch_lightning.loops.optimization.optimizer_loop import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop
from pytorch_lightning.loops.utilities import _get_active_optimizers
from pytorch_lightning.trainer.supporters import TensorRunningAccum
_OUTPUTS_TYPE = List[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]]
class TrainingBatchLoop(Loop[_OUTPUTS_TYPE]):
"""Runs over a single batch of data."""
def __init__(self) -> None:
super().__init__()
self.accumulated_loss = TensorRunningAccum(window_length=20)
self.running_loss = TensorRunningAccum(window_length=20)
# the current split index when the batch gets split into chunks in truncated backprop through time
self.split_idx: int = 0
self.optimizer_loop = OptimizerLoop()
self.manual_loop = ManualOptimization()
self._outputs: _OUTPUTS_TYPE = []
self._remaining_splits: List[Tuple[int, Any]] = []
@property
def done(self) -> bool:
"""Returns if all batch splits have been processed already."""
return len(self._remaining_splits) == 0
def connect( # type: ignore[override]
self, optimizer_loop: Optional[OptimizerLoop] = None, manual_loop: Optional[ManualOptimization] = None
) -> None:
if optimizer_loop is not None:
self.optimizer_loop = optimizer_loop
if manual_loop is not None:
self.manual_loop = manual_loop
def reset(self) -> None:
"""Resets the loop state."""
self._outputs = []
def on_run_start(self, kwargs: OrderedDict) -> None:
"""Splits the data into tbptt splits.
Args:
kwargs: the kwargs passed down to the hooks.
"""
batch = kwargs["batch"]
self._remaining_splits = list(enumerate(self._tbptt_split_batch(batch)))
def advance(self, kwargs: OrderedDict) -> None:
"""Runs the train step together with optimization (if necessary) on the current batch split.
Args:
kwargs: the kwargs passed down to the hooks.
"""
# replace the batch with the split batch
self.split_idx, kwargs["batch"] = self._remaining_splits.pop(0)
self.trainer._logger_connector.on_train_split_start(self.split_idx)
outputs: Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] = None # for mypy
# choose which loop will run the optimization
if self.trainer.lightning_module.automatic_optimization:
optimizers = _get_active_optimizers(
self.trainer.optimizers, self.trainer.optimizer_frequencies, kwargs.get("batch_idx", 0)
)
outputs = self.optimizer_loop.run(optimizers, kwargs)
else:
outputs = self.manual_loop.run(kwargs)
if outputs:
# automatic: can be empty if all optimizers skip their batches
# manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens,
# then `advance` doesn't finish and an empty dict is returned
self._outputs.append(outputs)
def on_run_end(self) -> _OUTPUTS_TYPE:
self.optimizer_loop._hiddens = None
# this is not necessary as the manual loop runs for only 1 iteration, but just in case
self.manual_loop._hiddens = None
output, self._outputs = self._outputs, [] # free memory
self._remaining_splits = []
return output
def teardown(self) -> None:
self.optimizer_loop.teardown()
self.manual_loop.teardown()
# release memory
if self.accumulated_loss.memory is not None:
self.accumulated_loss.memory = self.accumulated_loss.memory.cpu()
if self.running_loss.memory is not None:
self.running_loss.memory = self.running_loss.memory.cpu()
def _tbptt_split_batch(self, batch: Any) -> List[Any]:
"""Splits a single batch into a list of sequence steps for tbptt.
Args:
batch: the current batch to split
"""
tbptt_steps = self.trainer.lightning_module.truncated_bptt_steps
if tbptt_steps == 0:
return [batch]
splits = self.trainer._call_lightning_module_hook("tbptt_split_batch", batch, tbptt_steps)
return splits
def _update_running_loss(self, current_loss: Tensor) -> None:
"""Updates the running loss value with the current value."""
if self.trainer.lightning_module.automatic_optimization:
# track total loss for logging (avoid mem leaks)
self.accumulated_loss.append(current_loss)
accumulated_loss = self.accumulated_loss.mean()
if accumulated_loss is not None:
# calculate running loss for display
self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches)
# reset for next set of accumulated grads
self.accumulated_loss.reset()

View File

@ -18,15 +18,17 @@ from typing import Any, DefaultDict, Dict, Generator, List, Optional, overload,
import numpy as np
import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
import pytorch_lightning as pl
from pytorch_lightning import loops # import as loops to avoid circular imports
from pytorch_lightning.loops.batch import TrainingBatchLoop
from pytorch_lightning.loops.batch.training_batch_loop import _OUTPUTS_TYPE as _BATCH_OUTPUTS_TYPE
from pytorch_lightning.loops.optimization import ManualOptimization, OptimizerLoop
from pytorch_lightning.loops.optimization.manual_loop import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE
from pytorch_lightning.loops.optimization.optimizer_loop import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE
from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.trainer.supporters import CombinedLoader, TensorRunningAccum
from pytorch_lightning.utilities.auto_restart import _collect_states_on_rank_zero_over_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
@ -34,6 +36,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_warn, WarningCache
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
_BATCH_OUTPUTS_TYPE = Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]]
_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
@ -57,7 +60,11 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
self.batch_progress = BatchProgress()
self.scheduler_progress = SchedulerProgress()
self.batch_loop = TrainingBatchLoop()
self.accumulated_loss = TensorRunningAccum(window_length=20)
self.running_loss = TensorRunningAccum(window_length=20)
self.optimizer_loop = OptimizerLoop()
self.manual_loop = ManualOptimization()
self.val_loop = loops.EvaluationLoop(verbose=False)
self._results = _ResultCollection(training=True)
@ -85,8 +92,8 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
def global_step(self) -> int:
lightning_module = self.trainer.lightning_module
if lightning_module is None or lightning_module.automatic_optimization:
return self.batch_loop.optimizer_loop.optim_progress.optimizer_steps
return self.batch_loop.manual_loop.optim_step_progress.total.completed
return self.optimizer_loop.optim_progress.optimizer_steps
return self.manual_loop.optim_step_progress.total.completed
@property
def _is_training_done(self) -> bool:
@ -119,12 +126,15 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
def connect( # type: ignore[override]
self,
batch_loop: Optional[TrainingBatchLoop] = None,
optimizer_loop: Optional[OptimizerLoop] = None,
manual_loop: Optional[ManualOptimization] = None,
val_loop: Optional["loops.EvaluationLoop"] = None,
) -> None:
"""Optionally connect a custom batch or validation loop to this training epoch loop."""
if batch_loop is not None:
self.batch_loop = batch_loop
if optimizer_loop is not None:
self.optimizer_loop = optimizer_loop
if manual_loop is not None:
self.manual_loop = manual_loop
if val_loop is not None:
self.val_loop = val_loop
@ -133,7 +143,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
if self.restarting:
self.batch_progress.reset_on_restart()
self.scheduler_progress.reset_on_restart()
self.batch_loop.optimizer_loop.optim_progress.reset_on_restart()
self.optimizer_loop.optim_progress.reset_on_restart()
trainer = self.trainer
if not trainer.state._fault_tolerant_mode.is_enabled and trainer.num_training_batches != float("inf"):
@ -148,7 +158,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
else:
self.batch_progress.reset_on_run()
self.scheduler_progress.reset_on_run()
self.batch_loop.optimizer_loop.optim_progress.reset_on_run()
self.optimizer_loop.optim_progress.reset_on_run()
# when the epoch starts, the total val batch progress should be reset as it's supposed to count the batches
# seen per epoch, this is useful for tracking when validation is run multiple times per epoch
self.val_loop.epoch_loop.batch_progress.total.reset()
@ -195,9 +205,9 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
self.trainer._logger_connector.on_batch_start(batch, batch_idx)
batch_output: _BATCH_OUTPUTS_TYPE = None # for mypy
if batch is None:
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
batch_output = []
else:
# hook
self.trainer._call_callback_hooks("on_train_batch_start", batch, batch_idx)
@ -210,7 +220,14 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
self.batch_progress.increment_started()
with self.trainer.profiler.profile("run_training_batch"):
batch_output = self.batch_loop.run(kwargs)
# choose which loop will run the optimization
if self.trainer.lightning_module.automatic_optimization:
optimizers = _get_active_optimizers(
self.trainer.optimizers, self.trainer.optimizer_frequencies, kwargs.get("batch_idx", 0)
)
batch_output = self.optimizer_loop.run(optimizers, kwargs)
else:
batch_output = self.manual_loop.run(kwargs)
self.batch_progress.increment_processed()
@ -232,7 +249,11 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
self.batch_progress.increment_completed()
if is_overridden("training_epoch_end", self.trainer.lightning_module):
if batch_output and is_overridden("training_epoch_end", self.trainer.lightning_module):
# batch_output may be empty
# automatic: can be empty if all optimizers skip their batches
# manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens,
# then `advance` doesn't finish and an empty dict is returned
self._outputs.append(batch_output)
# -----------------------------------------
@ -254,7 +275,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
self.update_lr_schedulers("step", update_plateau_schedulers=True)
if not self._should_accumulate():
# this is increased once per batch disregarding multiple optimizers or tbptt on purpose for loggers
# this is increased once per batch disregarding multiple optimizers on purpose for loggers
self._batches_that_stepped += 1
# this will save based on the `batches_that_stepped` value
self._save_loggers_on_train_batch_end()
@ -271,7 +292,13 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
def teardown(self) -> None:
self._results.cpu()
self.batch_loop.teardown()
self.optimizer_loop.teardown()
self.manual_loop.teardown()
# release memory
if self.accumulated_loss.memory is not None:
self.accumulated_loss.memory = self.accumulated_loss.memory.cpu()
if self.running_loss.memory is not None:
self.running_loss.memory = self.running_loss.memory.cpu()
self.val_loop.teardown()
def on_save_checkpoint(self) -> Dict:
@ -527,6 +554,21 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
kwargs["batch_idx"] = batch_idx
return kwargs
def _update_running_loss(self, current_loss: Tensor) -> None:
"""Updates the running loss value with the current value."""
if self.trainer.lightning_module.automatic_optimization:
# track total loss for logging (avoid mem leaks)
self.accumulated_loss.append(current_loss)
accumulated_loss = self.accumulated_loss.mean()
if accumulated_loss is not None:
# calculate running loss for display
self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches)
# reset for next set of accumulated grads
self.accumulated_loss.reset()
def _convert_optim_dict(outs: Dict[int, Dict[str, Any]], num_optimizers: int) -> List[Optional[Dict[str, Any]]]:
"""Converts an optimizer dict to a list in which the key of the dict determines the position of the element.

View File

@ -77,11 +77,6 @@ class FitLoop(Loop[None]):
"""Returns the current batch index (within this epoch)"""
return self.epoch_loop.batch_idx
@property
def split_idx(self) -> int:
"""Returns the index of the current batch split (within the current batch) for bptt."""
return self.epoch_loop.batch_loop.split_idx
@property
def min_steps(self) -> Optional[int]:
# TODO(@justusschock): Why aren't we using the attribute in this class?
@ -112,7 +107,7 @@ class FitLoop(Loop[None]):
@property
def running_loss(self) -> TensorRunningAccum:
"""Returns the running loss."""
return self.epoch_loop.batch_loop.running_loss
return self.epoch_loop.running_loss
@Loop.restarting.setter
def restarting(self, restarting: bool) -> None:
@ -131,12 +126,12 @@ class FitLoop(Loop[None]):
@property
def _skip_backward(self) -> bool:
"""Determines whether the loop will skip backward during automatic optimization."""
return self.epoch_loop.batch_loop.optimizer_loop._skip_backward
return self.epoch_loop.optimizer_loop._skip_backward
@_skip_backward.setter
def _skip_backward(self, value: bool) -> None:
"""Determines whether the loop will skip backward during automatic optimization."""
self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value
self.epoch_loop.optimizer_loop._skip_backward = value
@property
def _results(self) -> _ResultCollection:
@ -239,7 +234,7 @@ class FitLoop(Loop[None]):
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)
# stores accumulated grad fractions per batch
self.epoch_loop.batch_loop.accumulated_loss.reset(window_length=self.trainer.accumulate_grad_batches)
self.epoch_loop.accumulated_loss.reset(window_length=self.trainer.accumulate_grad_batches)
self.epoch_progress.increment_ready()

View File

@ -145,7 +145,8 @@ class Loop(ABC, Generic[T]):
# connect sub-loops
kwargs = {n: lp for n, lp in old_loop.__dict__.items() if isinstance(lp, Loop)}
loop.connect(**kwargs)
if kwargs:
loop.connect(**kwargs)
# set the trainer reference
loop.trainer = self.trainer

View File

@ -20,7 +20,7 @@ from torch import Tensor
from pytorch_lightning.core.optimizer import do_nothing_closure
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.optimization.closure import OutputResult
from pytorch_lightning.loops.utilities import _build_training_step_kwargs, _extract_hiddens
from pytorch_lightning.loops.utilities import _build_training_step_kwargs
from pytorch_lightning.trainer.progress import Progress, ReadyCompletedTracker
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT
@ -42,7 +42,7 @@ class ManualResult(OutputResult):
def from_training_step_output(cls, training_step_output: Optional[STEP_OUTPUT]) -> "ManualResult":
extra = {}
if isinstance(training_step_output, dict):
extra = {k: v for k, v in training_step_output.items() if k != "hiddens"}
extra = training_step_output.copy()
elif isinstance(training_step_output, Tensor):
extra = {"loss": training_step_output}
elif training_step_output is not None:
@ -82,7 +82,6 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
self.optim_step_progress = Progress.from_defaults(ReadyCompletedTracker)
self._done: bool = False
self._hiddens: Optional[Any] = None
self._output: _OUTPUTS_TYPE = {}
@property
@ -104,7 +103,7 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
Args:
kwargs: The kwargs passed down to the hooks.
"""
kwargs = self._build_kwargs(kwargs, self._hiddens)
kwargs = self._build_kwargs(kwargs)
# manually capture logged metrics
training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
@ -114,12 +113,11 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
training_step_output = strategy_output if model_output is None else model_output
self._hiddens = _extract_hiddens(training_step_output, self.trainer.lightning_module.truncated_bptt_steps)
result = self.output_result_cls.from_training_step_output(training_step_output)
if self.trainer.move_metrics_to_cpu:
# hiddens and the training step output are not moved as they are not considered "metrics"
# training step output does not get moved because it is not considered a "metric"
# the user might need them on the correct device for an operation in `training_epoch_end`
assert self.trainer._results is not None
self.trainer._results.cpu()
@ -144,16 +142,13 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
self.trainer.profiler.stop("optimizer_step")
self.optim_step_progress.increment_completed()
def _build_kwargs(self, kwargs: OrderedDict, hiddens: Optional[Any]) -> OrderedDict:
def _build_kwargs(self, kwargs: OrderedDict) -> OrderedDict:
"""Helper method to build the arguments for the current step.
Args:
kwargs: The kwargs passed down to the hooks.
hiddens: the hidden state of the previous RNN iteration.
Returns:
The kwargs passed down to the hooks.
"""
return _build_training_step_kwargs(
kwargs, self.trainer.lightning_module, self.trainer.optimizers, None, hiddens
)
return _build_training_step_kwargs(kwargs, self.trainer.lightning_module, self.trainer.optimizers, None)

View File

@ -24,11 +24,7 @@ from pytorch_lightning.accelerators import TPUAccelerator
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.optimization.closure import AbstractClosure, OutputResult
from pytorch_lightning.loops.utilities import (
_block_parallel_sync_behavior,
_build_training_step_kwargs,
_extract_hiddens,
)
from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior, _build_training_step_kwargs
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import WarningCache
@ -72,7 +68,7 @@ class ClosureResult(OutputResult):
raise MisconfigurationException(
"In automatic_optimization, when `training_step` returns a dict, the 'loss' key needs to be present"
)
extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}
extra = {k: v for k, v in training_step_output.items() if k != "loss"}
elif isinstance(training_step_output, Tensor):
closure_loss = training_step_output
elif training_step_output is not None:
@ -166,7 +162,6 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
self._skip_backward: bool = False
self._optimizers: Tuple[Optimizer, ...] = tuple()
self._indices: Tuple[int, ...] = tuple()
self._hiddens: Optional[Any] = None
@property
def optimizer_idx(self) -> int:
@ -194,7 +189,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
self.optim_progress.optimizer_position = 0
def advance(self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict) -> None:
kwargs = self._build_kwargs(kwargs, self.optimizer_idx, self._hiddens)
kwargs = self._build_kwargs(kwargs, self.optimizer_idx)
result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
if result.loss is not None:
@ -251,7 +246,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
# if no result, user decided to skip optimization
# otherwise update running loss + reset accumulated loss
# TODO: find proper way to handle updating running loss
self.trainer.fit_loop.epoch_loop.batch_loop._update_running_loss(result.loss)
self.trainer.fit_loop.epoch_loop._update_running_loss(result.loss)
# untoggle model params
self._run_optimization_end(opt_idx)
@ -404,30 +399,25 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
training_step_output = strategy_output if model_output is None else model_output
self._hiddens = _extract_hiddens(training_step_output, self.trainer.lightning_module.truncated_bptt_steps)
result = self.output_result_cls.from_training_step_output(
training_step_output, self.trainer.accumulate_grad_batches
)
if self.trainer.move_metrics_to_cpu:
# hiddens and the training step output are not moved as they are not considered "metrics"
# training step output does not get moved because it is not considered a "metric"
assert self.trainer._results is not None
self.trainer._results.cpu()
return result
def _build_kwargs(self, kwargs: OrderedDict, opt_idx: int, hiddens: Optional[Any]) -> OrderedDict:
def _build_kwargs(self, kwargs: OrderedDict, opt_idx: int) -> OrderedDict:
"""Helper method to build the arguments for the current step.
Args:
kwargs: The kwargs passed down to the hooks.
opt_idx: the index of the current optimizer.
hiddens: the hidden state of the previous RNN iteration.
Returns:
The kwargs passed down to the hooks.
"""
return _build_training_step_kwargs(
kwargs, self.trainer.lightning_module, self.trainer.optimizers, opt_idx, hiddens
)
return _build_training_step_kwargs(kwargs, self.trainer.lightning_module, self.trainer.optimizers, opt_idx)

View File

@ -14,7 +14,7 @@
from collections import OrderedDict
from contextlib import contextmanager
from functools import lru_cache
from typing import Any, Generator, List, Optional, Sequence, Tuple, Union
from typing import Generator, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
@ -30,11 +30,8 @@ from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.strategies.strategy import Strategy
from pytorch_lightning.trainer.progress import BaseProgress
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import STEP_OUTPUT
def check_finite_loss(loss: Optional[Tensor]) -> None:
@ -47,28 +44,6 @@ def check_finite_loss(loss: Optional[Tensor]) -> None:
raise ValueError(f"The loss returned in `training_step` is {loss}.")
def _extract_hiddens(training_step_output: STEP_OUTPUT, truncated_bptt_steps: int) -> Optional[Any]:
"""Get the hidden state if present from the training step output.
Raises:
MisconfigurationException: If :attr:`~pytorch_lightning.core.Lightning.LightningModule.truncated_bptt_steps` is
not enabled and hiddens are returned or vice versa.
"""
if not truncated_bptt_steps:
if isinstance(training_step_output, dict) and "hiddens" in training_step_output:
raise MisconfigurationException(
'You returned "hiddens" in your `training_step` but `truncated_bptt_steps` is disabled'
)
return None
if not isinstance(training_step_output, dict) or "hiddens" not in training_step_output:
raise MisconfigurationException(
'You enabled `truncated_bptt_steps` but did not `return {..., "hiddens": ...}` in your `training_step`'
)
# detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time`
hiddens = recursive_detach(training_step_output["hiddens"])
return hiddens
def _parse_loop_limits(
min_steps: Optional[int],
max_steps: int,
@ -116,7 +91,6 @@ def _build_training_step_kwargs(
lightning_module: "pl.LightningModule",
optimizers: Sequence[Optimizer],
opt_idx: Optional[int],
hiddens: Optional[Any],
) -> OrderedDict:
"""Builds the keyword arguments for training_step.
@ -125,7 +99,6 @@ def _build_training_step_kwargs(
lightning_module: the LightningModule with a `training_step` hook implementation
optimizers: the list of optimizers from the Trainer
opt_idx: the index of the current optimizer
hiddens: the hidden state of the previous RNN iteration
Returns:
the keyword arguments for the training step
@ -147,10 +120,6 @@ def _build_training_step_kwargs(
" `training_step` is missing the `optimizer_idx` argument."
)
# pass hiddens if using tbptt
if lightning_module.truncated_bptt_steps > 0:
kwargs["hiddens"] = hiddens
return kwargs

View File

@ -165,9 +165,3 @@ def __check_training_step_requires_dataloader_iter(model: "pl.LightningModule")
" not match with the actual batch index when using a `dataloader_iter`"
" argument in your `training_step`."
)
if model.truncated_bptt_steps > 0:
raise MisconfigurationException(
"The model taking a `dataloader_iter` argument in your `training_step` "
"is incompatible with `truncated_bptt_steps > 0`."
)

View File

@ -37,7 +37,6 @@ class LoggerConnector:
self._epoch_end_reached = False
self._current_fx: Optional[str] = None
self._batch_idx: Optional[int] = None
self._split_idx: Optional[int] = None
def on_trainer_init(
self,
@ -144,9 +143,6 @@ class LoggerConnector:
Train metric updates
"""
def on_train_split_start(self, split_idx: int) -> None:
self._split_idx = split_idx
def update_train_step_metrics(self) -> None:
if self.trainer.fit_loop._should_accumulate() and self.trainer.lightning_module.automatic_optimization:
return
@ -185,7 +181,6 @@ class LoggerConnector:
def epoch_end_reached(self) -> None:
self._epoch_end_reached = True
self._batch_idx = None
self._split_idx = None
def on_epoch_end(self) -> None:
assert self._epoch_end_reached
@ -209,10 +204,7 @@ class LoggerConnector:
def should_reset_tensors(self, fx: str) -> bool:
is_different_fx = self._current_fx != fx
if self._split_idx is None:
is_first_batch = self._batch_idx in (None, 0)
else:
is_first_batch = bool(self._batch_idx) + self._split_idx == 0
is_first_batch = self._batch_idx in (None, 0)
return is_different_fx and is_first_batch
def reset_metrics(self) -> None:
@ -226,7 +218,6 @@ class LoggerConnector:
results.reset()
self._batch_idx = None
self._split_idx = None
self._current_fx = None
@property

View File

@ -1944,7 +1944,7 @@ class Trainer:
def global_step(self) -> int:
"""The number of optimizer steps taken (does not reset each epoch).
This includes multiple optimizers and TBPTT steps (if enabled).
This includes multiple optimizers (if enabled).
"""
return self.fit_loop.epoch_loop.global_step

View File

@ -335,8 +335,8 @@ def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:
def _reset_progress(trainer: "pl.Trainer") -> None:
if trainer.lightning_module.automatic_optimization:
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.reset()
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.reset()
else:
trainer.fit_loop.epoch_loop.batch_loop.manual_loop.optim_step_progress.reset()
trainer.fit_loop.epoch_loop.manual_loop.optim_step_progress.reset()
trainer.fit_loop.epoch_progress.reset()

View File

@ -46,7 +46,7 @@ def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]:
"1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking],
"1.6.5": [_migrate_loop_batches_that_stepped],
"1.9.0": [_migrate_model_checkpoint_save_on_train_epoch_end_default],
"2.0.0": [_drop_apex_amp_state],
"2.0.0": [_drop_apex_amp_state, _migrate_loop_structure_after_tbptt_removal],
}
@ -219,3 +219,40 @@ def _drop_apex_amp_state(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
rank_zero_warn("This checkpoint contains apex AMP data, but apex support has been removed in v2.0.0.")
del checkpoint[key]
return checkpoint
def _migrate_loop_structure_after_tbptt_removal(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
"""Adjusts the loop structure since it changed when the support for truncated backpropagation was removed. The
optimizer loop and the manual loop were previously children of the training batch loop. After its removal, they
became the children of the training epoch loop.
Version: 2.0.0
Commit: TBD
PR: #16172
"""
if "loops" not in checkpoint:
return checkpoint
fit_loop = checkpoint["loops"]["fit_loop"]
# remap `x.batch_loop.y` to `x.y`
old_key_new_key_mapping = {
"epoch_loop.batch_loop.manual_loop.optim_step_progress": "epoch_loop.manual_loop.optim_step_progress",
"epoch_loop.batch_loop.manual_loop.state_dict": "epoch_loop.manual_loop.state_dict",
"epoch_loop.batch_loop.optimizer_loop.optim_progress": "epoch_loop.optimizer_loop.optim_progress",
"epoch_loop.batch_loop.optimizer_loop.state_dict": "epoch_loop.optimizer_loop.state_dict",
}
for old, new in list(old_key_new_key_mapping.items()):
if old in fit_loop:
fit_loop[new] = fit_loop[old]
del fit_loop[old]
# We can safely drop this key: our default implementation of `batch_loop` did not have state.
# If there was state from a custom batch loop, we wouldn't be able to load it meaningfully.
# But just in case, we save a copy of it in `epoch_loop.state_dict` in case the user wants to process it after
# loading the checkpoint.
if "epoch_loop.batch_loop.state_dict" in fit_loop and fit_loop["epoch_loop.batch_loop.state_dict"]:
fit_loop["epoch_loop.state_dict"]["old_batch_loop_state_dict"] = fit_loop["epoch_loop.batch_loop.state_dict"]
fit_loop.pop("epoch_loop.batch_loop.state_dict", None)
return checkpoint

View File

@ -148,5 +148,6 @@ def _set_legacy_version(checkpoint: _CHECKPOINT, version: str) -> None:
def _should_upgrade(checkpoint: _CHECKPOINT, target: str, max_version: Optional[str] = None) -> bool:
"""Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target."""
is_lte_max_version = max_version is None or Version(target) <= Version(max_version)
return Version(_get_version(checkpoint)) < Version(target) and is_lte_max_version
target_version = Version(target)
is_lte_max_version = max_version is None or target_version <= Version(max_version)
return is_lte_max_version and Version(_get_version(checkpoint)) < target_version

View File

@ -659,10 +659,8 @@ def test_get_progress_bar_metrics(tmpdir: str):
)
model = BoringModel()
trainer.fit(model)
model.truncated_bptt_steps = 2
standard_metrics = progress_bar.get_metrics(trainer, model)
assert "loss" in standard_metrics.keys()
assert "split_idx" in standard_metrics.keys()
assert "v_num" not in standard_metrics.keys()

View File

@ -1,205 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import pytest
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from pytorch_lightning import LightningModule, Trainer
class LSTMModel(LightningModule):
"""LSTM sequence-to-sequence model for testing TBPTT with automatic optimization."""
def __init__(self, truncated_bptt_steps=2, input_size=1, hidden_size=8):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.lstm = torch.nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)
self.truncated_bptt_steps = truncated_bptt_steps
self.automatic_optimization = True
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.01)
def training_step(self, batch, batch_idx, hiddens):
x, y = batch
pred, hiddens = self.lstm(x, hiddens)
loss = F.mse_loss(pred, y)
return {"loss": loss, "hiddens": hiddens}
def train_dataloader(self):
dataset = TensorDataset(torch.rand(16, 8, self.input_size), torch.rand(16, 8, self.input_size))
return DataLoader(dataset=dataset, batch_size=4)
class ManualLSTMModel(LSTMModel):
"""LSTM sequence-to-sequence model for testing TBPTT with manual optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, batch, batch_idx, hiddens):
out = super().training_step(batch, batch_idx, hiddens)
loss, hiddens = out["loss"], out["hiddens"]
opt = self.optimizers()
opt.zero_grad()
self.manual_backward(loss)
opt.step()
return {"loss": loss, "hiddens": hiddens}
@pytest.mark.parametrize("model_class", (LSTMModel, ManualLSTMModel))
def test_persistent_hidden_state_transfer(tmpdir, model_class):
"""Test that the hidden state reference gets passed through from one training_step to the next and remains
unmodified apart from detached grad_fn."""
class TBPTTModel(model_class):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.test_hidden = None
def training_step(self, batch, batch_idx, hiddens):
split_idx = self.trainer.fit_loop.split_idx
# the hidden state may only be None for the first split_idx
assert not ((split_idx == 0) ^ (hiddens is None))
# test_hiddens is None when hiddens is None
assert not ((hiddens is None) ^ (self.test_hidden is None))
# the states are equal (persistent)
assert hiddens is None or all(torch.equal(h, th) for h, th in zip(hiddens, self.test_hidden))
# the incoming hidden state never has a grad_fn (gets automatically detached)
assert hiddens is None or all(h.grad_fn is None for h in hiddens)
out = super().training_step(batch, batch_idx, hiddens)
# store hiddens, assert persistence in next training_step
self.test_hidden = out["hiddens"]
# hiddens may have grad_fn when returning, gets automatically detached
assert all(h.grad_fn is not None for h in self.test_hidden)
return out
def on_train_batch_start(self, *_, **__) -> None:
self.test_hidden = None
model = TBPTTModel(truncated_bptt_steps=2, input_size=1, hidden_size=8)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
enable_model_summary=False,
logger=False,
enable_checkpointing=False,
)
trainer.fit(model)
@pytest.mark.parametrize("model_class", (LSTMModel, ManualLSTMModel))
def test_tbptt_split_shapes(tmpdir, model_class):
"""Test that the sequence data gets split correctly and that the outputs are correctly passed from hook to
hook."""
batch_size = 10
truncated_bptt_steps = 2
n, t, f = 32, 15, 1 # (num samples, sequence size, input size)
assert t % truncated_bptt_steps != 0, "test must run with sequence length not divisible by tbptt steps"
seq2seq_dataset = TensorDataset(torch.rand(n, t, f), torch.rand(n, t, f))
train_dataloader = DataLoader(dataset=seq2seq_dataset, batch_size=batch_size)
class TBPTTModel(model_class):
def training_step(self, batch, batch_idx, hiddens):
x, y = batch
if self.trainer.fit_loop.epoch_loop.batch_loop.done:
# last split idx, not aligned
assert x.shape[1] == t % truncated_bptt_steps
assert y.shape[1] == t % truncated_bptt_steps
else:
assert x.shape[1] == truncated_bptt_steps
assert y.shape[1] == truncated_bptt_steps
return super().training_step(batch, batch_idx, hiddens)
def training_epoch_end(self, training_step_outputs):
training_step_outputs = training_step_outputs[0]
assert len(training_step_outputs) == math.ceil(t / self.truncated_bptt_steps)
assert all(out["loss"].grad_fn is None for out in training_step_outputs)
assert all("hiddens" not in out for out in training_step_outputs)
model = TBPTTModel(truncated_bptt_steps=truncated_bptt_steps, input_size=f, hidden_size=8)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
enable_model_summary=False,
logger=False,
enable_checkpointing=False,
)
trainer.fit(model, train_dataloaders=train_dataloader)
assert trainer.fit_loop.batch_idx == n // batch_size
assert trainer.fit_loop.split_idx == t // truncated_bptt_steps
@pytest.mark.parametrize("model_class", (LSTMModel, ManualLSTMModel))
def test_tbptt_logging(tmpdir, model_class):
"""Test step-level and epoch-level logging works with TBPTT."""
class TBPTTModel(model_class):
def training_step(self, *args, **kwargs):
out = super().training_step(*args, **kwargs)
self.log("loss", out["loss"], on_step=True, on_epoch=True)
return out
model = TBPTTModel(truncated_bptt_steps=2)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
log_every_n_steps=2,
enable_model_summary=False,
enable_checkpointing=False,
)
trainer.fit(model)
assert set(trainer.logged_metrics) == {"loss_step", "loss_epoch"}
def test_hiddens_multiple_optimizers(tmpdir):
class TBPTTModel(LSTMModel):
# TODO: `optimizer_idx=n` gets the hiddens from `optimizer_idx=n-1` instead of the hidden from
# `optimizer_idx=n`, `split_idx=m-1`. This is unexpected and should be changed
test_hiddens = None
def training_step(self, batch, batch_idx, optimizer_idx, hiddens):
if hiddens is None:
assert self.test_hiddens is None
else:
assert all(torch.equal(h, th) for h, th in zip(hiddens, self.test_hiddens))
out = super().training_step(batch, batch_idx, hiddens)
self.test_hiddens = out["hiddens"]
return out
def configure_optimizers(self):
return [super().configure_optimizers(), super().configure_optimizers()]
model = TBPTTModel(truncated_bptt_steps=2, input_size=1, hidden_size=1)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=1,
limit_val_batches=0,
enable_model_summary=False,
logger=False,
enable_checkpointing=False,
enable_progress_bar=False,
)
trainer.fit(model)
assert trainer.global_step == 8 / 2 * 2 # time_dim_length / tbptt_steps * num_optimizers

View File

@ -32,73 +32,56 @@ _out13 = {"loss": 1.3}
class TestPrepareOutputs:
def prepare_outputs(self, fn, tbptt_splits, batch_outputs, num_optimizers, automatic_optimization):
def prepare_outputs(self, fn, batch_outputs, num_optimizers, automatic_optimization):
lightning_module = LightningModule()
lightning_module.automatic_optimization = automatic_optimization
lightning_module.truncated_bptt_steps = tbptt_splits
return fn(
batch_outputs,
lightning_module=lightning_module,
num_optimizers=num_optimizers, # does not matter for manual optimization
)
def prepare_outputs_training_epoch_end(
self, tbptt_splits, batch_outputs, num_optimizers, automatic_optimization=True
):
def prepare_outputs_training_epoch_end(self, batch_outputs, num_optimizers, automatic_optimization=True):
return self.prepare_outputs(
TrainingEpochLoop._prepare_outputs_training_epoch_end,
tbptt_splits,
batch_outputs,
num_optimizers,
automatic_optimization=automatic_optimization,
)
def prepare_outputs_training_batch_end(
self, tbptt_splits, batch_outputs, num_optimizers, automatic_optimization=True
):
def prepare_outputs_training_batch_end(self, batch_outputs, num_optimizers, automatic_optimization=True):
return self.prepare_outputs(
TrainingEpochLoop._prepare_outputs_training_batch_end,
tbptt_splits,
batch_outputs,
num_optimizers,
automatic_optimization=automatic_optimization,
)
@pytest.mark.parametrize(
"num_optimizers,tbptt_splits,batch_outputs,expected",
"num_optimizers,batch_outputs,expected",
[
(1, 0, [], []),
(1, 0, [[]], []),
(1, [], []),
(1, [[]], []),
# 1 batch
(1, 0, [[{0: _out00}]], [_out00]),
(1, [[{0: _out00}]], [_out00]),
# 2 batches
(1, 0, [[{0: _out00}], [{0: _out01}]], [_out00, _out01]),
(1, [[{0: _out00}], [{0: _out01}]], [_out00, _out01]),
# 1 batch, 2 optimizers
(2, 0, [[{0: _out00, 1: _out01}]], [_out00, _out01]),
(2, [[{0: _out00, 1: _out01}]], [_out00, _out01]),
# 2 batches, 2 optimizers
(2, 0, [[{0: _out00, 1: _out01}], [{0: _out10, 1: _out11}]], [[_out00, _out01], [_out10, _out11]]),
(2, [[{0: _out00, 1: _out01}], [{0: _out10, 1: _out11}]], [[_out00, _out01], [_out10, _out11]]),
# 4 batches, 2 optimizers, different frequency
(
2,
0,
[[{0: _out00}], [{1: _out10}], [{1: _out11}], [{0: _out01}]],
[[_out00], [_out10], [_out11], [_out01]],
),
# 1 batch, tbptt with 2 splits (uneven)
(1, 2, [[{0: _out00}, {0: _out01}], [{0: _out03}]], [[_out00, _out01], [_out03]]),
# 3 batches, tbptt with 2 splits, 2 optimizers alternating
(
2,
2,
[[{0: _out00}, {0: _out01}], [{1: _out10}, {1: _out11}], [{0: _out02}, {0: _out03}]],
[[[_out00], [_out01]], [[_out10], [_out11]], [[_out02], [_out03]]],
),
],
)
def test_prepare_outputs_training_epoch_end_automatic(self, num_optimizers, tbptt_splits, batch_outputs, expected):
def test_prepare_outputs_training_epoch_end_automatic(self, num_optimizers, batch_outputs, expected):
"""Test that the loop converts the nested lists of outputs to the format that the `training_epoch_end` hook
currently expects in the case of automatic optimization."""
assert self.prepare_outputs_training_epoch_end(tbptt_splits, batch_outputs, num_optimizers) == expected
assert self.prepare_outputs_training_epoch_end(batch_outputs, num_optimizers) == expected
@pytest.mark.parametrize(
"batch_outputs,expected",
@ -111,37 +94,29 @@ class TestPrepareOutputs:
([[_out00], [_out01]], [_out00, _out01]),
# skipped outputs
([[_out00], [], [], [_out03]], [_out00, _out03]),
# tbptt with 2 splits, uneven, skipped output
([[_out00, _out01], [_out02, _out03], [], [_out10]], [[_out00, _out01], [_out02, _out03], [_out10]]),
],
)
def test_prepare_outputs_training_epoch_end_manual(self, batch_outputs, expected):
"""Test that the loop converts the nested lists of outputs to the format that the `training_epoch_end` hook
currently expects in the case of manual optimization."""
assert self.prepare_outputs_training_epoch_end(0, batch_outputs, -1, automatic_optimization=False) == expected
assert self.prepare_outputs_training_epoch_end(batch_outputs, -1, automatic_optimization=False) == expected
@pytest.mark.parametrize(
"num_optimizers,tbptt_splits,batch_end_outputs,expected",
"num_optimizers,batch_end_outputs,expected",
[
(1, 0, [], []),
(1, 0, [[]], []),
(1, [], []),
(1, [[]], []),
# 1 optimizer
(1, 0, [{0: _out00}], _out00),
(1, [{0: _out00}], _out00),
# 2 optimizers
(2, 0, [{0: _out00, 1: _out01}], [_out00, _out01]),
# tbptt with 2 splits
(1, 2, [{0: _out00}, {0: _out01}], [_out00, _out01]),
# 2 optimizers, tbptt with 2 splits
(2, 2, [{0: _out00, 1: _out01}, {0: _out10, 1: _out11}], [[_out00, _out01], [_out10, _out11]]),
(2, [{0: _out00, 1: _out01}], [_out00, _out01]),
],
)
def test_prepare_outputs_training_batch_end_automatic(
self, num_optimizers, tbptt_splits, batch_end_outputs, expected
):
def test_prepare_outputs_training_batch_end_automatic(self, num_optimizers, batch_end_outputs, expected):
"""Test that the loop converts the nested lists of outputs to the format that the `on_train_batch_end` hook
currently expects in the case of automatic optimization."""
assert self.prepare_outputs_training_batch_end(tbptt_splits, batch_end_outputs, num_optimizers) == expected
assert self.prepare_outputs_training_batch_end(batch_end_outputs, num_optimizers) == expected
@pytest.mark.parametrize(
"batch_end_outputs,expected",
@ -150,16 +125,12 @@ class TestPrepareOutputs:
([[]], []),
# skipped outputs
([_out00, None, _out02], [_out00, _out02]),
# tbptt with 3 splits, skipped output
([_out00, _out01, None, _out03], [_out00, _out01, _out03]),
],
)
def test_prepare_outputs_training_batch_end_manual(self, batch_end_outputs, expected):
"""Test that the loop converts the nested lists of outputs to the format that the `on_train_batch_end` hook
currently expects in the case of manual optimization."""
assert (
self.prepare_outputs_training_batch_end(0, batch_end_outputs, -1, automatic_optimization=False) == expected
)
assert self.prepare_outputs_training_batch_end(batch_end_outputs, -1, automatic_optimization=False) == expected
def test_no_val_on_train_epoch_loop_restart(tmpdir):
@ -208,7 +179,7 @@ def test_should_stop_early_stopping_conditions_not_met(
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0)
trainer.num_training_batches = 10
trainer.should_stop = True
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = global_step
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = global_step
trainer.fit_loop.epoch_loop.batch_progress.current.ready = global_step
trainer.fit_loop.epoch_progress.current.completed = current_epoch - 1

View File

@ -65,15 +65,15 @@ def test__eval_step__flow(tmpdir):
# simulate training manually
trainer.state.stage = RunningStage.TRAINING
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs)
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs)
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
train_step_out = train_step_out[0]
assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171
# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171
@ -126,15 +126,15 @@ def test__eval_step__eval_step_end__flow(tmpdir):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs)
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs)
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
train_step_out = train_step_out[0]
assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171
# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171

View File

@ -52,14 +52,13 @@ def test_loops_state_dict_structure():
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.batch_loop.state_dict": {},
"epoch_loop.batch_loop.manual_loop.state_dict": {},
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
"epoch_loop.manual_loop.state_dict": {},
"epoch_loop.manual_loop.optim_step_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"epoch_loop.optimizer_loop.state_dict": {},
"epoch_loop.optimizer_loop.optim_progress": {
"optimizer": {
"step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
"zero_grad": {

View File

@ -25,7 +25,7 @@ from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoad
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.loops import EvaluationLoop, Loop, TrainingBatchLoop, TrainingEpochLoop
from pytorch_lightning.loops import EvaluationLoop, Loop, OptimizerLoop, TrainingEpochLoop
from pytorch_lightning.trainer.progress import BaseProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.runif import RunIf
@ -109,15 +109,15 @@ def test_connect_subloops(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
epoch_loop = trainer.fit_loop.epoch_loop
new_batch_loop = TrainingBatchLoop()
epoch_loop.connect(batch_loop=new_batch_loop)
assert epoch_loop.batch_loop is new_batch_loop
new_optimizer_loop = OptimizerLoop()
epoch_loop.connect(optimizer_loop=new_optimizer_loop)
assert epoch_loop.optimizer_loop is new_optimizer_loop
with pytest.raises(RuntimeError, match="The loop is not attached to a Trainer"):
_ = new_batch_loop.trainer
_ = new_optimizer_loop.trainer
trainer.fit(model)
assert new_batch_loop.trainer is trainer
assert new_optimizer_loop.trainer is trainer
def test_replace_loops():
@ -144,22 +144,22 @@ def test_replace_loops():
assert trainer.fit_loop.epoch_loop is new_loop
assert new_loop.min_steps == 123
assert new_loop.max_steps == 321
assert new_loop.batch_loop is old_loop.batch_loop
assert new_loop.optimizer_loop is old_loop.optimizer_loop
assert new_loop.val_loop is old_loop.val_loop
assert new_loop.trainer is trainer
class MyBatchLoop(TrainingBatchLoop):
class MyOptimizerLoop(OptimizerLoop):
...
class MyEvalLoop(EvaluationLoop):
...
# test passing more than one where one is an instance and the other a class
trainer.fit_loop.epoch_loop.replace(batch_loop=MyBatchLoop, val_loop=MyEvalLoop())
new_batch_loop = trainer.fit_loop.epoch_loop.batch_loop
trainer.fit_loop.epoch_loop.replace(optimizer_loop=MyOptimizerLoop, val_loop=MyEvalLoop())
new_optimizer_loop = trainer.fit_loop.epoch_loop.optimizer_loop
new_val_loop = trainer.fit_loop.epoch_loop.val_loop
assert isinstance(new_batch_loop, MyBatchLoop)
assert isinstance(new_optimizer_loop, MyOptimizerLoop)
assert isinstance(new_val_loop, MyEvalLoop)
@ -436,7 +436,7 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
assert os.path.exists(ckpt_path)
checkpoint = torch.load(ckpt_path)
optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress
optim_progress = trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress
sch_progress = trainer.fit_loop.epoch_loop.scheduler_progress
# `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch
@ -510,14 +510,13 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
"total": {"ready": nbe_sch_steps + be_sch_steps, "completed": nbe_sch_steps + be_sch_steps},
"current": {"ready": be_sch_steps, "completed": be_sch_steps},
},
"epoch_loop.batch_loop.state_dict": ANY,
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
"epoch_loop.manual_loop.state_dict": ANY,
"epoch_loop.manual_loop.optim_step_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"epoch_loop.optimizer_loop.state_dict": {},
"epoch_loop.optimizer_loop.optim_progress": {
"optimizer_position": stop_optimizer,
"optimizer": {
"step": {
@ -563,8 +562,8 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
# test resetting manually, we expect all `ready` counters to be reset to `completed`
trainer.fit_loop.reset()
trainer.fit_loop.epoch_loop.reset()
trainer.fit_loop.epoch_loop.batch_loop.reset()
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.reset()
trainer.fit_loop.epoch_loop.optimizer_loop.reset()
trainer.fit_loop.epoch_loop.manual_loop.reset()
epoch_progress = trainer.fit_loop.epoch_progress
assert epoch_progress.current.ready == stop_epoch
@ -574,7 +573,7 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
assert batch_progress.current.ready == be_batches_completed
assert batch_progress.current.completed == be_batches_completed
optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress
optim_progress = trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress
assert optim_progress.optimizer.step.current.ready == be_total_opt_steps
assert optim_progress.optimizer.step.current.completed == be_total_opt_steps
assert optim_progress.optimizer.zero_grad.current.ready == be_total_zero_grad
@ -677,14 +676,13 @@ def test_loop_state_on_complete_run(n_optimizers, tmpdir):
"total": {"ready": n_sch_steps_total, "completed": n_sch_steps_total},
"current": {"ready": n_sch_steps_current, "completed": n_sch_steps_current},
},
"epoch_loop.batch_loop.state_dict": ANY,
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
"epoch_loop.manual_loop.state_dict": ANY,
"epoch_loop.manual_loop.optim_step_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"epoch_loop.optimizer_loop.state_dict": {},
"epoch_loop.optimizer_loop.optim_progress": {
"optimizer_position": n_optimizers,
"optimizer": {
"step": {
@ -746,7 +744,7 @@ def test_fit_loop_reset(tmpdir):
mid_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=2.ckpt"))
fit_loop = trainer.fit_loop
epoch_loop = fit_loop.epoch_loop
optimizer_loop = epoch_loop.batch_loop.optimizer_loop
optimizer_loop = epoch_loop.optimizer_loop
assert not fit_loop.restarting
assert not epoch_loop.restarting
assert not optimizer_loop.restarting

View File

@ -217,7 +217,7 @@ def test_should_stop_early_stopping_conditions_met(
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0, max_epochs=100)
trainer.num_training_batches = 10
trainer.should_stop = True
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = (
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = (
current_epoch * trainer.num_training_batches
)
trainer.fit_loop.epoch_loop.batch_progress.current.ready = 10

View File

@ -147,15 +147,15 @@ def test__training_step__epoch_end__flow_scalar(tmpdir):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs)
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs)
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
train_step_out = train_step_out[0]
assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171
# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171
@ -217,15 +217,15 @@ def test__training_step__step_end__epoch_end__flow_scalar(tmpdir):
trainer.state.stage = RunningStage.TRAINING
# make sure training outputs what is expected
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs)
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs)
assert len(train_step_out) == 1
train_step_out = train_step_out[0][0]
train_step_out = train_step_out[0]
assert isinstance(train_step_out["loss"], Tensor)
assert train_step_out["loss"].item() == 171
# make sure the optimizer closure returns the correct things
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
opt_closure_result = opt_closure()
assert opt_closure_result.item() == 171
@ -301,9 +301,10 @@ def test_training_step_no_return_when_even(tmpdir):
# manually check a few batches
for batch_idx, batch in enumerate(model.train_dataloader()):
out = trainer.fit_loop.epoch_loop.batch_loop.run({"batch": batch, "batch_idx": batch_idx})
kwargs = {"batch": batch, "batch_idx": batch_idx}
out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs)
if not batch_idx % 2:
assert out == []
assert out == {}
def test_training_step_none_batches(tmpdir):

View File

@ -13,32 +13,7 @@
# limitations under the License.
from unittest.mock import Mock
import pytest
import torch
from pytorch_lightning.loops.utilities import _extract_hiddens, _set_sampler_epoch
from pytorch_lightning.utilities.exceptions import MisconfigurationException
def test_extract_hiddens():
# tbptt not enabled, no hiddens return
training_step_output = 1 # anything
hiddens = _extract_hiddens(training_step_output, 0)
assert hiddens is None
# tbptt enabled, hiddens return
hiddens = torch.tensor(321.12, requires_grad=True)
training_step_output = {"hiddens": hiddens}
hiddens = _extract_hiddens(training_step_output, 2)
assert "hiddens" in training_step_output
assert not hiddens.requires_grad
# tbptt not enabled, hiddens return
with pytest.raises(MisconfigurationException, match='returned "hiddens" .* but `truncated_bptt_steps` is disabled'):
_extract_hiddens(training_step_output, 0)
# tbptt enabled, no hiddens return
with pytest.raises(MisconfigurationException, match="enabled `truncated_bptt_steps` but did not `return"):
_extract_hiddens(None, 1)
from pytorch_lightning.loops.utilities import _set_sampler_epoch
def test_set_sampler_epoch():

View File

@ -353,7 +353,7 @@ def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files)
# emulate callback's calls during the training
for i, loss in enumerate(losses, 1):
# sets `trainer.global_step`
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = i
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = i
trainer.callback_metrics.update({"checkpoint_on": torch.tensor(loss)})
checkpoint_callback.on_validation_end(trainer, trainer.lightning_module)
trainer.fit_loop.epoch_progress.current.completed = i # sets `trainer.current_epoch`

View File

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock
import pytest
import torch
@ -156,3 +158,37 @@ def test_migrate_dropped_apex_amp_state(monkeypatch):
with pytest.warns(UserWarning, match="checkpoint contains apex AMP data"):
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy())
assert "amp_scaling_state" not in updated_checkpoint
def test_migrate_loop_structure_after_tbptt_removal():
"""Test the loop state migration after truncated backpropagation support was removed in 2.0.0, and with it the
training batch loop."""
# automatic- and manual optimization state are combined into a single checkpoint to simplify testing
state_automatic = MagicMock()
state_manual = MagicMock()
optim_progress_automatic = MagicMock()
optim_progress_manual = MagicMock()
old_batch_loop_state = MagicMock()
old_checkpoint = {
"loops": {
"fit_loop": {
"epoch_loop.state_dict": {"any": "state"},
"epoch_loop.batch_loop.state_dict": old_batch_loop_state,
"epoch_loop.batch_loop.optimizer_loop.state_dict": state_automatic,
"epoch_loop.batch_loop.optimizer_loop.optim_progress": optim_progress_automatic,
"epoch_loop.batch_loop.manual_loop.state_dict": state_manual,
"epoch_loop.batch_loop.manual_loop.optim_step_progress": optim_progress_manual,
}
}
}
_set_version(old_checkpoint, "1.8.0") # pretend a checkpoint prior to 2.0.0
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="2.0.0")
assert updated_checkpoint["loops"] == {
"fit_loop": {
"epoch_loop.state_dict": {"any": "state", "old_batch_loop_state_dict": old_batch_loop_state},
"epoch_loop.optimizer_loop.state_dict": state_automatic,
"epoch_loop.optimizer_loop.optim_progress": optim_progress_automatic,
"epoch_loop.manual_loop.state_dict": state_manual,
"epoch_loop.manual_loop.optim_step_progress": optim_progress_manual,
}
}

View File

@ -447,21 +447,6 @@ def test_on_train_batch_end_overridden(tmpdir) -> None:
trainer.fit(m)
def test_tbptt_split_batch_overridden(tmpdir) -> None:
"""Verify that a `MisconfigurationException` is raised when `tbptt_split_batch` is overridden on the
`LightningModule`."""
class InvalidModel(AsyncBoringModel):
def __init__(self) -> None:
super().__init__()
self.truncated_bptt_steps = 2
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
m = InvalidModel()
with pytest.raises(MisconfigurationException, match="is incompatible with `truncated_bptt_steps > 0`."):
trainer.fit(m)
def test_transfer_hooks_with_unpacking(tmpdir):
"""This test asserts the `transfer_batch` hooks are called only once per batch."""