lightning/pytorch_lightning/loops/utilities.py

167 lines
6.8 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, Mapping, Optional, Sequence
import torch
from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import STEP_OUTPUT
def check_finite_loss(loss: Optional[torch.Tensor]) -> None:
"""Checks for finite loss value.
Args:
loss: the loss value to check to be finite
"""
if loss is not None and not torch.isfinite(loss).all():
raise ValueError(f"The loss returned in `training_step` is {loss}.")
def _check_training_step_output(model: "pl.LightningModule", training_step_output: STEP_OUTPUT) -> None:
"""Sanity checks that training produced a valid output.
Args:
model: a reference to the trainer
training_step_output: the output of the training step (before wrapping in an AttributeDict)
"""
if (
isinstance(training_step_output, torch.Tensor)
and not model.automatic_optimization
and training_step_output.grad_fn is None
):
# TODO: in manual optimization, anything returned should be considered an `extra`
raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor")
if model.automatic_optimization and not (
isinstance(training_step_output, torch.Tensor)
or (isinstance(training_step_output, Mapping) and "loss" in training_step_output)
or training_step_output is None
):
raise MisconfigurationException(
"In automatic optimization, `training_step` must either return a Tensor, "
"a dict with key 'loss' or None (where the step will be skipped)."
)
def _extract_hiddens(training_step_output: STEP_OUTPUT, truncated_bptt_steps: int) -> Optional[Any]:
"""Get the hidden state if present from the training step output.
Raises:
MisconfigurationException: If :attr:`~pytorch_lightning.core.Lightning.LightningModule.truncated_bptt_steps` is
not enabled and hiddens are returned or vice versa.
"""
is_dict = isinstance(training_step_output, dict)
if not truncated_bptt_steps:
if is_dict and "hiddens" in training_step_output:
raise MisconfigurationException(
'You returned "hiddens" in your `training_step` but `truncated_bptt_steps` is disabled'
)
return
elif not is_dict or "hiddens" not in training_step_output:
raise MisconfigurationException(
'You enabled `truncated_bptt_steps` but did not return "hiddens" in your `training_step`'
)
# detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time`
hiddens = recursive_detach(training_step_output["hiddens"])
return hiddens
def _build_training_step_kwargs(
lightning_module: "pl.LightningModule",
optimizers: Sequence[Optimizer],
batch: Any,
batch_idx: int,
opt_idx: Optional[int],
hiddens: Optional[Any],
) -> Dict[str, Any]:
"""Builds the keyword arguments for training_step.
Args:
lightning_module: the LightningModule with a `training_step` hook implementation
optimizers: the list of optimizers from the Trainer
batch: the batch to train on
batch_idx: the index of the current batch
opt_idx: the index of the current optimizer
hiddens: the hidden state of the previous RNN iteration
Returns:
the keyword arguments for the training step
"""
# enable not needing to add opt_idx to training_step
step_kwargs = OrderedDict([("batch", batch)])
training_step_fx = getattr(lightning_module, "training_step")
if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2):
step_kwargs["batch_idx"] = batch_idx
if len(optimizers) > 1:
has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
if has_opt_idx_in_train_step:
if not lightning_module.automatic_optimization:
raise ValueError(
"Your `LightningModule.training_step` signature contains an `optimizer_idx` argument but"
" in manual optimization optimizers must be handled by the user. Remove the optimizer_idx"
" argument or set `self.automatic_optimization = True`."
)
step_kwargs["optimizer_idx"] = opt_idx
elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization:
raise ValueError(
f"Your LightningModule defines {len(optimizers)} optimizers but"
" `training_step` is missing the `optimizer_idx` argument."
)
# pass hiddens if using tbptt
if lightning_module.truncated_bptt_steps > 0:
step_kwargs["hiddens"] = hiddens
return step_kwargs
def _prepare_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) -> Iterator:
"""Attach the dataloader."""
if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
# restore iteration
dataloader_iter = enumerate(data_fetcher, batch_idx)
else:
dataloader_iter = iter(data_fetcher)
return dataloader_iter
@contextmanager
def _block_parallel_sync_behavior(trainer: "pl.Trainer", block: bool = True) -> Generator[None, None, None]:
"""Blocks synchronization in :class:`~pytorch_lightning.plugins.training_type.parallel.ParallelPlugin`. This is
useful for example when when accumulating gradients to reduce communication when it is not needed.
Args:
trainer: the trainer instance with a reference to a training type plugin
block: whether the context manager is enabled or not
Returns:
context manager with sync behaviour off
"""
if isinstance(trainer.training_type_plugin, ParallelPlugin) and block:
with trainer.training_type_plugin.block_backward_sync():
yield None
else:
yield None