Remove truncated backpropagation from loops (#16337)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
92a922ccd4
commit
03a699693b
|
@ -126,7 +126,6 @@ Training
|
|||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
~batch.TrainingBatchLoop
|
||||
~epoch.TrainingEpochLoop
|
||||
FitLoop
|
||||
~optimization.ManualOptimization
|
||||
|
|
|
@ -1035,7 +1035,7 @@ global_step
|
|||
~~~~~~~~~~~
|
||||
|
||||
The number of optimizer steps taken (does not reset each epoch).
|
||||
This includes multiple optimizers and TBPTT steps (if enabled).
|
||||
This includes multiple optimizers (if enabled).
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -1195,79 +1195,6 @@ Set and access example_input_array, which basically represents a single batch.
|
|||
# generate some images using the example_input_array
|
||||
gen_images = self.generator(self.example_input_array)
|
||||
|
||||
truncated_bptt_steps
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Truncated Backpropagation Through Time (TBPTT) performs perform backpropogation every k steps of
|
||||
a much longer sequence. This is made possible by passing training batches
|
||||
split along the time-dimensions into splits of size k to the
|
||||
``training_step``. In order to keep the same forward propagation behavior, all
|
||||
hidden states should be kept in-between each time-dimension split.
|
||||
|
||||
|
||||
If this is enabled, your batches will automatically get truncated
|
||||
and the Trainer will apply Truncated Backprop to it.
|
||||
|
||||
(`Williams et al. "An efficient gradient-based algorithm for on-line training of
|
||||
recurrent network trajectories."
|
||||
<https://ieeexplore.ieee.org/document/6797135>`_)
|
||||
|
||||
`Tutorial <https://d2l.ai/chapter_recurrent-neural-networks/bptt.html>`_
|
||||
|
||||
.. testcode:: python
|
||||
|
||||
from pytorch_lightning import LightningModule
|
||||
|
||||
|
||||
class MyModel(LightningModule):
|
||||
def __init__(self, input_size, hidden_size, num_layers):
|
||||
super().__init__()
|
||||
# batch_first has to be set to True
|
||||
self.lstm = nn.LSTM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
...
|
||||
|
||||
# Important: This property activates truncated backpropagation through time
|
||||
# Setting this value to 2 splits the batch into sequences of size 2
|
||||
self.truncated_bptt_steps = 2
|
||||
|
||||
# Truncated back-propagation through time
|
||||
def training_step(self, batch, batch_idx, hiddens):
|
||||
x, y = batch
|
||||
|
||||
# the training step must be updated to accept a ``hiddens`` argument
|
||||
# hiddens are the hiddens from the previous truncated backprop step
|
||||
out, hiddens = self.lstm(x, hiddens)
|
||||
|
||||
...
|
||||
|
||||
return {"loss": ..., "hiddens": hiddens}
|
||||
|
||||
Lightning takes care of splitting your batch along the time-dimension. It is
|
||||
assumed to be the second dimension of your batches. Therefore, in the
|
||||
example above, we have set ``batch_first=True``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# we use the second as the time dimension
|
||||
# (batch, time, ...)
|
||||
sub_batch = batch[0, 0:t, ...]
|
||||
|
||||
To modify how the batch is split,
|
||||
override the :meth:`pytorch_lightning.core.module.LightningModule.tbptt_split_batch` method:
|
||||
|
||||
.. testcode:: python
|
||||
|
||||
class LitMNIST(LightningModule):
|
||||
def tbptt_split_batch(self, batch, split_size):
|
||||
# do your own splitting on the batch
|
||||
return splits
|
||||
|
||||
--------------
|
||||
|
||||
.. _lightning_hooks:
|
||||
|
@ -1636,12 +1563,6 @@ setup
|
|||
.. automethod:: pytorch_lightning.core.module.LightningModule.setup
|
||||
:noindex:
|
||||
|
||||
tbptt_split_batch
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.core.module.LightningModule.tbptt_split_batch
|
||||
:noindex:
|
||||
|
||||
teardown
|
||||
~~~~~~~~
|
||||
|
||||
|
|
|
@ -1418,7 +1418,7 @@ global_step
|
|||
***********
|
||||
|
||||
The number of optimizer steps taken (does not reset each epoch).
|
||||
This includes multiple optimizers and TBPTT steps (if enabled).
|
||||
This includes multiple optimizers (if enabled).
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
|
|
@ -306,14 +306,11 @@ Here is what the structure would look like in plain Python:
|
|||
# TrainingEpochLoop
|
||||
for batch_idx, batch in enumerate(train_dataloader):
|
||||
|
||||
# TrainingBatchLoop
|
||||
for split_batch in tbptt_split(batch):
|
||||
# OptimizerLoop
|
||||
for optimizer_idx, opt in enumerate(optimizers):
|
||||
|
||||
# OptimizerLoop
|
||||
for optimizer_idx, opt in enumerate(optimizers):
|
||||
|
||||
loss = lightning_module.training_step(batch, batch_idx, optimizer_idx)
|
||||
...
|
||||
loss = lightning_module.training_step(batch, batch_idx, optimizer_idx)
|
||||
...
|
||||
|
||||
# ValidationEpochLoop
|
||||
for batch_idx, batch in enumerate(val_dataloader):
|
||||
|
@ -339,13 +336,8 @@ Each of these :code:`for`-loops represents a class implementing the :class:`~pyt
|
|||
The validation is carried out by yet another loop, :class:`~pytorch_lightning.loops.epoch.validation_epoch_loop.ValidationEpochLoop`.
|
||||
|
||||
In the :code:`run()` method, the training epoch loop could in theory simply call the :code:`LightningModule.training_step` already and perform the optimization.
|
||||
However, Lightning has built-in support for automatic optimization with multiple optimizers and on top of that also supports :ref:`TBPTT <sequential-data>`.
|
||||
However, Lightning has built-in support for automatic optimization with multiple optimizers.
|
||||
For this reason there are actually two more loops nested under :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop`.
|
||||
* - :class:`~pytorch_lightning.loops.batch.training_batch_loop.TrainingBatchLoop`
|
||||
- The responsibility of the :class:`~pytorch_lightning.loops.batch.training_batch_loop.TrainingBatchLoop` is to split a batch given by the :class:`~pytorch_lightning.loops.epoch.training_epoch_loop.TrainingEpochLoop` along the time-dimension and iterate over the list of splits.
|
||||
It also keeps track of the hidden state *hiddens* returned by the training step.
|
||||
By default, when truncated back-propagation through time (TBPTT) is turned off, this loop does not do anything except redirect the call to the :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop`.
|
||||
Read more about :ref:`TBPTT <sequential-data>`.
|
||||
* - :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop`
|
||||
- The :class:`~pytorch_lightning.loops.optimization.optimizer_loop.OptimizerLoop` iterates over one or multiple optimizers and for each one it calls the :meth:`~pytorch_lightning.core.module.LightningModule.training_step` method with the batch, the current batch index and the optimizer index if multiple optimizers are requested.
|
||||
It is the leaf node in the tree of loops and performs the actual optimization (forward, zero grad, backward, optimizer step).
|
||||
|
|
|
@ -343,38 +343,6 @@ When using :class:`~torch.nn.utils.rnn.PackedSequence`, do two things:
|
|||
x = rnn.pack_sequence(batch[0], enforce_sorted=False)
|
||||
y = rnn.pack_sequence(batch[1], enforce_sorted=False)
|
||||
|
||||
|
||||
Truncated Backpropagation Through Time (TBPTT)
|
||||
==============================================
|
||||
|
||||
There are times when multiple backwards passes are needed for each batch.
|
||||
For example, it may save memory to use **Truncated Backpropagation Through Time** when training RNNs.
|
||||
|
||||
Lightning can handle TBPTT automatically via this flag.
|
||||
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning import LightningModule
|
||||
|
||||
|
||||
class MyModel(LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Important: This property activates truncated backpropagation through time
|
||||
# Setting this value to 2 splits the batch into sequences of size 2
|
||||
self.truncated_bptt_steps = 2
|
||||
|
||||
# Truncated back-propagation through time
|
||||
def training_step(self, batch, batch_idx, hiddens):
|
||||
# the training step must be updated to accept a ``hiddens`` argument
|
||||
# hiddens are the hiddens from the previous truncated backprop step
|
||||
out, hiddens = self.lstm(data, hiddens)
|
||||
return {"loss": ..., "hiddens": hiddens}
|
||||
|
||||
.. note:: If you need to modify how the batch is split,
|
||||
override :func:`~pytorch_lightning.core.module.LightningModule.tbptt_split_batch`.
|
||||
|
||||
|
||||
Iterable Datasets
|
||||
=================
|
||||
Lightning supports using :class:`~torch.utils.data.IterableDataset` as well as map-style Datasets. IterableDatasets provide a more natural
|
||||
|
|
|
@ -74,7 +74,7 @@ class YieldLoop(OptimizerLoop):
|
|||
return partial(self._training_step, self._generator)
|
||||
|
||||
def _get_generator(self, kwargs, opt_idx=0):
|
||||
kwargs = self._build_kwargs(kwargs, opt_idx, hiddens=None)
|
||||
kwargs = self._build_kwargs(kwargs, opt_idx)
|
||||
|
||||
# Here we are basically calling `lightning_module.training_step()`
|
||||
# and this returns a generator! The `training_step` is handled by
|
||||
|
@ -285,8 +285,8 @@ class GAN(LightningModule):
|
|||
#############################################################################################
|
||||
# Step 3 / 3: Connect the loop to the Trainer #
|
||||
# #
|
||||
# Finally, attach the loop to the `Trainer`. Here, we modified the `AutomaticOptimization` #
|
||||
# loop which is a subloop of the `TrainingBatchLoop`. We use `.connect()` to attach it. #
|
||||
# Finally, attach the loop to the `Trainer`. Here, we modified the `OptimizerLoop` #
|
||||
# loop which is a subloop of the `TrainingEpochLoop`. We use `.connect()` to attach it. #
|
||||
#############################################################################################
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -296,7 +296,7 @@ if __name__ == "__main__":
|
|||
|
||||
# Connect the new loop
|
||||
# YieldLoop now replaces the previous optimizer loop
|
||||
trainer.fit_loop.epoch_loop.batch_loop.connect(optimizer_loop=YieldLoop())
|
||||
trainer.fit_loop.epoch_loop.connect(optimizer_loop=YieldLoop())
|
||||
|
||||
# fit() will now use the new loop!
|
||||
trainer.fit(model, dm)
|
||||
|
|
|
@ -106,7 +106,6 @@ class LightningModuleVisitor(LightningVisitor):
|
|||
"optimizer_zero_grad",
|
||||
"prepare_data",
|
||||
"setup",
|
||||
"tbptt_split_batch",
|
||||
"teardown",
|
||||
"train_dataloader",
|
||||
"val_dataloader",
|
||||
|
@ -256,10 +255,6 @@ class TorchMetricVisitor(LightningVisitor):
|
|||
class_name = "Metric"
|
||||
|
||||
|
||||
class LightningLiteVisitor(LightningVisitor): # deprecated
|
||||
class_name = "LightningLite"
|
||||
|
||||
|
||||
class FabricVisitor(LightningVisitor):
|
||||
class_name = "Fabric"
|
||||
|
||||
|
@ -297,7 +292,6 @@ class Scanner:
|
|||
LightningLoggerVisitor,
|
||||
LightningLoopVisitor,
|
||||
TorchMetricVisitor,
|
||||
LightningLiteVisitor, # deprecated
|
||||
FabricVisitor,
|
||||
LightningProfilerVisitor,
|
||||
]
|
||||
|
|
|
@ -56,6 +56,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Removed the `Trainer(auto_select_gpus=...)` argument
|
||||
* Removed the `pytorch_lightning.tuner.auto_gpu_select.{pick_single_gpu,pick_multiple_gpus}` functions
|
||||
|
||||
- Removed special support for truncated backpropagation through time (TBPTT) ([#16172](https://github.com/Lightning-AI/lightning/pull/16172))
|
||||
* Removed the `LightningModule.truncated_bptt_steps` attribute
|
||||
* Removed the `LightningModule.tbptt_split_batch` hook
|
||||
* The `LightningModule.training_step` no longer accepts a `hiddens` argument
|
||||
* Removed the `pytorch_lightning.loops.batch.TrainingBatchLoop`
|
||||
* Removed the `FitLoop.split_idx` property
|
||||
* Removed the `LoggerConnector.on_train_split_start` method
|
||||
|
||||
|
||||
### Fixed
|
||||
|
||||
|
|
|
@ -279,9 +279,6 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule")
|
|||
if avg_training_loss is not None:
|
||||
items_dict["loss"] = f"{avg_training_loss:.3g}"
|
||||
|
||||
if pl_module.truncated_bptt_steps > 0:
|
||||
items_dict["split_idx"] = trainer.fit_loop.split_idx
|
||||
|
||||
if trainer.loggers:
|
||||
version = _version(trainer.loggers)
|
||||
if version is not None:
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
"""The LightningModule - an nn.Module with many additional features."""
|
||||
|
||||
import collections.abc
|
||||
import logging
|
||||
import numbers
|
||||
import weakref
|
||||
|
@ -89,7 +88,6 @@ class LightningModule(
|
|||
"logger",
|
||||
"loggers",
|
||||
"automatic_optimization",
|
||||
"truncated_bptt_steps",
|
||||
"trainer",
|
||||
"fabric",
|
||||
]
|
||||
|
@ -115,7 +113,6 @@ class LightningModule(
|
|||
self._example_input_array: Optional[Union[Tensor, Tuple, Dict]] = None
|
||||
self._current_fx_name: Optional[str] = None
|
||||
self._automatic_optimization: bool = True
|
||||
self._truncated_bptt_steps: int = 0
|
||||
self._param_requires_grad_state: Dict[str, bool] = {}
|
||||
self._metric_attributes: Optional[Dict[int, str]] = None
|
||||
self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False
|
||||
|
@ -275,20 +272,6 @@ class LightningModule(
|
|||
def automatic_optimization(self, automatic_optimization: bool) -> None:
|
||||
self._automatic_optimization = automatic_optimization
|
||||
|
||||
@property
|
||||
def truncated_bptt_steps(self) -> int:
|
||||
"""Enables `Truncated Backpropagation Through Time` in the Trainer when set to a positive integer.
|
||||
|
||||
It represents
|
||||
the number of times :meth:`training_step` gets called before backpropagation. If this is > 0, the
|
||||
:meth:`training_step` receives an additional argument ``hiddens`` and is expected to return a hidden state.
|
||||
"""
|
||||
return self._truncated_bptt_steps
|
||||
|
||||
@truncated_bptt_steps.setter
|
||||
def truncated_bptt_steps(self, truncated_bptt_steps: int) -> None:
|
||||
self._truncated_bptt_steps = truncated_bptt_steps
|
||||
|
||||
@property
|
||||
def logger(self) -> Optional[Union[Logger, FabricLogger]]:
|
||||
"""Reference to the logger object in the Trainer."""
|
||||
|
@ -683,8 +666,6 @@ class LightningModule(
|
|||
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
||||
batch_idx (``int``): Integer displaying index of this batch
|
||||
optimizer_idx (``int``): When using multiple optimizers, this argument will also be present.
|
||||
hiddens (``Any``): Passed in if
|
||||
:paramref:`~pytorch_lightning.core.module.LightningModule.truncated_bptt_steps` > 0.
|
||||
|
||||
Return:
|
||||
Any of.
|
||||
|
@ -719,19 +700,6 @@ class LightningModule(
|
|||
# do training_step with decoder
|
||||
...
|
||||
|
||||
|
||||
If you add truncated back propagation through time you will also get an additional
|
||||
argument with the hidden states of the previous step.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Truncated back-propagation through time
|
||||
def training_step(self, batch, batch_idx, hiddens):
|
||||
# hiddens are the hidden states from the previous truncated backprop step
|
||||
out, hiddens = self.lstm(data, hiddens)
|
||||
loss = ...
|
||||
return {"loss": loss, "hiddens": hiddens}
|
||||
|
||||
Note:
|
||||
The loss value shown in the progress bar is smoothed (averaged) over the last values,
|
||||
so it differs from the actual loss returned in train/validation step.
|
||||
|
@ -817,9 +785,8 @@ class LightningModule(
|
|||
training_epoch_end(train_outs)
|
||||
|
||||
Args:
|
||||
outputs: List of outputs you defined in :meth:`training_step`. If there are multiple optimizers or when
|
||||
using ``truncated_bptt_steps > 0``, the lists have the dimensions
|
||||
(n_batches, tbptt_steps, n_optimizers). Dimensions of length 1 are squeezed.
|
||||
outputs: List of outputs you defined in :meth:`training_step`. If there are multiple optimizers, the lists
|
||||
have the dimensions (n_batches, n_optimizers). Dimensions of length 1 are squeezed.
|
||||
|
||||
Return:
|
||||
None
|
||||
|
@ -1764,64 +1731,6 @@ class LightningModule(
|
|||
"""
|
||||
optimizer.zero_grad()
|
||||
|
||||
def tbptt_split_batch(self, batch: Any, split_size: int) -> List[Any]:
|
||||
r"""
|
||||
When using truncated backpropagation through time, each batch must be split along the
|
||||
time dimension. Lightning handles this by default, but for custom behavior override
|
||||
this function.
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
split_size: The size of the split
|
||||
|
||||
Return:
|
||||
List of batch splits. Each split will be passed to :meth:`training_step` to enable truncated
|
||||
back propagation through time. The default implementation splits root level Tensors and
|
||||
Sequences at dim=1 (i.e. time dim). It assumes that each time dim is the same length.
|
||||
|
||||
Examples::
|
||||
|
||||
def tbptt_split_batch(self, batch, split_size):
|
||||
splits = []
|
||||
for t in range(0, time_dims[0], split_size):
|
||||
batch_split = []
|
||||
for i, x in enumerate(batch):
|
||||
if isinstance(x, torch.Tensor):
|
||||
split_x = x[:, t:t + split_size]
|
||||
elif isinstance(x, collections.abc.Sequence):
|
||||
split_x = [None] * len(x)
|
||||
for batch_idx in range(len(x)):
|
||||
split_x[batch_idx] = x[batch_idx][t:t + split_size]
|
||||
batch_split.append(split_x)
|
||||
splits.append(batch_split)
|
||||
return splits
|
||||
|
||||
Note:
|
||||
Called in the training loop after
|
||||
:meth:`~pytorch_lightning.callbacks.base.Callback.on_train_batch_start`
|
||||
if :paramref:`~pytorch_lightning.core.module.LightningModule.truncated_bptt_steps` > 0.
|
||||
Each returned batch split is passed separately to :meth:`training_step`.
|
||||
"""
|
||||
time_dims = [len(x[0]) for x in batch if isinstance(x, (Tensor, collections.abc.Sequence))]
|
||||
assert len(time_dims) >= 1, "Unable to determine batch time dimension"
|
||||
assert all(x == time_dims[0] for x in time_dims), "Batch time dimension length is ambiguous"
|
||||
|
||||
splits = []
|
||||
for t in range(0, time_dims[0], split_size):
|
||||
batch_split = []
|
||||
for i, x in enumerate(batch):
|
||||
split_x: Union[Tensor, List[Tensor]]
|
||||
if isinstance(x, Tensor):
|
||||
split_x = x[:, t : t + split_size]
|
||||
elif isinstance(x, collections.abc.Sequence):
|
||||
split_x = [x[batch_idx][t : t + split_size] for batch_idx in range(len(x))]
|
||||
|
||||
batch_split.append(split_x)
|
||||
|
||||
splits.append(batch_split)
|
||||
|
||||
return splits
|
||||
|
||||
def freeze(self) -> None:
|
||||
r"""
|
||||
Freeze all params for inference.
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.loops.loop import Loop # noqa: F401 isort: skip (avoids circular imports)
|
||||
from pytorch_lightning.loops.batch import TrainingBatchLoop # noqa: F401
|
||||
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
|
||||
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
|
||||
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401
|
||||
from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization # noqa: F401
|
|
@ -1,140 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
from torch import Tensor
|
||||
from typing_extensions import OrderedDict
|
||||
|
||||
from pytorch_lightning.loops.loop import Loop
|
||||
from pytorch_lightning.loops.optimization.manual_loop import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE
|
||||
from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization
|
||||
from pytorch_lightning.loops.optimization.optimizer_loop import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE
|
||||
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop
|
||||
from pytorch_lightning.loops.utilities import _get_active_optimizers
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
||||
|
||||
_OUTPUTS_TYPE = List[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]]
|
||||
|
||||
|
||||
class TrainingBatchLoop(Loop[_OUTPUTS_TYPE]):
|
||||
"""Runs over a single batch of data."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.accumulated_loss = TensorRunningAccum(window_length=20)
|
||||
self.running_loss = TensorRunningAccum(window_length=20)
|
||||
# the current split index when the batch gets split into chunks in truncated backprop through time
|
||||
self.split_idx: int = 0
|
||||
self.optimizer_loop = OptimizerLoop()
|
||||
self.manual_loop = ManualOptimization()
|
||||
|
||||
self._outputs: _OUTPUTS_TYPE = []
|
||||
self._remaining_splits: List[Tuple[int, Any]] = []
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
"""Returns if all batch splits have been processed already."""
|
||||
return len(self._remaining_splits) == 0
|
||||
|
||||
def connect( # type: ignore[override]
|
||||
self, optimizer_loop: Optional[OptimizerLoop] = None, manual_loop: Optional[ManualOptimization] = None
|
||||
) -> None:
|
||||
if optimizer_loop is not None:
|
||||
self.optimizer_loop = optimizer_loop
|
||||
if manual_loop is not None:
|
||||
self.manual_loop = manual_loop
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the loop state."""
|
||||
self._outputs = []
|
||||
|
||||
def on_run_start(self, kwargs: OrderedDict) -> None:
|
||||
"""Splits the data into tbptt splits.
|
||||
|
||||
Args:
|
||||
kwargs: the kwargs passed down to the hooks.
|
||||
"""
|
||||
batch = kwargs["batch"]
|
||||
self._remaining_splits = list(enumerate(self._tbptt_split_batch(batch)))
|
||||
|
||||
def advance(self, kwargs: OrderedDict) -> None:
|
||||
"""Runs the train step together with optimization (if necessary) on the current batch split.
|
||||
|
||||
Args:
|
||||
kwargs: the kwargs passed down to the hooks.
|
||||
"""
|
||||
# replace the batch with the split batch
|
||||
self.split_idx, kwargs["batch"] = self._remaining_splits.pop(0)
|
||||
|
||||
self.trainer._logger_connector.on_train_split_start(self.split_idx)
|
||||
|
||||
outputs: Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] = None # for mypy
|
||||
# choose which loop will run the optimization
|
||||
if self.trainer.lightning_module.automatic_optimization:
|
||||
optimizers = _get_active_optimizers(
|
||||
self.trainer.optimizers, self.trainer.optimizer_frequencies, kwargs.get("batch_idx", 0)
|
||||
)
|
||||
outputs = self.optimizer_loop.run(optimizers, kwargs)
|
||||
else:
|
||||
outputs = self.manual_loop.run(kwargs)
|
||||
if outputs:
|
||||
# automatic: can be empty if all optimizers skip their batches
|
||||
# manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens,
|
||||
# then `advance` doesn't finish and an empty dict is returned
|
||||
self._outputs.append(outputs)
|
||||
|
||||
def on_run_end(self) -> _OUTPUTS_TYPE:
|
||||
self.optimizer_loop._hiddens = None
|
||||
# this is not necessary as the manual loop runs for only 1 iteration, but just in case
|
||||
self.manual_loop._hiddens = None
|
||||
output, self._outputs = self._outputs, [] # free memory
|
||||
self._remaining_splits = []
|
||||
return output
|
||||
|
||||
def teardown(self) -> None:
|
||||
self.optimizer_loop.teardown()
|
||||
self.manual_loop.teardown()
|
||||
# release memory
|
||||
if self.accumulated_loss.memory is not None:
|
||||
self.accumulated_loss.memory = self.accumulated_loss.memory.cpu()
|
||||
if self.running_loss.memory is not None:
|
||||
self.running_loss.memory = self.running_loss.memory.cpu()
|
||||
|
||||
def _tbptt_split_batch(self, batch: Any) -> List[Any]:
|
||||
"""Splits a single batch into a list of sequence steps for tbptt.
|
||||
|
||||
Args:
|
||||
batch: the current batch to split
|
||||
"""
|
||||
tbptt_steps = self.trainer.lightning_module.truncated_bptt_steps
|
||||
if tbptt_steps == 0:
|
||||
return [batch]
|
||||
|
||||
splits = self.trainer._call_lightning_module_hook("tbptt_split_batch", batch, tbptt_steps)
|
||||
return splits
|
||||
|
||||
def _update_running_loss(self, current_loss: Tensor) -> None:
|
||||
"""Updates the running loss value with the current value."""
|
||||
if self.trainer.lightning_module.automatic_optimization:
|
||||
# track total loss for logging (avoid mem leaks)
|
||||
self.accumulated_loss.append(current_loss)
|
||||
|
||||
accumulated_loss = self.accumulated_loss.mean()
|
||||
|
||||
if accumulated_loss is not None:
|
||||
# calculate running loss for display
|
||||
self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches)
|
||||
|
||||
# reset for next set of accumulated grads
|
||||
self.accumulated_loss.reset()
|
|
@ -18,15 +18,17 @@ from typing import Any, DefaultDict, Dict, Generator, List, Optional, overload,
|
|||
import numpy as np
|
||||
import torch
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
from torch import Tensor
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import loops # import as loops to avoid circular imports
|
||||
from pytorch_lightning.loops.batch import TrainingBatchLoop
|
||||
from pytorch_lightning.loops.batch.training_batch_loop import _OUTPUTS_TYPE as _BATCH_OUTPUTS_TYPE
|
||||
from pytorch_lightning.loops.optimization import ManualOptimization, OptimizerLoop
|
||||
from pytorch_lightning.loops.optimization.manual_loop import _OUTPUTS_TYPE as _MANUAL_LOOP_OUTPUTS_TYPE
|
||||
from pytorch_lightning.loops.optimization.optimizer_loop import _OUTPUTS_TYPE as _OPTIMIZER_LOOP_OUTPUTS_TYPE
|
||||
from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
|
||||
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader, TensorRunningAccum
|
||||
from pytorch_lightning.utilities.auto_restart import _collect_states_on_rank_zero_over_collection
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
|
||||
|
@ -34,6 +36,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden
|
|||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn, WarningCache
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
|
||||
_BATCH_OUTPUTS_TYPE = Optional[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]]
|
||||
_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
|
||||
|
||||
|
||||
|
@ -57,7 +60,11 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
self.batch_progress = BatchProgress()
|
||||
self.scheduler_progress = SchedulerProgress()
|
||||
|
||||
self.batch_loop = TrainingBatchLoop()
|
||||
self.accumulated_loss = TensorRunningAccum(window_length=20)
|
||||
self.running_loss = TensorRunningAccum(window_length=20)
|
||||
self.optimizer_loop = OptimizerLoop()
|
||||
self.manual_loop = ManualOptimization()
|
||||
|
||||
self.val_loop = loops.EvaluationLoop(verbose=False)
|
||||
|
||||
self._results = _ResultCollection(training=True)
|
||||
|
@ -85,8 +92,8 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
def global_step(self) -> int:
|
||||
lightning_module = self.trainer.lightning_module
|
||||
if lightning_module is None or lightning_module.automatic_optimization:
|
||||
return self.batch_loop.optimizer_loop.optim_progress.optimizer_steps
|
||||
return self.batch_loop.manual_loop.optim_step_progress.total.completed
|
||||
return self.optimizer_loop.optim_progress.optimizer_steps
|
||||
return self.manual_loop.optim_step_progress.total.completed
|
||||
|
||||
@property
|
||||
def _is_training_done(self) -> bool:
|
||||
|
@ -119,12 +126,15 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
|
||||
def connect( # type: ignore[override]
|
||||
self,
|
||||
batch_loop: Optional[TrainingBatchLoop] = None,
|
||||
optimizer_loop: Optional[OptimizerLoop] = None,
|
||||
manual_loop: Optional[ManualOptimization] = None,
|
||||
val_loop: Optional["loops.EvaluationLoop"] = None,
|
||||
) -> None:
|
||||
"""Optionally connect a custom batch or validation loop to this training epoch loop."""
|
||||
if batch_loop is not None:
|
||||
self.batch_loop = batch_loop
|
||||
if optimizer_loop is not None:
|
||||
self.optimizer_loop = optimizer_loop
|
||||
if manual_loop is not None:
|
||||
self.manual_loop = manual_loop
|
||||
if val_loop is not None:
|
||||
self.val_loop = val_loop
|
||||
|
||||
|
@ -133,7 +143,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
if self.restarting:
|
||||
self.batch_progress.reset_on_restart()
|
||||
self.scheduler_progress.reset_on_restart()
|
||||
self.batch_loop.optimizer_loop.optim_progress.reset_on_restart()
|
||||
self.optimizer_loop.optim_progress.reset_on_restart()
|
||||
|
||||
trainer = self.trainer
|
||||
if not trainer.state._fault_tolerant_mode.is_enabled and trainer.num_training_batches != float("inf"):
|
||||
|
@ -148,7 +158,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
else:
|
||||
self.batch_progress.reset_on_run()
|
||||
self.scheduler_progress.reset_on_run()
|
||||
self.batch_loop.optimizer_loop.optim_progress.reset_on_run()
|
||||
self.optimizer_loop.optim_progress.reset_on_run()
|
||||
# when the epoch starts, the total val batch progress should be reset as it's supposed to count the batches
|
||||
# seen per epoch, this is useful for tracking when validation is run multiple times per epoch
|
||||
self.val_loop.epoch_loop.batch_progress.total.reset()
|
||||
|
@ -195,9 +205,9 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
|
||||
self.trainer._logger_connector.on_batch_start(batch, batch_idx)
|
||||
|
||||
batch_output: _BATCH_OUTPUTS_TYPE = None # for mypy
|
||||
if batch is None:
|
||||
self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")
|
||||
batch_output = []
|
||||
else:
|
||||
# hook
|
||||
self.trainer._call_callback_hooks("on_train_batch_start", batch, batch_idx)
|
||||
|
@ -210,7 +220,14 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
self.batch_progress.increment_started()
|
||||
|
||||
with self.trainer.profiler.profile("run_training_batch"):
|
||||
batch_output = self.batch_loop.run(kwargs)
|
||||
# choose which loop will run the optimization
|
||||
if self.trainer.lightning_module.automatic_optimization:
|
||||
optimizers = _get_active_optimizers(
|
||||
self.trainer.optimizers, self.trainer.optimizer_frequencies, kwargs.get("batch_idx", 0)
|
||||
)
|
||||
batch_output = self.optimizer_loop.run(optimizers, kwargs)
|
||||
else:
|
||||
batch_output = self.manual_loop.run(kwargs)
|
||||
|
||||
self.batch_progress.increment_processed()
|
||||
|
||||
|
@ -232,7 +249,11 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
|
||||
self.batch_progress.increment_completed()
|
||||
|
||||
if is_overridden("training_epoch_end", self.trainer.lightning_module):
|
||||
if batch_output and is_overridden("training_epoch_end", self.trainer.lightning_module):
|
||||
# batch_output may be empty
|
||||
# automatic: can be empty if all optimizers skip their batches
|
||||
# manual: #9052 added support for raising `StopIteration` in the `training_step`. If that happens,
|
||||
# then `advance` doesn't finish and an empty dict is returned
|
||||
self._outputs.append(batch_output)
|
||||
|
||||
# -----------------------------------------
|
||||
|
@ -254,7 +275,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
self.update_lr_schedulers("step", update_plateau_schedulers=True)
|
||||
|
||||
if not self._should_accumulate():
|
||||
# this is increased once per batch disregarding multiple optimizers or tbptt on purpose for loggers
|
||||
# this is increased once per batch disregarding multiple optimizers on purpose for loggers
|
||||
self._batches_that_stepped += 1
|
||||
# this will save based on the `batches_that_stepped` value
|
||||
self._save_loggers_on_train_batch_end()
|
||||
|
@ -271,7 +292,13 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
|
||||
def teardown(self) -> None:
|
||||
self._results.cpu()
|
||||
self.batch_loop.teardown()
|
||||
self.optimizer_loop.teardown()
|
||||
self.manual_loop.teardown()
|
||||
# release memory
|
||||
if self.accumulated_loss.memory is not None:
|
||||
self.accumulated_loss.memory = self.accumulated_loss.memory.cpu()
|
||||
if self.running_loss.memory is not None:
|
||||
self.running_loss.memory = self.running_loss.memory.cpu()
|
||||
self.val_loop.teardown()
|
||||
|
||||
def on_save_checkpoint(self) -> Dict:
|
||||
|
@ -527,6 +554,21 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
kwargs["batch_idx"] = batch_idx
|
||||
return kwargs
|
||||
|
||||
def _update_running_loss(self, current_loss: Tensor) -> None:
|
||||
"""Updates the running loss value with the current value."""
|
||||
if self.trainer.lightning_module.automatic_optimization:
|
||||
# track total loss for logging (avoid mem leaks)
|
||||
self.accumulated_loss.append(current_loss)
|
||||
|
||||
accumulated_loss = self.accumulated_loss.mean()
|
||||
|
||||
if accumulated_loss is not None:
|
||||
# calculate running loss for display
|
||||
self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches)
|
||||
|
||||
# reset for next set of accumulated grads
|
||||
self.accumulated_loss.reset()
|
||||
|
||||
|
||||
def _convert_optim_dict(outs: Dict[int, Dict[str, Any]], num_optimizers: int) -> List[Optional[Dict[str, Any]]]:
|
||||
"""Converts an optimizer dict to a list in which the key of the dict determines the position of the element.
|
||||
|
|
|
@ -77,11 +77,6 @@ class FitLoop(Loop[None]):
|
|||
"""Returns the current batch index (within this epoch)"""
|
||||
return self.epoch_loop.batch_idx
|
||||
|
||||
@property
|
||||
def split_idx(self) -> int:
|
||||
"""Returns the index of the current batch split (within the current batch) for bptt."""
|
||||
return self.epoch_loop.batch_loop.split_idx
|
||||
|
||||
@property
|
||||
def min_steps(self) -> Optional[int]:
|
||||
# TODO(@justusschock): Why aren't we using the attribute in this class?
|
||||
|
@ -112,7 +107,7 @@ class FitLoop(Loop[None]):
|
|||
@property
|
||||
def running_loss(self) -> TensorRunningAccum:
|
||||
"""Returns the running loss."""
|
||||
return self.epoch_loop.batch_loop.running_loss
|
||||
return self.epoch_loop.running_loss
|
||||
|
||||
@Loop.restarting.setter
|
||||
def restarting(self, restarting: bool) -> None:
|
||||
|
@ -131,12 +126,12 @@ class FitLoop(Loop[None]):
|
|||
@property
|
||||
def _skip_backward(self) -> bool:
|
||||
"""Determines whether the loop will skip backward during automatic optimization."""
|
||||
return self.epoch_loop.batch_loop.optimizer_loop._skip_backward
|
||||
return self.epoch_loop.optimizer_loop._skip_backward
|
||||
|
||||
@_skip_backward.setter
|
||||
def _skip_backward(self, value: bool) -> None:
|
||||
"""Determines whether the loop will skip backward during automatic optimization."""
|
||||
self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value
|
||||
self.epoch_loop.optimizer_loop._skip_backward = value
|
||||
|
||||
@property
|
||||
def _results(self) -> _ResultCollection:
|
||||
|
@ -239,7 +234,7 @@ class FitLoop(Loop[None]):
|
|||
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)
|
||||
|
||||
# stores accumulated grad fractions per batch
|
||||
self.epoch_loop.batch_loop.accumulated_loss.reset(window_length=self.trainer.accumulate_grad_batches)
|
||||
self.epoch_loop.accumulated_loss.reset(window_length=self.trainer.accumulate_grad_batches)
|
||||
|
||||
self.epoch_progress.increment_ready()
|
||||
|
||||
|
|
|
@ -145,7 +145,8 @@ class Loop(ABC, Generic[T]):
|
|||
|
||||
# connect sub-loops
|
||||
kwargs = {n: lp for n, lp in old_loop.__dict__.items() if isinstance(lp, Loop)}
|
||||
loop.connect(**kwargs)
|
||||
if kwargs:
|
||||
loop.connect(**kwargs)
|
||||
# set the trainer reference
|
||||
loop.trainer = self.trainer
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from torch import Tensor
|
|||
from pytorch_lightning.core.optimizer import do_nothing_closure
|
||||
from pytorch_lightning.loops import Loop
|
||||
from pytorch_lightning.loops.optimization.closure import OutputResult
|
||||
from pytorch_lightning.loops.utilities import _build_training_step_kwargs, _extract_hiddens
|
||||
from pytorch_lightning.loops.utilities import _build_training_step_kwargs
|
||||
from pytorch_lightning.trainer.progress import Progress, ReadyCompletedTracker
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
@ -42,7 +42,7 @@ class ManualResult(OutputResult):
|
|||
def from_training_step_output(cls, training_step_output: Optional[STEP_OUTPUT]) -> "ManualResult":
|
||||
extra = {}
|
||||
if isinstance(training_step_output, dict):
|
||||
extra = {k: v for k, v in training_step_output.items() if k != "hiddens"}
|
||||
extra = training_step_output.copy()
|
||||
elif isinstance(training_step_output, Tensor):
|
||||
extra = {"loss": training_step_output}
|
||||
elif training_step_output is not None:
|
||||
|
@ -82,7 +82,6 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
|
|||
self.optim_step_progress = Progress.from_defaults(ReadyCompletedTracker)
|
||||
|
||||
self._done: bool = False
|
||||
self._hiddens: Optional[Any] = None
|
||||
self._output: _OUTPUTS_TYPE = {}
|
||||
|
||||
@property
|
||||
|
@ -104,7 +103,7 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
|
|||
Args:
|
||||
kwargs: The kwargs passed down to the hooks.
|
||||
"""
|
||||
kwargs = self._build_kwargs(kwargs, self._hiddens)
|
||||
kwargs = self._build_kwargs(kwargs)
|
||||
|
||||
# manually capture logged metrics
|
||||
training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
|
||||
|
@ -114,12 +113,11 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
|
|||
model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
|
||||
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
|
||||
training_step_output = strategy_output if model_output is None else model_output
|
||||
self._hiddens = _extract_hiddens(training_step_output, self.trainer.lightning_module.truncated_bptt_steps)
|
||||
|
||||
result = self.output_result_cls.from_training_step_output(training_step_output)
|
||||
|
||||
if self.trainer.move_metrics_to_cpu:
|
||||
# hiddens and the training step output are not moved as they are not considered "metrics"
|
||||
# training step output does not get moved because it is not considered a "metric"
|
||||
# the user might need them on the correct device for an operation in `training_epoch_end`
|
||||
assert self.trainer._results is not None
|
||||
self.trainer._results.cpu()
|
||||
|
@ -144,16 +142,13 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
|
|||
self.trainer.profiler.stop("optimizer_step")
|
||||
self.optim_step_progress.increment_completed()
|
||||
|
||||
def _build_kwargs(self, kwargs: OrderedDict, hiddens: Optional[Any]) -> OrderedDict:
|
||||
def _build_kwargs(self, kwargs: OrderedDict) -> OrderedDict:
|
||||
"""Helper method to build the arguments for the current step.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs passed down to the hooks.
|
||||
hiddens: the hidden state of the previous RNN iteration.
|
||||
|
||||
Returns:
|
||||
The kwargs passed down to the hooks.
|
||||
"""
|
||||
return _build_training_step_kwargs(
|
||||
kwargs, self.trainer.lightning_module, self.trainer.optimizers, None, hiddens
|
||||
)
|
||||
return _build_training_step_kwargs(kwargs, self.trainer.lightning_module, self.trainer.optimizers, None)
|
||||
|
|
|
@ -24,11 +24,7 @@ from pytorch_lightning.accelerators import TPUAccelerator
|
|||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.loops import Loop
|
||||
from pytorch_lightning.loops.optimization.closure import AbstractClosure, OutputResult
|
||||
from pytorch_lightning.loops.utilities import (
|
||||
_block_parallel_sync_behavior,
|
||||
_build_training_step_kwargs,
|
||||
_extract_hiddens,
|
||||
)
|
||||
from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior, _build_training_step_kwargs
|
||||
from pytorch_lightning.trainer.progress import OptimizationProgress
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import WarningCache
|
||||
|
@ -72,7 +68,7 @@ class ClosureResult(OutputResult):
|
|||
raise MisconfigurationException(
|
||||
"In automatic_optimization, when `training_step` returns a dict, the 'loss' key needs to be present"
|
||||
)
|
||||
extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}
|
||||
extra = {k: v for k, v in training_step_output.items() if k != "loss"}
|
||||
elif isinstance(training_step_output, Tensor):
|
||||
closure_loss = training_step_output
|
||||
elif training_step_output is not None:
|
||||
|
@ -166,7 +162,6 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
self._skip_backward: bool = False
|
||||
self._optimizers: Tuple[Optimizer, ...] = tuple()
|
||||
self._indices: Tuple[int, ...] = tuple()
|
||||
self._hiddens: Optional[Any] = None
|
||||
|
||||
@property
|
||||
def optimizer_idx(self) -> int:
|
||||
|
@ -194,7 +189,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
self.optim_progress.optimizer_position = 0
|
||||
|
||||
def advance(self, optimizers: List[Tuple[int, Optimizer]], kwargs: OrderedDict) -> None:
|
||||
kwargs = self._build_kwargs(kwargs, self.optimizer_idx, self._hiddens)
|
||||
kwargs = self._build_kwargs(kwargs, self.optimizer_idx)
|
||||
|
||||
result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
|
||||
if result.loss is not None:
|
||||
|
@ -251,7 +246,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
# if no result, user decided to skip optimization
|
||||
# otherwise update running loss + reset accumulated loss
|
||||
# TODO: find proper way to handle updating running loss
|
||||
self.trainer.fit_loop.epoch_loop.batch_loop._update_running_loss(result.loss)
|
||||
self.trainer.fit_loop.epoch_loop._update_running_loss(result.loss)
|
||||
|
||||
# untoggle model params
|
||||
self._run_optimization_end(opt_idx)
|
||||
|
@ -404,30 +399,25 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
|
||||
training_step_output = strategy_output if model_output is None else model_output
|
||||
|
||||
self._hiddens = _extract_hiddens(training_step_output, self.trainer.lightning_module.truncated_bptt_steps)
|
||||
|
||||
result = self.output_result_cls.from_training_step_output(
|
||||
training_step_output, self.trainer.accumulate_grad_batches
|
||||
)
|
||||
|
||||
if self.trainer.move_metrics_to_cpu:
|
||||
# hiddens and the training step output are not moved as they are not considered "metrics"
|
||||
# training step output does not get moved because it is not considered a "metric"
|
||||
assert self.trainer._results is not None
|
||||
self.trainer._results.cpu()
|
||||
|
||||
return result
|
||||
|
||||
def _build_kwargs(self, kwargs: OrderedDict, opt_idx: int, hiddens: Optional[Any]) -> OrderedDict:
|
||||
def _build_kwargs(self, kwargs: OrderedDict, opt_idx: int) -> OrderedDict:
|
||||
"""Helper method to build the arguments for the current step.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs passed down to the hooks.
|
||||
opt_idx: the index of the current optimizer.
|
||||
hiddens: the hidden state of the previous RNN iteration.
|
||||
|
||||
Returns:
|
||||
The kwargs passed down to the hooks.
|
||||
"""
|
||||
return _build_training_step_kwargs(
|
||||
kwargs, self.trainer.lightning_module, self.trainer.optimizers, opt_idx, hiddens
|
||||
)
|
||||
return _build_training_step_kwargs(kwargs, self.trainer.lightning_module, self.trainer.optimizers, opt_idx)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
from typing import Any, Generator, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Generator, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -30,11 +30,8 @@ from pytorch_lightning.strategies.parallel import ParallelStrategy
|
|||
from pytorch_lightning.strategies.strategy import Strategy
|
||||
from pytorch_lightning.trainer.progress import BaseProgress
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.memory import recursive_detach
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
||||
|
||||
def check_finite_loss(loss: Optional[Tensor]) -> None:
|
||||
|
@ -47,28 +44,6 @@ def check_finite_loss(loss: Optional[Tensor]) -> None:
|
|||
raise ValueError(f"The loss returned in `training_step` is {loss}.")
|
||||
|
||||
|
||||
def _extract_hiddens(training_step_output: STEP_OUTPUT, truncated_bptt_steps: int) -> Optional[Any]:
|
||||
"""Get the hidden state if present from the training step output.
|
||||
|
||||
Raises:
|
||||
MisconfigurationException: If :attr:`~pytorch_lightning.core.Lightning.LightningModule.truncated_bptt_steps` is
|
||||
not enabled and hiddens are returned or vice versa.
|
||||
"""
|
||||
if not truncated_bptt_steps:
|
||||
if isinstance(training_step_output, dict) and "hiddens" in training_step_output:
|
||||
raise MisconfigurationException(
|
||||
'You returned "hiddens" in your `training_step` but `truncated_bptt_steps` is disabled'
|
||||
)
|
||||
return None
|
||||
if not isinstance(training_step_output, dict) or "hiddens" not in training_step_output:
|
||||
raise MisconfigurationException(
|
||||
'You enabled `truncated_bptt_steps` but did not `return {..., "hiddens": ...}` in your `training_step`'
|
||||
)
|
||||
# detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time`
|
||||
hiddens = recursive_detach(training_step_output["hiddens"])
|
||||
return hiddens
|
||||
|
||||
|
||||
def _parse_loop_limits(
|
||||
min_steps: Optional[int],
|
||||
max_steps: int,
|
||||
|
@ -116,7 +91,6 @@ def _build_training_step_kwargs(
|
|||
lightning_module: "pl.LightningModule",
|
||||
optimizers: Sequence[Optimizer],
|
||||
opt_idx: Optional[int],
|
||||
hiddens: Optional[Any],
|
||||
) -> OrderedDict:
|
||||
"""Builds the keyword arguments for training_step.
|
||||
|
||||
|
@ -125,7 +99,6 @@ def _build_training_step_kwargs(
|
|||
lightning_module: the LightningModule with a `training_step` hook implementation
|
||||
optimizers: the list of optimizers from the Trainer
|
||||
opt_idx: the index of the current optimizer
|
||||
hiddens: the hidden state of the previous RNN iteration
|
||||
|
||||
Returns:
|
||||
the keyword arguments for the training step
|
||||
|
@ -147,10 +120,6 @@ def _build_training_step_kwargs(
|
|||
" `training_step` is missing the `optimizer_idx` argument."
|
||||
)
|
||||
|
||||
# pass hiddens if using tbptt
|
||||
if lightning_module.truncated_bptt_steps > 0:
|
||||
kwargs["hiddens"] = hiddens
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
|
|
|
@ -165,9 +165,3 @@ def __check_training_step_requires_dataloader_iter(model: "pl.LightningModule")
|
|||
" not match with the actual batch index when using a `dataloader_iter`"
|
||||
" argument in your `training_step`."
|
||||
)
|
||||
|
||||
if model.truncated_bptt_steps > 0:
|
||||
raise MisconfigurationException(
|
||||
"The model taking a `dataloader_iter` argument in your `training_step` "
|
||||
"is incompatible with `truncated_bptt_steps > 0`."
|
||||
)
|
||||
|
|
|
@ -37,7 +37,6 @@ class LoggerConnector:
|
|||
self._epoch_end_reached = False
|
||||
self._current_fx: Optional[str] = None
|
||||
self._batch_idx: Optional[int] = None
|
||||
self._split_idx: Optional[int] = None
|
||||
|
||||
def on_trainer_init(
|
||||
self,
|
||||
|
@ -144,9 +143,6 @@ class LoggerConnector:
|
|||
Train metric updates
|
||||
"""
|
||||
|
||||
def on_train_split_start(self, split_idx: int) -> None:
|
||||
self._split_idx = split_idx
|
||||
|
||||
def update_train_step_metrics(self) -> None:
|
||||
if self.trainer.fit_loop._should_accumulate() and self.trainer.lightning_module.automatic_optimization:
|
||||
return
|
||||
|
@ -185,7 +181,6 @@ class LoggerConnector:
|
|||
def epoch_end_reached(self) -> None:
|
||||
self._epoch_end_reached = True
|
||||
self._batch_idx = None
|
||||
self._split_idx = None
|
||||
|
||||
def on_epoch_end(self) -> None:
|
||||
assert self._epoch_end_reached
|
||||
|
@ -209,10 +204,7 @@ class LoggerConnector:
|
|||
|
||||
def should_reset_tensors(self, fx: str) -> bool:
|
||||
is_different_fx = self._current_fx != fx
|
||||
if self._split_idx is None:
|
||||
is_first_batch = self._batch_idx in (None, 0)
|
||||
else:
|
||||
is_first_batch = bool(self._batch_idx) + self._split_idx == 0
|
||||
is_first_batch = self._batch_idx in (None, 0)
|
||||
return is_different_fx and is_first_batch
|
||||
|
||||
def reset_metrics(self) -> None:
|
||||
|
@ -226,7 +218,6 @@ class LoggerConnector:
|
|||
results.reset()
|
||||
|
||||
self._batch_idx = None
|
||||
self._split_idx = None
|
||||
self._current_fx = None
|
||||
|
||||
@property
|
||||
|
|
|
@ -1944,7 +1944,7 @@ class Trainer:
|
|||
def global_step(self) -> int:
|
||||
"""The number of optimizer steps taken (does not reset each epoch).
|
||||
|
||||
This includes multiple optimizers and TBPTT steps (if enabled).
|
||||
This includes multiple optimizers (if enabled).
|
||||
"""
|
||||
return self.fit_loop.epoch_loop.global_step
|
||||
|
||||
|
|
|
@ -335,8 +335,8 @@ def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None:
|
|||
|
||||
def _reset_progress(trainer: "pl.Trainer") -> None:
|
||||
if trainer.lightning_module.automatic_optimization:
|
||||
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.reset()
|
||||
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.reset()
|
||||
else:
|
||||
trainer.fit_loop.epoch_loop.batch_loop.manual_loop.optim_step_progress.reset()
|
||||
trainer.fit_loop.epoch_loop.manual_loop.optim_step_progress.reset()
|
||||
|
||||
trainer.fit_loop.epoch_progress.reset()
|
||||
|
|
|
@ -46,7 +46,7 @@ def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]:
|
|||
"1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking],
|
||||
"1.6.5": [_migrate_loop_batches_that_stepped],
|
||||
"1.9.0": [_migrate_model_checkpoint_save_on_train_epoch_end_default],
|
||||
"2.0.0": [_drop_apex_amp_state],
|
||||
"2.0.0": [_drop_apex_amp_state, _migrate_loop_structure_after_tbptt_removal],
|
||||
}
|
||||
|
||||
|
||||
|
@ -219,3 +219,40 @@ def _drop_apex_amp_state(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
|
|||
rank_zero_warn("This checkpoint contains apex AMP data, but apex support has been removed in v2.0.0.")
|
||||
del checkpoint[key]
|
||||
return checkpoint
|
||||
|
||||
|
||||
def _migrate_loop_structure_after_tbptt_removal(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
|
||||
"""Adjusts the loop structure since it changed when the support for truncated backpropagation was removed. The
|
||||
optimizer loop and the manual loop were previously children of the training batch loop. After its removal, they
|
||||
became the children of the training epoch loop.
|
||||
|
||||
Version: 2.0.0
|
||||
Commit: TBD
|
||||
PR: #16172
|
||||
"""
|
||||
if "loops" not in checkpoint:
|
||||
return checkpoint
|
||||
|
||||
fit_loop = checkpoint["loops"]["fit_loop"]
|
||||
|
||||
# remap `x.batch_loop.y` to `x.y`
|
||||
old_key_new_key_mapping = {
|
||||
"epoch_loop.batch_loop.manual_loop.optim_step_progress": "epoch_loop.manual_loop.optim_step_progress",
|
||||
"epoch_loop.batch_loop.manual_loop.state_dict": "epoch_loop.manual_loop.state_dict",
|
||||
"epoch_loop.batch_loop.optimizer_loop.optim_progress": "epoch_loop.optimizer_loop.optim_progress",
|
||||
"epoch_loop.batch_loop.optimizer_loop.state_dict": "epoch_loop.optimizer_loop.state_dict",
|
||||
}
|
||||
for old, new in list(old_key_new_key_mapping.items()):
|
||||
if old in fit_loop:
|
||||
fit_loop[new] = fit_loop[old]
|
||||
del fit_loop[old]
|
||||
|
||||
# We can safely drop this key: our default implementation of `batch_loop` did not have state.
|
||||
# If there was state from a custom batch loop, we wouldn't be able to load it meaningfully.
|
||||
# But just in case, we save a copy of it in `epoch_loop.state_dict` in case the user wants to process it after
|
||||
# loading the checkpoint.
|
||||
if "epoch_loop.batch_loop.state_dict" in fit_loop and fit_loop["epoch_loop.batch_loop.state_dict"]:
|
||||
fit_loop["epoch_loop.state_dict"]["old_batch_loop_state_dict"] = fit_loop["epoch_loop.batch_loop.state_dict"]
|
||||
fit_loop.pop("epoch_loop.batch_loop.state_dict", None)
|
||||
|
||||
return checkpoint
|
||||
|
|
|
@ -148,5 +148,6 @@ def _set_legacy_version(checkpoint: _CHECKPOINT, version: str) -> None:
|
|||
|
||||
def _should_upgrade(checkpoint: _CHECKPOINT, target: str, max_version: Optional[str] = None) -> bool:
|
||||
"""Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target."""
|
||||
is_lte_max_version = max_version is None or Version(target) <= Version(max_version)
|
||||
return Version(_get_version(checkpoint)) < Version(target) and is_lte_max_version
|
||||
target_version = Version(target)
|
||||
is_lte_max_version = max_version is None or target_version <= Version(max_version)
|
||||
return is_lte_max_version and Version(_get_version(checkpoint)) < target_version
|
||||
|
|
|
@ -659,10 +659,8 @@ def test_get_progress_bar_metrics(tmpdir: str):
|
|||
)
|
||||
model = BoringModel()
|
||||
trainer.fit(model)
|
||||
model.truncated_bptt_steps = 2
|
||||
standard_metrics = progress_bar.get_metrics(trainer, model)
|
||||
assert "loss" in standard_metrics.keys()
|
||||
assert "split_idx" in standard_metrics.keys()
|
||||
assert "v_num" not in standard_metrics.keys()
|
||||
|
||||
|
||||
|
|
|
@ -1,205 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
from pytorch_lightning import LightningModule, Trainer
|
||||
|
||||
|
||||
class LSTMModel(LightningModule):
|
||||
"""LSTM sequence-to-sequence model for testing TBPTT with automatic optimization."""
|
||||
|
||||
def __init__(self, truncated_bptt_steps=2, input_size=1, hidden_size=8):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.lstm = torch.nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)
|
||||
self.truncated_bptt_steps = truncated_bptt_steps
|
||||
self.automatic_optimization = True
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.SGD(self.parameters(), lr=0.01)
|
||||
|
||||
def training_step(self, batch, batch_idx, hiddens):
|
||||
x, y = batch
|
||||
pred, hiddens = self.lstm(x, hiddens)
|
||||
loss = F.mse_loss(pred, y)
|
||||
return {"loss": loss, "hiddens": hiddens}
|
||||
|
||||
def train_dataloader(self):
|
||||
dataset = TensorDataset(torch.rand(16, 8, self.input_size), torch.rand(16, 8, self.input_size))
|
||||
return DataLoader(dataset=dataset, batch_size=4)
|
||||
|
||||
|
||||
class ManualLSTMModel(LSTMModel):
|
||||
"""LSTM sequence-to-sequence model for testing TBPTT with manual optimization."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.automatic_optimization = False
|
||||
|
||||
def training_step(self, batch, batch_idx, hiddens):
|
||||
out = super().training_step(batch, batch_idx, hiddens)
|
||||
loss, hiddens = out["loss"], out["hiddens"]
|
||||
opt = self.optimizers()
|
||||
opt.zero_grad()
|
||||
self.manual_backward(loss)
|
||||
opt.step()
|
||||
return {"loss": loss, "hiddens": hiddens}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", (LSTMModel, ManualLSTMModel))
|
||||
def test_persistent_hidden_state_transfer(tmpdir, model_class):
|
||||
"""Test that the hidden state reference gets passed through from one training_step to the next and remains
|
||||
unmodified apart from detached grad_fn."""
|
||||
|
||||
class TBPTTModel(model_class):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.test_hidden = None
|
||||
|
||||
def training_step(self, batch, batch_idx, hiddens):
|
||||
split_idx = self.trainer.fit_loop.split_idx
|
||||
# the hidden state may only be None for the first split_idx
|
||||
assert not ((split_idx == 0) ^ (hiddens is None))
|
||||
# test_hiddens is None when hiddens is None
|
||||
assert not ((hiddens is None) ^ (self.test_hidden is None))
|
||||
# the states are equal (persistent)
|
||||
assert hiddens is None or all(torch.equal(h, th) for h, th in zip(hiddens, self.test_hidden))
|
||||
# the incoming hidden state never has a grad_fn (gets automatically detached)
|
||||
assert hiddens is None or all(h.grad_fn is None for h in hiddens)
|
||||
out = super().training_step(batch, batch_idx, hiddens)
|
||||
|
||||
# store hiddens, assert persistence in next training_step
|
||||
self.test_hidden = out["hiddens"]
|
||||
|
||||
# hiddens may have grad_fn when returning, gets automatically detached
|
||||
assert all(h.grad_fn is not None for h in self.test_hidden)
|
||||
return out
|
||||
|
||||
def on_train_batch_start(self, *_, **__) -> None:
|
||||
self.test_hidden = None
|
||||
|
||||
model = TBPTTModel(truncated_bptt_steps=2, input_size=1, hidden_size=8)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
enable_model_summary=False,
|
||||
logger=False,
|
||||
enable_checkpointing=False,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", (LSTMModel, ManualLSTMModel))
|
||||
def test_tbptt_split_shapes(tmpdir, model_class):
|
||||
"""Test that the sequence data gets split correctly and that the outputs are correctly passed from hook to
|
||||
hook."""
|
||||
batch_size = 10
|
||||
truncated_bptt_steps = 2
|
||||
n, t, f = 32, 15, 1 # (num samples, sequence size, input size)
|
||||
assert t % truncated_bptt_steps != 0, "test must run with sequence length not divisible by tbptt steps"
|
||||
|
||||
seq2seq_dataset = TensorDataset(torch.rand(n, t, f), torch.rand(n, t, f))
|
||||
train_dataloader = DataLoader(dataset=seq2seq_dataset, batch_size=batch_size)
|
||||
|
||||
class TBPTTModel(model_class):
|
||||
def training_step(self, batch, batch_idx, hiddens):
|
||||
x, y = batch
|
||||
if self.trainer.fit_loop.epoch_loop.batch_loop.done:
|
||||
# last split idx, not aligned
|
||||
assert x.shape[1] == t % truncated_bptt_steps
|
||||
assert y.shape[1] == t % truncated_bptt_steps
|
||||
else:
|
||||
assert x.shape[1] == truncated_bptt_steps
|
||||
assert y.shape[1] == truncated_bptt_steps
|
||||
return super().training_step(batch, batch_idx, hiddens)
|
||||
|
||||
def training_epoch_end(self, training_step_outputs):
|
||||
training_step_outputs = training_step_outputs[0]
|
||||
assert len(training_step_outputs) == math.ceil(t / self.truncated_bptt_steps)
|
||||
assert all(out["loss"].grad_fn is None for out in training_step_outputs)
|
||||
assert all("hiddens" not in out for out in training_step_outputs)
|
||||
|
||||
model = TBPTTModel(truncated_bptt_steps=truncated_bptt_steps, input_size=f, hidden_size=8)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
enable_model_summary=False,
|
||||
logger=False,
|
||||
enable_checkpointing=False,
|
||||
)
|
||||
trainer.fit(model, train_dataloaders=train_dataloader)
|
||||
|
||||
assert trainer.fit_loop.batch_idx == n // batch_size
|
||||
assert trainer.fit_loop.split_idx == t // truncated_bptt_steps
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_class", (LSTMModel, ManualLSTMModel))
|
||||
def test_tbptt_logging(tmpdir, model_class):
|
||||
"""Test step-level and epoch-level logging works with TBPTT."""
|
||||
|
||||
class TBPTTModel(model_class):
|
||||
def training_step(self, *args, **kwargs):
|
||||
out = super().training_step(*args, **kwargs)
|
||||
self.log("loss", out["loss"], on_step=True, on_epoch=True)
|
||||
return out
|
||||
|
||||
model = TBPTTModel(truncated_bptt_steps=2)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
log_every_n_steps=2,
|
||||
enable_model_summary=False,
|
||||
enable_checkpointing=False,
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert set(trainer.logged_metrics) == {"loss_step", "loss_epoch"}
|
||||
|
||||
|
||||
def test_hiddens_multiple_optimizers(tmpdir):
|
||||
class TBPTTModel(LSTMModel):
|
||||
# TODO: `optimizer_idx=n` gets the hiddens from `optimizer_idx=n-1` instead of the hidden from
|
||||
# `optimizer_idx=n`, `split_idx=m-1`. This is unexpected and should be changed
|
||||
test_hiddens = None
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx, hiddens):
|
||||
if hiddens is None:
|
||||
assert self.test_hiddens is None
|
||||
else:
|
||||
assert all(torch.equal(h, th) for h, th in zip(hiddens, self.test_hiddens))
|
||||
out = super().training_step(batch, batch_idx, hiddens)
|
||||
self.test_hiddens = out["hiddens"]
|
||||
return out
|
||||
|
||||
def configure_optimizers(self):
|
||||
return [super().configure_optimizers(), super().configure_optimizers()]
|
||||
|
||||
model = TBPTTModel(truncated_bptt_steps=2, input_size=1, hidden_size=1)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_train_batches=1,
|
||||
limit_val_batches=0,
|
||||
enable_model_summary=False,
|
||||
logger=False,
|
||||
enable_checkpointing=False,
|
||||
enable_progress_bar=False,
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert trainer.global_step == 8 / 2 * 2 # time_dim_length / tbptt_steps * num_optimizers
|
|
@ -32,73 +32,56 @@ _out13 = {"loss": 1.3}
|
|||
|
||||
|
||||
class TestPrepareOutputs:
|
||||
def prepare_outputs(self, fn, tbptt_splits, batch_outputs, num_optimizers, automatic_optimization):
|
||||
def prepare_outputs(self, fn, batch_outputs, num_optimizers, automatic_optimization):
|
||||
lightning_module = LightningModule()
|
||||
lightning_module.automatic_optimization = automatic_optimization
|
||||
lightning_module.truncated_bptt_steps = tbptt_splits
|
||||
return fn(
|
||||
batch_outputs,
|
||||
lightning_module=lightning_module,
|
||||
num_optimizers=num_optimizers, # does not matter for manual optimization
|
||||
)
|
||||
|
||||
def prepare_outputs_training_epoch_end(
|
||||
self, tbptt_splits, batch_outputs, num_optimizers, automatic_optimization=True
|
||||
):
|
||||
def prepare_outputs_training_epoch_end(self, batch_outputs, num_optimizers, automatic_optimization=True):
|
||||
return self.prepare_outputs(
|
||||
TrainingEpochLoop._prepare_outputs_training_epoch_end,
|
||||
tbptt_splits,
|
||||
batch_outputs,
|
||||
num_optimizers,
|
||||
automatic_optimization=automatic_optimization,
|
||||
)
|
||||
|
||||
def prepare_outputs_training_batch_end(
|
||||
self, tbptt_splits, batch_outputs, num_optimizers, automatic_optimization=True
|
||||
):
|
||||
def prepare_outputs_training_batch_end(self, batch_outputs, num_optimizers, automatic_optimization=True):
|
||||
return self.prepare_outputs(
|
||||
TrainingEpochLoop._prepare_outputs_training_batch_end,
|
||||
tbptt_splits,
|
||||
batch_outputs,
|
||||
num_optimizers,
|
||||
automatic_optimization=automatic_optimization,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_optimizers,tbptt_splits,batch_outputs,expected",
|
||||
"num_optimizers,batch_outputs,expected",
|
||||
[
|
||||
(1, 0, [], []),
|
||||
(1, 0, [[]], []),
|
||||
(1, [], []),
|
||||
(1, [[]], []),
|
||||
# 1 batch
|
||||
(1, 0, [[{0: _out00}]], [_out00]),
|
||||
(1, [[{0: _out00}]], [_out00]),
|
||||
# 2 batches
|
||||
(1, 0, [[{0: _out00}], [{0: _out01}]], [_out00, _out01]),
|
||||
(1, [[{0: _out00}], [{0: _out01}]], [_out00, _out01]),
|
||||
# 1 batch, 2 optimizers
|
||||
(2, 0, [[{0: _out00, 1: _out01}]], [_out00, _out01]),
|
||||
(2, [[{0: _out00, 1: _out01}]], [_out00, _out01]),
|
||||
# 2 batches, 2 optimizers
|
||||
(2, 0, [[{0: _out00, 1: _out01}], [{0: _out10, 1: _out11}]], [[_out00, _out01], [_out10, _out11]]),
|
||||
(2, [[{0: _out00, 1: _out01}], [{0: _out10, 1: _out11}]], [[_out00, _out01], [_out10, _out11]]),
|
||||
# 4 batches, 2 optimizers, different frequency
|
||||
(
|
||||
2,
|
||||
0,
|
||||
[[{0: _out00}], [{1: _out10}], [{1: _out11}], [{0: _out01}]],
|
||||
[[_out00], [_out10], [_out11], [_out01]],
|
||||
),
|
||||
# 1 batch, tbptt with 2 splits (uneven)
|
||||
(1, 2, [[{0: _out00}, {0: _out01}], [{0: _out03}]], [[_out00, _out01], [_out03]]),
|
||||
# 3 batches, tbptt with 2 splits, 2 optimizers alternating
|
||||
(
|
||||
2,
|
||||
2,
|
||||
[[{0: _out00}, {0: _out01}], [{1: _out10}, {1: _out11}], [{0: _out02}, {0: _out03}]],
|
||||
[[[_out00], [_out01]], [[_out10], [_out11]], [[_out02], [_out03]]],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_prepare_outputs_training_epoch_end_automatic(self, num_optimizers, tbptt_splits, batch_outputs, expected):
|
||||
def test_prepare_outputs_training_epoch_end_automatic(self, num_optimizers, batch_outputs, expected):
|
||||
"""Test that the loop converts the nested lists of outputs to the format that the `training_epoch_end` hook
|
||||
currently expects in the case of automatic optimization."""
|
||||
assert self.prepare_outputs_training_epoch_end(tbptt_splits, batch_outputs, num_optimizers) == expected
|
||||
assert self.prepare_outputs_training_epoch_end(batch_outputs, num_optimizers) == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_outputs,expected",
|
||||
|
@ -111,37 +94,29 @@ class TestPrepareOutputs:
|
|||
([[_out00], [_out01]], [_out00, _out01]),
|
||||
# skipped outputs
|
||||
([[_out00], [], [], [_out03]], [_out00, _out03]),
|
||||
# tbptt with 2 splits, uneven, skipped output
|
||||
([[_out00, _out01], [_out02, _out03], [], [_out10]], [[_out00, _out01], [_out02, _out03], [_out10]]),
|
||||
],
|
||||
)
|
||||
def test_prepare_outputs_training_epoch_end_manual(self, batch_outputs, expected):
|
||||
"""Test that the loop converts the nested lists of outputs to the format that the `training_epoch_end` hook
|
||||
currently expects in the case of manual optimization."""
|
||||
assert self.prepare_outputs_training_epoch_end(0, batch_outputs, -1, automatic_optimization=False) == expected
|
||||
assert self.prepare_outputs_training_epoch_end(batch_outputs, -1, automatic_optimization=False) == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_optimizers,tbptt_splits,batch_end_outputs,expected",
|
||||
"num_optimizers,batch_end_outputs,expected",
|
||||
[
|
||||
(1, 0, [], []),
|
||||
(1, 0, [[]], []),
|
||||
(1, [], []),
|
||||
(1, [[]], []),
|
||||
# 1 optimizer
|
||||
(1, 0, [{0: _out00}], _out00),
|
||||
(1, [{0: _out00}], _out00),
|
||||
# 2 optimizers
|
||||
(2, 0, [{0: _out00, 1: _out01}], [_out00, _out01]),
|
||||
# tbptt with 2 splits
|
||||
(1, 2, [{0: _out00}, {0: _out01}], [_out00, _out01]),
|
||||
# 2 optimizers, tbptt with 2 splits
|
||||
(2, 2, [{0: _out00, 1: _out01}, {0: _out10, 1: _out11}], [[_out00, _out01], [_out10, _out11]]),
|
||||
(2, [{0: _out00, 1: _out01}], [_out00, _out01]),
|
||||
],
|
||||
)
|
||||
def test_prepare_outputs_training_batch_end_automatic(
|
||||
self, num_optimizers, tbptt_splits, batch_end_outputs, expected
|
||||
):
|
||||
def test_prepare_outputs_training_batch_end_automatic(self, num_optimizers, batch_end_outputs, expected):
|
||||
"""Test that the loop converts the nested lists of outputs to the format that the `on_train_batch_end` hook
|
||||
currently expects in the case of automatic optimization."""
|
||||
|
||||
assert self.prepare_outputs_training_batch_end(tbptt_splits, batch_end_outputs, num_optimizers) == expected
|
||||
assert self.prepare_outputs_training_batch_end(batch_end_outputs, num_optimizers) == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_end_outputs,expected",
|
||||
|
@ -150,16 +125,12 @@ class TestPrepareOutputs:
|
|||
([[]], []),
|
||||
# skipped outputs
|
||||
([_out00, None, _out02], [_out00, _out02]),
|
||||
# tbptt with 3 splits, skipped output
|
||||
([_out00, _out01, None, _out03], [_out00, _out01, _out03]),
|
||||
],
|
||||
)
|
||||
def test_prepare_outputs_training_batch_end_manual(self, batch_end_outputs, expected):
|
||||
"""Test that the loop converts the nested lists of outputs to the format that the `on_train_batch_end` hook
|
||||
currently expects in the case of manual optimization."""
|
||||
assert (
|
||||
self.prepare_outputs_training_batch_end(0, batch_end_outputs, -1, automatic_optimization=False) == expected
|
||||
)
|
||||
assert self.prepare_outputs_training_batch_end(batch_end_outputs, -1, automatic_optimization=False) == expected
|
||||
|
||||
|
||||
def test_no_val_on_train_epoch_loop_restart(tmpdir):
|
||||
|
@ -208,7 +179,7 @@ def test_should_stop_early_stopping_conditions_not_met(
|
|||
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0)
|
||||
trainer.num_training_batches = 10
|
||||
trainer.should_stop = True
|
||||
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = global_step
|
||||
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = global_step
|
||||
trainer.fit_loop.epoch_loop.batch_progress.current.ready = global_step
|
||||
trainer.fit_loop.epoch_progress.current.completed = current_epoch - 1
|
||||
|
||||
|
|
|
@ -65,15 +65,15 @@ def test__eval_step__flow(tmpdir):
|
|||
# simulate training manually
|
||||
trainer.state.stage = RunningStage.TRAINING
|
||||
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
|
||||
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs)
|
||||
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs)
|
||||
|
||||
assert len(train_step_out) == 1
|
||||
train_step_out = train_step_out[0][0]
|
||||
train_step_out = train_step_out[0]
|
||||
assert isinstance(train_step_out["loss"], Tensor)
|
||||
assert train_step_out["loss"].item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
|
||||
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
|
||||
opt_closure_result = opt_closure()
|
||||
assert opt_closure_result.item() == 171
|
||||
|
||||
|
@ -126,15 +126,15 @@ def test__eval_step__eval_step_end__flow(tmpdir):
|
|||
trainer.state.stage = RunningStage.TRAINING
|
||||
# make sure training outputs what is expected
|
||||
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
|
||||
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs)
|
||||
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs)
|
||||
|
||||
assert len(train_step_out) == 1
|
||||
train_step_out = train_step_out[0][0]
|
||||
train_step_out = train_step_out[0]
|
||||
assert isinstance(train_step_out["loss"], Tensor)
|
||||
assert train_step_out["loss"].item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
|
||||
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
|
||||
opt_closure_result = opt_closure()
|
||||
assert opt_closure_result.item() == 171
|
||||
|
||||
|
|
|
@ -52,14 +52,13 @@ def test_loops_state_dict_structure():
|
|||
"total": {"ready": 0, "completed": 0},
|
||||
"current": {"ready": 0, "completed": 0},
|
||||
},
|
||||
"epoch_loop.batch_loop.state_dict": {},
|
||||
"epoch_loop.batch_loop.manual_loop.state_dict": {},
|
||||
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
|
||||
"epoch_loop.manual_loop.state_dict": {},
|
||||
"epoch_loop.manual_loop.optim_step_progress": {
|
||||
"total": {"ready": 0, "completed": 0},
|
||||
"current": {"ready": 0, "completed": 0},
|
||||
},
|
||||
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
|
||||
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
|
||||
"epoch_loop.optimizer_loop.state_dict": {},
|
||||
"epoch_loop.optimizer_loop.optim_progress": {
|
||||
"optimizer": {
|
||||
"step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
|
||||
"zero_grad": {
|
||||
|
|
|
@ -25,7 +25,7 @@ from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoad
|
|||
from pytorch_lightning import LightningModule, Trainer
|
||||
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
|
||||
from pytorch_lightning.loops import EvaluationLoop, Loop, TrainingBatchLoop, TrainingEpochLoop
|
||||
from pytorch_lightning.loops import EvaluationLoop, Loop, OptimizerLoop, TrainingEpochLoop
|
||||
from pytorch_lightning.trainer.progress import BaseProgress
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
@ -109,15 +109,15 @@ def test_connect_subloops(tmpdir):
|
|||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
|
||||
|
||||
epoch_loop = trainer.fit_loop.epoch_loop
|
||||
new_batch_loop = TrainingBatchLoop()
|
||||
epoch_loop.connect(batch_loop=new_batch_loop)
|
||||
assert epoch_loop.batch_loop is new_batch_loop
|
||||
new_optimizer_loop = OptimizerLoop()
|
||||
epoch_loop.connect(optimizer_loop=new_optimizer_loop)
|
||||
assert epoch_loop.optimizer_loop is new_optimizer_loop
|
||||
|
||||
with pytest.raises(RuntimeError, match="The loop is not attached to a Trainer"):
|
||||
_ = new_batch_loop.trainer
|
||||
_ = new_optimizer_loop.trainer
|
||||
|
||||
trainer.fit(model)
|
||||
assert new_batch_loop.trainer is trainer
|
||||
assert new_optimizer_loop.trainer is trainer
|
||||
|
||||
|
||||
def test_replace_loops():
|
||||
|
@ -144,22 +144,22 @@ def test_replace_loops():
|
|||
assert trainer.fit_loop.epoch_loop is new_loop
|
||||
assert new_loop.min_steps == 123
|
||||
assert new_loop.max_steps == 321
|
||||
assert new_loop.batch_loop is old_loop.batch_loop
|
||||
assert new_loop.optimizer_loop is old_loop.optimizer_loop
|
||||
assert new_loop.val_loop is old_loop.val_loop
|
||||
assert new_loop.trainer is trainer
|
||||
|
||||
class MyBatchLoop(TrainingBatchLoop):
|
||||
class MyOptimizerLoop(OptimizerLoop):
|
||||
...
|
||||
|
||||
class MyEvalLoop(EvaluationLoop):
|
||||
...
|
||||
|
||||
# test passing more than one where one is an instance and the other a class
|
||||
trainer.fit_loop.epoch_loop.replace(batch_loop=MyBatchLoop, val_loop=MyEvalLoop())
|
||||
new_batch_loop = trainer.fit_loop.epoch_loop.batch_loop
|
||||
trainer.fit_loop.epoch_loop.replace(optimizer_loop=MyOptimizerLoop, val_loop=MyEvalLoop())
|
||||
new_optimizer_loop = trainer.fit_loop.epoch_loop.optimizer_loop
|
||||
new_val_loop = trainer.fit_loop.epoch_loop.val_loop
|
||||
|
||||
assert isinstance(new_batch_loop, MyBatchLoop)
|
||||
assert isinstance(new_optimizer_loop, MyOptimizerLoop)
|
||||
assert isinstance(new_val_loop, MyEvalLoop)
|
||||
|
||||
|
||||
|
@ -436,7 +436,7 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
|
|||
assert os.path.exists(ckpt_path)
|
||||
checkpoint = torch.load(ckpt_path)
|
||||
|
||||
optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress
|
||||
optim_progress = trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress
|
||||
sch_progress = trainer.fit_loop.epoch_loop.scheduler_progress
|
||||
|
||||
# `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch
|
||||
|
@ -510,14 +510,13 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
|
|||
"total": {"ready": nbe_sch_steps + be_sch_steps, "completed": nbe_sch_steps + be_sch_steps},
|
||||
"current": {"ready": be_sch_steps, "completed": be_sch_steps},
|
||||
},
|
||||
"epoch_loop.batch_loop.state_dict": ANY,
|
||||
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
|
||||
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
|
||||
"epoch_loop.manual_loop.state_dict": ANY,
|
||||
"epoch_loop.manual_loop.optim_step_progress": {
|
||||
"total": {"ready": 0, "completed": 0},
|
||||
"current": {"ready": 0, "completed": 0},
|
||||
},
|
||||
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
|
||||
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
|
||||
"epoch_loop.optimizer_loop.state_dict": {},
|
||||
"epoch_loop.optimizer_loop.optim_progress": {
|
||||
"optimizer_position": stop_optimizer,
|
||||
"optimizer": {
|
||||
"step": {
|
||||
|
@ -563,8 +562,8 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
|
|||
# test resetting manually, we expect all `ready` counters to be reset to `completed`
|
||||
trainer.fit_loop.reset()
|
||||
trainer.fit_loop.epoch_loop.reset()
|
||||
trainer.fit_loop.epoch_loop.batch_loop.reset()
|
||||
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.reset()
|
||||
trainer.fit_loop.epoch_loop.optimizer_loop.reset()
|
||||
trainer.fit_loop.epoch_loop.manual_loop.reset()
|
||||
|
||||
epoch_progress = trainer.fit_loop.epoch_progress
|
||||
assert epoch_progress.current.ready == stop_epoch
|
||||
|
@ -574,7 +573,7 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
|
|||
assert batch_progress.current.ready == be_batches_completed
|
||||
assert batch_progress.current.completed == be_batches_completed
|
||||
|
||||
optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress
|
||||
optim_progress = trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress
|
||||
assert optim_progress.optimizer.step.current.ready == be_total_opt_steps
|
||||
assert optim_progress.optimizer.step.current.completed == be_total_opt_steps
|
||||
assert optim_progress.optimizer.zero_grad.current.ready == be_total_zero_grad
|
||||
|
@ -677,14 +676,13 @@ def test_loop_state_on_complete_run(n_optimizers, tmpdir):
|
|||
"total": {"ready": n_sch_steps_total, "completed": n_sch_steps_total},
|
||||
"current": {"ready": n_sch_steps_current, "completed": n_sch_steps_current},
|
||||
},
|
||||
"epoch_loop.batch_loop.state_dict": ANY,
|
||||
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
|
||||
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
|
||||
"epoch_loop.manual_loop.state_dict": ANY,
|
||||
"epoch_loop.manual_loop.optim_step_progress": {
|
||||
"total": {"ready": 0, "completed": 0},
|
||||
"current": {"ready": 0, "completed": 0},
|
||||
},
|
||||
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
|
||||
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
|
||||
"epoch_loop.optimizer_loop.state_dict": {},
|
||||
"epoch_loop.optimizer_loop.optim_progress": {
|
||||
"optimizer_position": n_optimizers,
|
||||
"optimizer": {
|
||||
"step": {
|
||||
|
@ -746,7 +744,7 @@ def test_fit_loop_reset(tmpdir):
|
|||
mid_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=2.ckpt"))
|
||||
fit_loop = trainer.fit_loop
|
||||
epoch_loop = fit_loop.epoch_loop
|
||||
optimizer_loop = epoch_loop.batch_loop.optimizer_loop
|
||||
optimizer_loop = epoch_loop.optimizer_loop
|
||||
assert not fit_loop.restarting
|
||||
assert not epoch_loop.restarting
|
||||
assert not optimizer_loop.restarting
|
||||
|
|
|
@ -217,7 +217,7 @@ def test_should_stop_early_stopping_conditions_met(
|
|||
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0, max_epochs=100)
|
||||
trainer.num_training_batches = 10
|
||||
trainer.should_stop = True
|
||||
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = (
|
||||
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = (
|
||||
current_epoch * trainer.num_training_batches
|
||||
)
|
||||
trainer.fit_loop.epoch_loop.batch_progress.current.ready = 10
|
||||
|
|
|
@ -147,15 +147,15 @@ def test__training_step__epoch_end__flow_scalar(tmpdir):
|
|||
trainer.state.stage = RunningStage.TRAINING
|
||||
# make sure training outputs what is expected
|
||||
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
|
||||
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs)
|
||||
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs)
|
||||
|
||||
assert len(train_step_out) == 1
|
||||
train_step_out = train_step_out[0][0]
|
||||
train_step_out = train_step_out[0]
|
||||
assert isinstance(train_step_out["loss"], Tensor)
|
||||
assert train_step_out["loss"].item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
|
||||
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
|
||||
opt_closure_result = opt_closure()
|
||||
assert opt_closure_result.item() == 171
|
||||
|
||||
|
@ -217,15 +217,15 @@ def test__training_step__step_end__epoch_end__flow_scalar(tmpdir):
|
|||
trainer.state.stage = RunningStage.TRAINING
|
||||
# make sure training outputs what is expected
|
||||
kwargs = {"batch": next(iter(model.train_dataloader())), "batch_idx": 0}
|
||||
train_step_out = trainer.fit_loop.epoch_loop.batch_loop.run(kwargs)
|
||||
train_step_out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs)
|
||||
|
||||
assert len(train_step_out) == 1
|
||||
train_step_out = train_step_out[0][0]
|
||||
train_step_out = train_step_out[0]
|
||||
assert isinstance(train_step_out["loss"], Tensor)
|
||||
assert train_step_out["loss"].item() == 171
|
||||
|
||||
# make sure the optimizer closure returns the correct things
|
||||
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
|
||||
opt_closure = trainer.fit_loop.epoch_loop.optimizer_loop._make_closure(kwargs, trainer.optimizers[0])
|
||||
opt_closure_result = opt_closure()
|
||||
assert opt_closure_result.item() == 171
|
||||
|
||||
|
@ -301,9 +301,10 @@ def test_training_step_no_return_when_even(tmpdir):
|
|||
|
||||
# manually check a few batches
|
||||
for batch_idx, batch in enumerate(model.train_dataloader()):
|
||||
out = trainer.fit_loop.epoch_loop.batch_loop.run({"batch": batch, "batch_idx": batch_idx})
|
||||
kwargs = {"batch": batch, "batch_idx": batch_idx}
|
||||
out = trainer.fit_loop.epoch_loop.optimizer_loop.run([(0, trainer.optimizers[0])], kwargs)
|
||||
if not batch_idx % 2:
|
||||
assert out == []
|
||||
assert out == {}
|
||||
|
||||
|
||||
def test_training_step_none_batches(tmpdir):
|
||||
|
|
|
@ -13,32 +13,7 @@
|
|||
# limitations under the License.
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.loops.utilities import _extract_hiddens, _set_sampler_epoch
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
def test_extract_hiddens():
|
||||
# tbptt not enabled, no hiddens return
|
||||
training_step_output = 1 # anything
|
||||
hiddens = _extract_hiddens(training_step_output, 0)
|
||||
assert hiddens is None
|
||||
|
||||
# tbptt enabled, hiddens return
|
||||
hiddens = torch.tensor(321.12, requires_grad=True)
|
||||
training_step_output = {"hiddens": hiddens}
|
||||
hiddens = _extract_hiddens(training_step_output, 2)
|
||||
assert "hiddens" in training_step_output
|
||||
assert not hiddens.requires_grad
|
||||
|
||||
# tbptt not enabled, hiddens return
|
||||
with pytest.raises(MisconfigurationException, match='returned "hiddens" .* but `truncated_bptt_steps` is disabled'):
|
||||
_extract_hiddens(training_step_output, 0)
|
||||
# tbptt enabled, no hiddens return
|
||||
with pytest.raises(MisconfigurationException, match="enabled `truncated_bptt_steps` but did not `return"):
|
||||
_extract_hiddens(None, 1)
|
||||
from pytorch_lightning.loops.utilities import _set_sampler_epoch
|
||||
|
||||
|
||||
def test_set_sampler_epoch():
|
||||
|
|
|
@ -353,7 +353,7 @@ def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files)
|
|||
# emulate callback's calls during the training
|
||||
for i, loss in enumerate(losses, 1):
|
||||
# sets `trainer.global_step`
|
||||
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = i
|
||||
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = i
|
||||
trainer.callback_metrics.update({"checkpoint_on": torch.tensor(loss)})
|
||||
checkpoint_callback.on_validation_end(trainer, trainer.lightning_module)
|
||||
trainer.fit_loop.epoch_progress.current.completed = i # sets `trainer.current_epoch`
|
||||
|
|
|
@ -11,6 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
@ -156,3 +158,37 @@ def test_migrate_dropped_apex_amp_state(monkeypatch):
|
|||
with pytest.warns(UserWarning, match="checkpoint contains apex AMP data"):
|
||||
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy())
|
||||
assert "amp_scaling_state" not in updated_checkpoint
|
||||
|
||||
|
||||
def test_migrate_loop_structure_after_tbptt_removal():
|
||||
"""Test the loop state migration after truncated backpropagation support was removed in 2.0.0, and with it the
|
||||
training batch loop."""
|
||||
# automatic- and manual optimization state are combined into a single checkpoint to simplify testing
|
||||
state_automatic = MagicMock()
|
||||
state_manual = MagicMock()
|
||||
optim_progress_automatic = MagicMock()
|
||||
optim_progress_manual = MagicMock()
|
||||
old_batch_loop_state = MagicMock()
|
||||
old_checkpoint = {
|
||||
"loops": {
|
||||
"fit_loop": {
|
||||
"epoch_loop.state_dict": {"any": "state"},
|
||||
"epoch_loop.batch_loop.state_dict": old_batch_loop_state,
|
||||
"epoch_loop.batch_loop.optimizer_loop.state_dict": state_automatic,
|
||||
"epoch_loop.batch_loop.optimizer_loop.optim_progress": optim_progress_automatic,
|
||||
"epoch_loop.batch_loop.manual_loop.state_dict": state_manual,
|
||||
"epoch_loop.batch_loop.manual_loop.optim_step_progress": optim_progress_manual,
|
||||
}
|
||||
}
|
||||
}
|
||||
_set_version(old_checkpoint, "1.8.0") # pretend a checkpoint prior to 2.0.0
|
||||
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="2.0.0")
|
||||
assert updated_checkpoint["loops"] == {
|
||||
"fit_loop": {
|
||||
"epoch_loop.state_dict": {"any": "state", "old_batch_loop_state_dict": old_batch_loop_state},
|
||||
"epoch_loop.optimizer_loop.state_dict": state_automatic,
|
||||
"epoch_loop.optimizer_loop.optim_progress": optim_progress_automatic,
|
||||
"epoch_loop.manual_loop.state_dict": state_manual,
|
||||
"epoch_loop.manual_loop.optim_step_progress": optim_progress_manual,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -447,21 +447,6 @@ def test_on_train_batch_end_overridden(tmpdir) -> None:
|
|||
trainer.fit(m)
|
||||
|
||||
|
||||
def test_tbptt_split_batch_overridden(tmpdir) -> None:
|
||||
"""Verify that a `MisconfigurationException` is raised when `tbptt_split_batch` is overridden on the
|
||||
`LightningModule`."""
|
||||
|
||||
class InvalidModel(AsyncBoringModel):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.truncated_bptt_steps = 2
|
||||
|
||||
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
|
||||
m = InvalidModel()
|
||||
with pytest.raises(MisconfigurationException, match="is incompatible with `truncated_bptt_steps > 0`."):
|
||||
trainer.fit(m)
|
||||
|
||||
|
||||
def test_transfer_hooks_with_unpacking(tmpdir):
|
||||
|
||||
"""This test asserts the `transfer_batch` hooks are called only once per batch."""
|
||||
|
|
Loading…
Reference in New Issue