# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict from contextlib import contextmanager from typing import Any, Dict, Generator, Iterator, Mapping, Optional, Sequence, Tuple import torch from torch import Tensor from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import STEP_OUTPUT def check_finite_loss(model: "pl.LightningModule", loss: torch.Tensor) -> None: """Checks for finite parameters and loss values. Args: model: a reference to the ``LightningModule`` loss: the loss value to check to be finite """ if not torch.isfinite(loss).all(): raise ValueError(f"The loss returned in `training_step` is {loss}.") detect_nan_parameters(model) def _check_training_step_output(model: "pl.LightningModule", training_step_output: STEP_OUTPUT) -> None: """Sanity checks that training produced a valid output and optimizer step has already been called in manual optimization. Args: model: a reference to the trainer training_step_output: the output of the training step (before wrapping in an AttributeDict) """ if isinstance(training_step_output, torch.Tensor) and not model.automatic_optimization: if training_step_output.grad_fn is None: # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") elif model.automatic_optimization: if not any( ( isinstance(training_step_output, torch.Tensor), (isinstance(training_step_output, Mapping) and "loss" in training_step_output), training_step_output is None, ) ): raise MisconfigurationException( "In automatic optimization, `training_step` must either return a Tensor, " "a dict with key 'loss' or None (where the step will be skipped)." ) def _process_training_step_output( trainer: "pl.Trainer", training_step_output: STEP_OUTPUT ) -> Tuple[Optional[ResultCollection], Optional[Any]]: """Adds the :param:`training_step_output` to the trainer's results Args: trainer: a reference to the trainer training_step_output: the output of the training step (before wrapping into an AttributeDict) Returns: the updated results (None if the training_step's output was None) and hiddens exract from the results """ if training_step_output is None: return None, None results = trainer._results loss = None hiddens = None # handle dict return if isinstance(training_step_output, dict): # this should not modify the `training_step_output`, as the user could be using it after `training_step_end` loss = training_step_output.get("loss") hiddens = training_step_output.get("hiddens") # detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time` hiddens = apply_to_collection(hiddens, torch.Tensor, lambda t: t.detach()) # use the setter instead of `dict.update` because it calls `detach` on the tensor items results.extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")} # handle scalar return elif isinstance(training_step_output, torch.Tensor): loss = training_step_output # map to results under the hood results.minimize = loss if trainer.move_metrics_to_cpu: results.cpu() return results, hiddens def _build_training_step_kwargs( lightning_module: "pl.LightningModule", optimizers: Sequence[Optimizer], batch: Any, batch_idx: int, opt_idx: Optional[int], hiddens: Optional[Tensor], ) -> Dict[str, Any]: """Builds the keyword arguments for training_step Args: lightning_module: the LightningModule with a `training_step` hook implementation optimizers: the list of optimizers from the Trainer batch: the batch to train on batch_idx: the index of the current batch opt_idx: the index of the current optimizer hiddens: the hidden state of the previous RNN iteration Returns: the keyword arguments for the training step """ # enable not needing to add opt_idx to training_step step_kwargs = OrderedDict([("batch", batch)]) training_step_fx = getattr(lightning_module, "training_step") if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2): step_kwargs["batch_idx"] = batch_idx if len(optimizers) > 1: has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx") if has_opt_idx_in_train_step: if not lightning_module.automatic_optimization: raise ValueError( "Your `LightningModule.training_step` signature contains an `optimizer_idx` argument but" " in manual optimization optimizers must be handled by the user. Remove the optimizer_idx" " argument or set `self.automatic_optimization = True`." ) step_kwargs["optimizer_idx"] = opt_idx elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization: raise ValueError( f"Your LightningModule defines {len(optimizers)} optimizers but" " `training_step` is missing the `optimizer_idx` argument." ) # pass hiddens if using tbptt if lightning_module.truncated_bptt_steps > 0: step_kwargs["hiddens"] = hiddens return step_kwargs def _prepare_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) -> Iterator: """Attach the dataloader""" if not isinstance(data_fetcher, DataLoaderIterDataFetcher): # restore iteration dataloader_iter = enumerate(data_fetcher, batch_idx) else: dataloader_iter = iter(data_fetcher) return dataloader_iter @contextmanager def _block_parallel_sync_behavior(trainer: "pl.Trainer", block: bool = True) -> Generator[None, None, None]: """ Blocks synchronization in :class:`~pytorch_lightning.plugins.training_type.parallel.ParallelPlugin`. This is useful for example when when accumulating gradients to reduce communication when it is not needed. Args: trainer: the trainer instance with a reference to a training type plugin block: whether the context manager is enabled or not Returns: context manager with sync behaviour off """ if isinstance(trainer.training_type_plugin, ParallelPlugin) and block: with trainer.training_type_plugin.block_backward_sync(): yield None else: yield None