2021-08-16 19:01:09 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2021-08-30 16:00:26 +00:00
|
|
|
from collections import OrderedDict
|
2021-08-30 14:56:16 +00:00
|
|
|
from contextlib import contextmanager
|
2021-08-30 16:00:26 +00:00
|
|
|
from typing import Any, Dict, Generator, Iterator, Mapping, Optional, Sequence, Tuple
|
2021-08-16 19:01:09 +00:00
|
|
|
|
|
|
|
import torch
|
2021-08-30 16:00:26 +00:00
|
|
|
from torch import Tensor
|
|
|
|
from torch.optim import Optimizer
|
2021-08-16 19:01:09 +00:00
|
|
|
|
|
|
|
import pytorch_lightning as pl
|
2021-08-30 14:56:16 +00:00
|
|
|
from pytorch_lightning.plugins import ParallelPlugin
|
2021-08-16 19:01:09 +00:00
|
|
|
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
|
|
|
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
|
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
2021-08-26 09:36:29 +00:00
|
|
|
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
|
2021-08-16 19:01:09 +00:00
|
|
|
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
|
2021-08-30 16:00:26 +00:00
|
|
|
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
2021-08-16 19:01:09 +00:00
|
|
|
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
|
|
|
|
|
|
|
|
|
|
|
def check_finite_loss(model: "pl.LightningModule", loss: torch.Tensor) -> None:
|
|
|
|
"""Checks for finite parameters and loss values.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model: a reference to the ``LightningModule``
|
|
|
|
loss: the loss value to check to be finite
|
|
|
|
"""
|
|
|
|
if not torch.isfinite(loss).all():
|
|
|
|
raise ValueError(f"The loss returned in `training_step` is {loss}.")
|
|
|
|
detect_nan_parameters(model)
|
|
|
|
|
|
|
|
|
|
|
|
def _check_training_step_output(model: "pl.LightningModule", training_step_output: STEP_OUTPUT) -> None:
|
|
|
|
"""Sanity checks that training produced a valid output and optimizer step has already been called in manual
|
|
|
|
optimization.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model: a reference to the trainer
|
|
|
|
training_step_output: the output of the training step (before wrapping in an AttributeDict)
|
|
|
|
"""
|
|
|
|
if isinstance(training_step_output, torch.Tensor) and not model.automatic_optimization:
|
|
|
|
if training_step_output.grad_fn is None:
|
|
|
|
# TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
|
|
|
|
raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor")
|
|
|
|
elif model.automatic_optimization:
|
|
|
|
if not any(
|
|
|
|
(
|
|
|
|
isinstance(training_step_output, torch.Tensor),
|
|
|
|
(isinstance(training_step_output, Mapping) and "loss" in training_step_output),
|
|
|
|
training_step_output is None,
|
|
|
|
)
|
|
|
|
):
|
|
|
|
raise MisconfigurationException(
|
|
|
|
"In automatic optimization, `training_step` must either return a Tensor, "
|
|
|
|
"a dict with key 'loss' or None (where the step will be skipped)."
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _process_training_step_output(
|
|
|
|
trainer: "pl.Trainer", training_step_output: STEP_OUTPUT
|
2021-08-26 08:36:21 +00:00
|
|
|
) -> Tuple[Optional[ResultCollection], Optional[Any]]:
|
2021-08-16 19:01:09 +00:00
|
|
|
"""Adds the :param:`training_step_output` to the trainer's results
|
|
|
|
|
|
|
|
Args:
|
|
|
|
trainer: a reference to the trainer
|
|
|
|
training_step_output: the output of the training step (before wrapping into an AttributeDict)
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
the updated results (None if the training_step's output was None) and hiddens exract from the results
|
|
|
|
"""
|
|
|
|
if training_step_output is None:
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
results = trainer._results
|
|
|
|
|
|
|
|
loss = None
|
|
|
|
hiddens = None
|
|
|
|
|
|
|
|
# handle dict return
|
|
|
|
if isinstance(training_step_output, dict):
|
|
|
|
# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
|
|
|
|
loss = training_step_output.get("loss")
|
|
|
|
hiddens = training_step_output.get("hiddens")
|
|
|
|
# detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time`
|
|
|
|
hiddens = apply_to_collection(hiddens, torch.Tensor, lambda t: t.detach())
|
|
|
|
# use the setter instead of `dict.update` because it calls `detach` on the tensor items
|
|
|
|
results.extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}
|
|
|
|
|
|
|
|
# handle scalar return
|
|
|
|
elif isinstance(training_step_output, torch.Tensor):
|
|
|
|
loss = training_step_output
|
|
|
|
|
|
|
|
# map to results under the hood
|
|
|
|
results.minimize = loss
|
|
|
|
|
|
|
|
if trainer.move_metrics_to_cpu:
|
|
|
|
results.cpu()
|
|
|
|
return results, hiddens
|
2021-08-24 18:45:54 +00:00
|
|
|
|
|
|
|
|
2021-08-30 16:00:26 +00:00
|
|
|
def _build_training_step_kwargs(
|
|
|
|
lightning_module: "pl.LightningModule",
|
|
|
|
optimizers: Sequence[Optimizer],
|
|
|
|
batch: Any,
|
|
|
|
batch_idx: int,
|
|
|
|
opt_idx: Optional[int],
|
|
|
|
hiddens: Optional[Tensor],
|
|
|
|
) -> Dict[str, Any]:
|
|
|
|
"""Builds the keyword arguments for training_step
|
|
|
|
|
|
|
|
Args:
|
|
|
|
lightning_module: the LightningModule with a `training_step` hook implementation
|
|
|
|
optimizers: the list of optimizers from the Trainer
|
|
|
|
batch: the batch to train on
|
|
|
|
batch_idx: the index of the current batch
|
|
|
|
opt_idx: the index of the current optimizer
|
|
|
|
hiddens: the hidden state of the previous RNN iteration
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
the keyword arguments for the training step
|
|
|
|
"""
|
|
|
|
# enable not needing to add opt_idx to training_step
|
|
|
|
step_kwargs = OrderedDict([("batch", batch)])
|
|
|
|
|
|
|
|
training_step_fx = getattr(lightning_module, "training_step")
|
|
|
|
|
|
|
|
if is_param_in_hook_signature(training_step_fx, "batch_idx", min_args=2):
|
|
|
|
step_kwargs["batch_idx"] = batch_idx
|
|
|
|
|
|
|
|
if len(optimizers) > 1:
|
|
|
|
has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx")
|
|
|
|
if has_opt_idx_in_train_step:
|
|
|
|
if not lightning_module.automatic_optimization:
|
|
|
|
raise ValueError(
|
|
|
|
"Your `LightningModule.training_step` signature contains an `optimizer_idx` argument but"
|
|
|
|
" in manual optimization optimizers must be handled by the user. Remove the optimizer_idx"
|
|
|
|
" argument or set `self.automatic_optimization = True`."
|
|
|
|
)
|
|
|
|
step_kwargs["optimizer_idx"] = opt_idx
|
|
|
|
elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization:
|
|
|
|
raise ValueError(
|
|
|
|
f"Your LightningModule defines {len(optimizers)} optimizers but"
|
|
|
|
" `training_step` is missing the `optimizer_idx` argument."
|
|
|
|
)
|
|
|
|
|
|
|
|
# pass hiddens if using tbptt
|
|
|
|
if lightning_module.truncated_bptt_steps > 0:
|
|
|
|
step_kwargs["hiddens"] = hiddens
|
|
|
|
|
|
|
|
return step_kwargs
|
|
|
|
|
|
|
|
|
2021-08-26 09:36:29 +00:00
|
|
|
def _prepare_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) -> Iterator:
|
2021-08-24 18:45:54 +00:00
|
|
|
"""Attach the dataloader"""
|
2021-08-26 09:36:29 +00:00
|
|
|
if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
|
|
|
|
# restore iteration
|
|
|
|
dataloader_iter = enumerate(data_fetcher, batch_idx)
|
|
|
|
else:
|
|
|
|
dataloader_iter = iter(data_fetcher)
|
2021-08-24 18:45:54 +00:00
|
|
|
return dataloader_iter
|
2021-08-30 14:56:16 +00:00
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def _block_parallel_sync_behavior(trainer: "pl.Trainer", block: bool = True) -> Generator[None, None, None]:
|
|
|
|
"""
|
|
|
|
Blocks synchronization in :class:`~pytorch_lightning.plugins.training_type.parallel.ParallelPlugin`.
|
|
|
|
This is useful for example when when accumulating gradients to reduce communication when it is not needed.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
trainer: the trainer instance with a reference to a training type plugin
|
|
|
|
block: whether the context manager is enabled or not
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
context manager with sync behaviour off
|
|
|
|
"""
|
|
|
|
if isinstance(trainer.training_type_plugin, ParallelPlugin) and block:
|
|
|
|
with trainer.training_type_plugin.block_backward_sync():
|
|
|
|
yield None
|
|
|
|
else:
|
|
|
|
yield None
|