2020-08-20 02:03:22 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2021-07-05 11:43:27 +00:00
|
|
|
"""The LightningModule - an nn.Module with many additional features."""
|
2020-11-13 15:05:54 +00:00
|
|
|
|
2019-10-31 10:45:28 +00:00
|
|
|
import collections
|
2020-02-25 15:36:44 +00:00
|
|
|
import inspect
|
2021-03-02 09:47:55 +00:00
|
|
|
import logging
|
2021-05-24 12:13:55 +00:00
|
|
|
import numbers
|
2020-12-01 00:09:46 +00:00
|
|
|
import os
|
|
|
|
import tempfile
|
2021-07-20 18:31:49 +00:00
|
|
|
from contextlib import contextmanager
|
2021-01-27 10:02:16 +00:00
|
|
|
from pathlib import Path
|
2021-10-18 15:29:51 +00:00
|
|
|
from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Tuple, Union
|
2019-10-22 08:32:40 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
import torch
|
2020-12-01 00:09:46 +00:00
|
|
|
from torch import ScriptModule, Tensor
|
|
|
|
from torch.nn import Module
|
|
|
|
from torch.optim.optimizer import Optimizer
|
2021-06-08 13:04:16 +00:00
|
|
|
from torchmetrics import Metric
|
2021-10-18 15:29:51 +00:00
|
|
|
from typing_extensions import Literal
|
2020-12-01 00:09:46 +00:00
|
|
|
|
2021-09-10 20:58:02 +00:00
|
|
|
import pytorch_lightning as pl
|
2021-09-09 20:53:47 +00:00
|
|
|
from pytorch_lightning.callbacks.progress import base as progress_base
|
2020-09-29 17:51:44 +00:00
|
|
|
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
|
2021-07-19 08:15:59 +00:00
|
|
|
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin
|
2020-12-07 12:55:49 +00:00
|
|
|
from pytorch_lightning.core.optimizer import LightningOptimizer
|
2021-07-09 15:10:00 +00:00
|
|
|
from pytorch_lightning.core.saving import ModelIO
|
2021-09-03 13:41:05 +00:00
|
|
|
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
|
2021-10-13 14:45:13 +00:00
|
|
|
from pytorch_lightning.utilities import (
|
2021-10-21 21:01:56 +00:00
|
|
|
_IS_WINDOWS,
|
2021-10-20 16:21:37 +00:00
|
|
|
_TORCH_GREATER_EQUAL_DEV_1_10,
|
2021-10-13 14:45:13 +00:00
|
|
|
GradClipAlgorithmType,
|
|
|
|
rank_zero_deprecation,
|
|
|
|
rank_zero_warn,
|
|
|
|
)
|
2021-01-09 12:37:44 +00:00
|
|
|
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
|
2021-05-21 11:23:15 +00:00
|
|
|
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
2021-06-25 19:16:11 +00:00
|
|
|
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
|
2020-09-30 02:12:56 +00:00
|
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
2021-07-30 13:53:40 +00:00
|
|
|
from pytorch_lightning.utilities.memory import get_model_size_mb
|
2021-08-03 22:08:51 +00:00
|
|
|
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
|
2021-07-09 15:10:00 +00:00
|
|
|
from pytorch_lightning.utilities.parsing import collect_init_args
|
2021-05-13 17:33:55 +00:00
|
|
|
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
2021-06-01 11:51:50 +00:00
|
|
|
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, STEP_OUTPUT
|
2021-05-13 17:33:55 +00:00
|
|
|
from pytorch_lightning.utilities.warnings import WarningCache
|
2020-09-21 02:59:21 +00:00
|
|
|
|
2021-05-13 17:33:55 +00:00
|
|
|
warning_cache = WarningCache()
|
2021-03-02 09:47:55 +00:00
|
|
|
log = logging.getLogger(__name__)
|
2021-02-22 11:01:54 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-09-29 17:51:44 +00:00
|
|
|
class LightningModule(
|
|
|
|
DeviceDtypeModuleMixin,
|
2021-07-09 15:10:00 +00:00
|
|
|
HyperparametersMixin,
|
2020-09-29 17:51:44 +00:00
|
|
|
ModelIO,
|
|
|
|
ModelHooks,
|
|
|
|
DataHooks,
|
|
|
|
CheckpointHooks,
|
|
|
|
Module,
|
|
|
|
):
|
2020-09-25 14:20:15 +00:00
|
|
|
# Below is for property support of JIT in PyTorch 1.7
|
2021-05-05 10:21:00 +00:00
|
|
|
# since none of these are important when using JIT, we are going to ignore them.
|
2021-07-26 11:37:35 +00:00
|
|
|
__jit_unused_properties__ = (
|
|
|
|
[
|
|
|
|
"example_input_array",
|
|
|
|
"on_gpu",
|
|
|
|
"current_epoch",
|
|
|
|
"global_step",
|
|
|
|
"global_rank",
|
|
|
|
"local_rank",
|
|
|
|
"logger",
|
|
|
|
"model_size",
|
|
|
|
"automatic_optimization",
|
|
|
|
"truncated_bptt_steps",
|
|
|
|
"loaded_optimizer_states_dict",
|
|
|
|
]
|
|
|
|
+ DeviceDtypeModuleMixin.__jit_unused_properties__
|
|
|
|
+ HyperparametersMixin.__jit_unused_properties__
|
|
|
|
)
|
2020-09-25 14:20:15 +00:00
|
|
|
|
2021-04-27 12:46:45 +00:00
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
2020-03-27 12:36:50 +00:00
|
|
|
super().__init__(*args, **kwargs)
|
2019-03-31 20:29:50 +00:00
|
|
|
|
2020-09-22 09:05:33 +00:00
|
|
|
# see (https://github.com/pytorch/pytorch/blob/3e6bb5233f9ca2c5aa55d9cda22a7ee85439aa6e/
|
|
|
|
# torch/nn/modules/module.py#L227)
|
|
|
|
torch._C._log_api_usage_once(f"lightning.module.{self.__class__.__name__}")
|
|
|
|
|
2021-07-05 11:43:27 +00:00
|
|
|
# pointer to the trainer object
|
2019-04-23 11:25:09 +00:00
|
|
|
self.trainer = None
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2021-01-12 10:22:37 +00:00
|
|
|
self._distrib_type = None
|
|
|
|
self._device_type = None
|
2020-05-24 22:59:08 +00:00
|
|
|
|
2021-07-05 11:43:27 +00:00
|
|
|
# true if using amp
|
2021-04-28 19:17:20 +00:00
|
|
|
self.use_amp: bool = False
|
2019-03-31 20:29:50 +00:00
|
|
|
|
2021-07-05 11:43:27 +00:00
|
|
|
# the precision used
|
2021-04-28 19:17:20 +00:00
|
|
|
self.precision: int = 32
|
2020-08-11 23:39:43 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
# optionally can be set by user
|
|
|
|
self._example_input_array = None
|
2021-05-19 20:31:06 +00:00
|
|
|
self._current_fx_name: Optional[str] = None
|
2021-04-28 19:17:20 +00:00
|
|
|
self._current_dataloader_idx: Optional[int] = None
|
2021-01-11 16:21:10 +00:00
|
|
|
self._automatic_optimization: bool = True
|
2021-05-05 10:21:00 +00:00
|
|
|
self._truncated_bptt_steps: int = 0
|
2021-07-14 10:32:13 +00:00
|
|
|
self._param_requires_grad_state = {}
|
2021-06-25 19:16:11 +00:00
|
|
|
self._metric_attributes: Optional[Dict[int, str]] = None
|
2021-07-20 18:31:49 +00:00
|
|
|
self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False
|
2020-06-15 21:05:58 +00:00
|
|
|
|
2021-08-23 19:59:38 +00:00
|
|
|
self._register_sharded_tensor_state_dict_hooks_if_available()
|
|
|
|
|
2021-07-06 08:13:09 +00:00
|
|
|
# deprecated, will be removed in 1.6
|
|
|
|
self._loaded_optimizer_states_dict = {}
|
|
|
|
|
2021-10-18 15:29:51 +00:00
|
|
|
@overload
|
|
|
|
def optimizers(self, use_pl_optimizer: Literal[True] = True) -> Union[LightningOptimizer, List[LightningOptimizer]]:
|
|
|
|
...
|
|
|
|
|
|
|
|
@overload
|
|
|
|
def optimizers(self, use_pl_optimizer: Literal[False]) -> Union[Optimizer, List[Optimizer]]:
|
|
|
|
...
|
|
|
|
|
2021-07-07 16:57:45 +00:00
|
|
|
def optimizers(
|
2021-07-26 11:37:35 +00:00
|
|
|
self, use_pl_optimizer: bool = True
|
2021-07-07 16:57:45 +00:00
|
|
|
) -> Union[Optimizer, LightningOptimizer, List[Optimizer], List[LightningOptimizer]]:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Returns the optimizer(s) that are being used during training. Useful for manual optimization.
|
2021-07-05 11:43:27 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
use_pl_optimizer: If ``True``, will wrap the optimizer(s) in a
|
|
|
|
:class:`~pytorch_lightning.core.optimizer.LightningOptimizer` for automatic handling of precision and
|
|
|
|
profiling.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A single optimizer, or a list of optimizers in case multiple ones are present.
|
|
|
|
"""
|
2021-01-08 21:13:12 +00:00
|
|
|
if use_pl_optimizer:
|
|
|
|
opts = list(self.trainer.lightning_optimizers.values())
|
|
|
|
else:
|
|
|
|
opts = self.trainer.optimizers
|
2020-10-11 13:35:51 +00:00
|
|
|
|
|
|
|
# single optimizer
|
2021-07-07 16:57:45 +00:00
|
|
|
if isinstance(opts, list) and len(opts) == 1 and isinstance(opts[0], (Optimizer, LightningOptimizer)):
|
2020-10-11 13:35:51 +00:00
|
|
|
return opts[0]
|
|
|
|
# multiple opts
|
2020-12-01 00:09:46 +00:00
|
|
|
return opts
|
2020-10-10 16:19:22 +00:00
|
|
|
|
2021-04-09 09:32:14 +00:00
|
|
|
def lr_schedulers(self) -> Optional[Union[Any, List[Any]]]:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Returns the learning rate scheduler(s) that are being used during training. Useful for manual
|
|
|
|
optimization.
|
2021-07-05 11:43:27 +00:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
A single scheduler, or a list of schedulers in case multiple ones are present, or ``None`` if no
|
|
|
|
schedulers were returned in :meth:`configure_optimizers`.
|
|
|
|
"""
|
2021-04-09 09:32:14 +00:00
|
|
|
if not self.trainer.lr_schedulers:
|
|
|
|
return None
|
|
|
|
|
|
|
|
# ignore other keys "interval", "frequency", etc.
|
|
|
|
lr_schedulers = [s["scheduler"] for s in self.trainer.lr_schedulers]
|
|
|
|
|
|
|
|
# single scheduler
|
|
|
|
if len(lr_schedulers) == 1:
|
|
|
|
return lr_schedulers[0]
|
|
|
|
|
|
|
|
# multiple schedulers
|
|
|
|
return lr_schedulers
|
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
@property
|
|
|
|
def example_input_array(self) -> Any:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""The example input array is a specification of what the module can consume in the :meth:`forward` method.
|
2021-07-05 11:43:27 +00:00
|
|
|
The return type is interpreted as follows:
|
|
|
|
|
|
|
|
- Single tensor: It is assumed the model takes a single argument, i.e.,
|
|
|
|
``model.forward(model.example_input_array)``
|
|
|
|
- Tuple: The input array should be interpreted as a sequence of positional arguments, i.e.,
|
|
|
|
``model.forward(*model.example_input_array)``
|
|
|
|
- Dict: The input array represents named keyword arguments, i.e.,
|
|
|
|
``model.forward(**model.example_input_array)``
|
|
|
|
"""
|
2020-06-15 21:05:58 +00:00
|
|
|
return self._example_input_array
|
|
|
|
|
2021-07-05 11:43:27 +00:00
|
|
|
@example_input_array.setter
|
|
|
|
def example_input_array(self, example: Any) -> None:
|
|
|
|
self._example_input_array = example
|
|
|
|
|
2020-10-05 15:10:40 +00:00
|
|
|
@property
|
|
|
|
def current_epoch(self) -> int:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""The current epoch in the Trainer.
|
|
|
|
|
|
|
|
If no Trainer is attached, this propery is 0.
|
|
|
|
"""
|
2020-10-05 15:10:40 +00:00
|
|
|
return self.trainer.current_epoch if self.trainer else 0
|
|
|
|
|
|
|
|
@property
|
|
|
|
def global_step(self) -> int:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Total training batches seen across all epochs.
|
|
|
|
|
|
|
|
If no Trainer is attached, this propery is 0.
|
|
|
|
"""
|
2020-10-05 15:10:40 +00:00
|
|
|
return self.trainer.global_step if self.trainer else 0
|
|
|
|
|
2021-02-01 14:28:17 +00:00
|
|
|
@property
|
|
|
|
def global_rank(self) -> int:
|
2021-07-05 11:43:27 +00:00
|
|
|
"""The index of the current process across all nodes and devices."""
|
2021-02-01 14:28:17 +00:00
|
|
|
return self.trainer.global_rank if self.trainer else 0
|
|
|
|
|
|
|
|
@property
|
|
|
|
def local_rank(self) -> int:
|
2021-07-05 11:43:27 +00:00
|
|
|
"""The index of the current process within a single node."""
|
2021-02-01 14:28:17 +00:00
|
|
|
return self.trainer.local_rank if self.trainer else 0
|
|
|
|
|
2021-07-01 21:02:29 +00:00
|
|
|
@property
|
|
|
|
def loaded_optimizer_states_dict(self) -> dict:
|
|
|
|
warning_cache.deprecation(
|
|
|
|
"The `LightningModule.loaded_optimizer_states_dict` property is deprecated in v1.4"
|
|
|
|
" and will be removed in v1.6.",
|
|
|
|
stacklevel=6,
|
|
|
|
)
|
|
|
|
return self._loaded_optimizer_states_dict
|
|
|
|
|
|
|
|
@loaded_optimizer_states_dict.setter
|
|
|
|
def loaded_optimizer_states_dict(self, val: dict) -> None:
|
|
|
|
warning_cache.deprecation(
|
|
|
|
"The `LightningModule.loaded_optimizer_states_dict` property is deprecated in v1.4"
|
|
|
|
" and will be removed in v1.6.",
|
|
|
|
stacklevel=6,
|
|
|
|
)
|
|
|
|
self._loaded_optimizer_states_dict = val
|
|
|
|
|
2020-05-17 12:20:51 +00:00
|
|
|
@property
|
|
|
|
def on_gpu(self):
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Returns ``True`` if this model is currently located on a GPU.
|
|
|
|
|
2020-05-17 12:20:51 +00:00
|
|
|
Useful to set flags around the LightningModule for different CPU vs GPU behavior.
|
|
|
|
"""
|
2020-09-21 02:59:21 +00:00
|
|
|
return self.device.type == "cuda"
|
2020-05-17 12:20:51 +00:00
|
|
|
|
2020-11-14 04:43:42 +00:00
|
|
|
@property
|
|
|
|
def automatic_optimization(self) -> bool:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``."""
|
2021-01-11 16:21:10 +00:00
|
|
|
return self._automatic_optimization
|
|
|
|
|
|
|
|
@automatic_optimization.setter
|
|
|
|
def automatic_optimization(self, automatic_optimization: bool) -> None:
|
|
|
|
self._automatic_optimization = automatic_optimization
|
|
|
|
|
2021-05-05 10:21:00 +00:00
|
|
|
@property
|
|
|
|
def truncated_bptt_steps(self) -> int:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Enables `Truncated Backpropagation Through Time` in the Trainer when set to a positive integer.
|
|
|
|
|
|
|
|
It represents
|
2021-07-05 11:43:27 +00:00
|
|
|
the number of times :meth:`training_step` gets called before backpropagation. If this is > 0, the
|
|
|
|
:meth:`training_step` receives an additional argument ``hiddens`` and is expected to return a hidden state.
|
2021-05-05 10:21:00 +00:00
|
|
|
"""
|
|
|
|
return self._truncated_bptt_steps
|
|
|
|
|
|
|
|
@truncated_bptt_steps.setter
|
|
|
|
def truncated_bptt_steps(self, truncated_bptt_steps: int) -> None:
|
|
|
|
self._truncated_bptt_steps = truncated_bptt_steps
|
|
|
|
|
2021-02-01 14:28:17 +00:00
|
|
|
@property
|
|
|
|
def logger(self):
|
2021-07-26 11:37:35 +00:00
|
|
|
"""Reference to the logger object in the Trainer."""
|
2021-02-01 14:28:17 +00:00
|
|
|
return self.trainer.logger if self.trainer else None
|
|
|
|
|
2021-05-13 17:33:55 +00:00
|
|
|
def _apply_batch_transfer_handler(
|
2021-08-16 11:34:42 +00:00
|
|
|
self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0
|
2021-05-13 17:33:55 +00:00
|
|
|
) -> Any:
|
|
|
|
device = device or self.device
|
2021-02-18 19:24:19 +00:00
|
|
|
batch = self.on_before_batch_transfer(batch, dataloader_idx)
|
2021-05-13 17:33:55 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
if is_param_in_hook_signature(self.transfer_batch_to_device, "dataloader_idx"):
|
2021-05-13 17:33:55 +00:00
|
|
|
batch = self.transfer_batch_to_device(batch, device, dataloader_idx)
|
|
|
|
else:
|
2021-06-18 11:50:24 +00:00
|
|
|
warning_cache.deprecation(
|
2021-05-13 17:33:55 +00:00
|
|
|
"`transfer_batch_to_device` hook signature has changed in v1.4."
|
|
|
|
" `dataloader_idx` parameter has been added to it. Support for"
|
2021-06-18 11:50:24 +00:00
|
|
|
" the old signature will be removed in v1.6"
|
2021-05-13 17:33:55 +00:00
|
|
|
)
|
|
|
|
batch = self.transfer_batch_to_device(batch, device)
|
|
|
|
|
2021-02-18 19:24:19 +00:00
|
|
|
batch = self.on_after_batch_transfer(batch, dataloader_idx)
|
2021-02-18 11:58:12 +00:00
|
|
|
return batch
|
|
|
|
|
2020-03-12 16:47:23 +00:00
|
|
|
def print(self, *args, **kwargs) -> None:
|
2020-02-25 03:30:53 +00:00
|
|
|
r"""
|
2020-04-06 12:12:44 +00:00
|
|
|
Prints only from process 0. Use this in any distributed mode to log only once.
|
2020-02-25 03:30:53 +00:00
|
|
|
|
|
|
|
Args:
|
2021-02-22 09:40:18 +00:00
|
|
|
*args: The thing to print. The same as for Python's built-in print function.
|
|
|
|
**kwargs: The same as for Python's built-in print function.
|
2020-04-06 12:12:44 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
Example::
|
2020-02-25 03:30:53 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
def forward(self, x):
|
|
|
|
self.print(x, 'in forward')
|
2020-02-25 03:30:53 +00:00
|
|
|
|
|
|
|
"""
|
2020-06-13 16:00:14 +00:00
|
|
|
if self.trainer.is_global_zero:
|
2021-02-22 09:40:18 +00:00
|
|
|
progress_bar = self.trainer.progress_bar_callback
|
|
|
|
if progress_bar is not None and progress_bar.is_enabled:
|
|
|
|
progress_bar.print(*args, **kwargs)
|
|
|
|
else:
|
|
|
|
print(*args, **kwargs)
|
2020-02-25 03:30:53 +00:00
|
|
|
|
2020-09-28 00:26:16 +00:00
|
|
|
def log(
|
|
|
|
self,
|
|
|
|
name: str,
|
2021-06-01 11:51:50 +00:00
|
|
|
value: _METRIC_COLLECTION,
|
2020-09-28 00:26:16 +00:00
|
|
|
prog_bar: bool = False,
|
|
|
|
logger: bool = True,
|
2020-10-15 21:02:50 +00:00
|
|
|
on_step: Optional[bool] = None,
|
|
|
|
on_epoch: Optional[bool] = None,
|
2021-07-26 11:37:35 +00:00
|
|
|
reduce_fx: Union[str, Callable] = "default", # TODO: change to 'mean' when `sync_dist_op` is removed in 1.6
|
2021-08-02 16:05:56 +00:00
|
|
|
tbptt_reduce_fx: Optional = None, # todo: Remove in 1.6
|
|
|
|
tbptt_pad_token: Optional = None, # todo: Remove in 1.6
|
2020-09-28 00:26:16 +00:00
|
|
|
enable_graph: bool = False,
|
|
|
|
sync_dist: bool = False,
|
2021-08-02 16:05:56 +00:00
|
|
|
sync_dist_op: Optional = None, # todo: Remove in 1.6
|
2020-09-28 00:26:16 +00:00
|
|
|
sync_dist_group: Optional[Any] = None,
|
2021-03-02 16:03:36 +00:00
|
|
|
add_dataloader_idx: bool = True,
|
2021-06-09 14:24:45 +00:00
|
|
|
batch_size: Optional[int] = None,
|
2021-06-25 19:16:11 +00:00
|
|
|
metric_attribute: Optional[str] = None,
|
|
|
|
rank_zero_only: Optional[bool] = None,
|
2021-05-19 19:25:32 +00:00
|
|
|
) -> None:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Log a key, value pair.
|
2020-09-28 00:26:16 +00:00
|
|
|
|
|
|
|
Example::
|
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
self.log('train_loss', loss)
|
2020-09-28 00:26:16 +00:00
|
|
|
|
2021-07-05 11:43:27 +00:00
|
|
|
The default behavior per hook is as follows:
|
2020-09-28 00:26:16 +00:00
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
.. csv-table:: ``*`` also applies to the test loop
|
2021-04-19 12:48:44 +00:00
|
|
|
:header: "LightningModule Hook", "on_step", "on_epoch", "prog_bar", "logger"
|
2020-09-30 12:31:16 +00:00
|
|
|
:widths: 20, 10, 10, 10, 10
|
|
|
|
|
|
|
|
"training_step", "T", "F", "F", "T"
|
|
|
|
"training_step_end", "T", "F", "F", "T"
|
|
|
|
"training_epoch_end", "F", "T", "F", "T"
|
|
|
|
"validation_step*", "F", "T", "F", "T"
|
|
|
|
"validation_step_end*", "F", "T", "F", "T"
|
|
|
|
"validation_epoch_end*", "F", "T", "F", "T"
|
2020-09-28 00:26:16 +00:00
|
|
|
|
|
|
|
Args:
|
2021-06-08 13:04:16 +00:00
|
|
|
name: key to log
|
2021-06-09 14:24:45 +00:00
|
|
|
value: value to log. Can be a ``float``, ``Tensor``, ``Metric``, or a dictionary of the former.
|
2020-09-30 12:31:16 +00:00
|
|
|
prog_bar: if True logs to the progress bar
|
2020-09-28 00:26:16 +00:00
|
|
|
logger: if True logs to the logger
|
2020-09-30 12:31:16 +00:00
|
|
|
on_step: if True logs at this step. None auto-logs at the training_step but not validation/test_step
|
|
|
|
on_epoch: if True logs epoch accumulated metrics. None auto-logs at the val/test step but not training_step
|
2021-06-08 13:04:16 +00:00
|
|
|
reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default.
|
2020-09-28 00:26:16 +00:00
|
|
|
enable_graph: if True, will not auto detach the graph
|
|
|
|
sync_dist: if True, reduces the metric across GPUs/TPUs
|
2021-03-02 16:03:36 +00:00
|
|
|
sync_dist_group: the ddp group to sync across
|
|
|
|
add_dataloader_idx: if True, appends the index of the current dataloader to
|
|
|
|
the name (when using multiple). If False, user needs to give unique names for
|
|
|
|
each dataloader to not mix values
|
2021-06-09 14:24:45 +00:00
|
|
|
batch_size: Current batch_size. This will be directly inferred from the loaded batch,
|
|
|
|
but some data structures might need to explicitly provide it.
|
2021-06-25 19:16:11 +00:00
|
|
|
metric_attribute: To restore the metric state, Lightning requires the reference of the
|
|
|
|
:class:`torchmetrics.Metric` in your model. This is found automatically if it is a model attribute.
|
|
|
|
rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which
|
|
|
|
would produce a deadlock as not all processes would perform this log call.
|
2020-09-28 00:26:16 +00:00
|
|
|
"""
|
2021-05-22 01:13:00 +00:00
|
|
|
if tbptt_reduce_fx is not None:
|
|
|
|
rank_zero_deprecation(
|
2021-07-26 11:37:35 +00:00
|
|
|
"`self.log(tbptt_reduce_fx=...)` is no longer supported. The flag will be removed in v1.6."
|
|
|
|
" Please, open a discussion explaining your use-case in"
|
|
|
|
" `https://github.com/PyTorchLightning/pytorch-lightning/discussions`"
|
2021-05-22 01:13:00 +00:00
|
|
|
)
|
|
|
|
if tbptt_pad_token is not None:
|
|
|
|
rank_zero_deprecation(
|
2021-07-26 11:37:35 +00:00
|
|
|
"`self.log(tbptt_pad_token=...)` is no longer supported. The flag will be removed in v1.6."
|
|
|
|
" Please, open a discussion explaining your use-case in"
|
|
|
|
" `https://github.com/PyTorchLightning/pytorch-lightning/discussions`"
|
2021-05-22 01:13:00 +00:00
|
|
|
)
|
2021-06-09 14:24:45 +00:00
|
|
|
if sync_dist_op is not None:
|
|
|
|
rank_zero_deprecation(
|
|
|
|
f"`self.log(sync_dist_op='{sync_dist_op}')` is deprecated and will be removed in v.1.6."
|
|
|
|
f" Use `self.log(reduce_fx={sync_dist_op})` instead."
|
|
|
|
)
|
2021-07-26 11:37:35 +00:00
|
|
|
if reduce_fx == "default":
|
2021-06-09 14:24:45 +00:00
|
|
|
reduce_fx = sync_dist_op
|
2021-07-26 11:37:35 +00:00
|
|
|
elif reduce_fx == "default":
|
|
|
|
reduce_fx = "mean"
|
2021-05-22 01:13:00 +00:00
|
|
|
|
2021-06-08 13:04:16 +00:00
|
|
|
# check for invalid values
|
|
|
|
apply_to_collection(value, dict, self.__check_not_nested, name)
|
|
|
|
apply_to_collection(
|
|
|
|
value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Metric, Tensor, dict)
|
|
|
|
)
|
2021-06-01 11:51:50 +00:00
|
|
|
|
2021-05-31 07:54:28 +00:00
|
|
|
# set the default depending on the fx_name
|
|
|
|
on_step = self.__auto_choose_log_on_step(on_step)
|
|
|
|
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)
|
|
|
|
|
2021-08-20 16:22:03 +00:00
|
|
|
if self.trainer is None:
|
2021-09-29 04:49:51 +00:00
|
|
|
# not an error to support testing the `*_step` methods without a `Trainer` reference
|
|
|
|
rank_zero_warn(
|
2021-08-20 16:22:03 +00:00
|
|
|
"You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet."
|
|
|
|
" This is most likely because the model hasn't been passed to the `Trainer`"
|
|
|
|
)
|
2021-09-29 04:49:51 +00:00
|
|
|
return
|
2021-06-09 14:24:45 +00:00
|
|
|
results = self.trainer._results
|
2021-08-20 16:22:03 +00:00
|
|
|
if results is None:
|
|
|
|
raise MisconfigurationException(
|
|
|
|
"You are trying to `self.log()` but the loop `ResultCollection` is not registered"
|
|
|
|
" yet. This is most likely because you are trying to log in a `predict` hook,"
|
|
|
|
" but it doesn't support logging"
|
|
|
|
)
|
|
|
|
if self._current_fx_name is None:
|
|
|
|
raise MisconfigurationException(
|
|
|
|
"You are trying to `self.log()` but it is not managed by the `Trainer` control flow"
|
|
|
|
)
|
2021-09-03 13:41:05 +00:00
|
|
|
_FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)
|
2021-05-31 07:54:28 +00:00
|
|
|
|
|
|
|
# make sure user doesn't introduce logic for multi-dataloaders
|
|
|
|
if "/dataloader_idx_" in name:
|
2021-06-08 13:04:16 +00:00
|
|
|
raise MisconfigurationException(
|
|
|
|
f"You called `self.log` with the key `{name}`"
|
|
|
|
" but it should not contain information about `dataloader_idx`"
|
|
|
|
)
|
2021-05-31 07:54:28 +00:00
|
|
|
|
2021-06-09 14:24:45 +00:00
|
|
|
value = apply_to_collection(value, numbers.Number, self.__to_tensor)
|
2021-05-24 12:13:55 +00:00
|
|
|
|
2021-06-09 14:24:45 +00:00
|
|
|
if self.trainer.logger_connector.should_reset_tensors(self._current_fx_name):
|
|
|
|
# if we started a new epoch (running it's first batch) the hook name has changed
|
|
|
|
# reset any tensors for the new hook name
|
|
|
|
results.reset(metrics=False, fx=self._current_fx_name)
|
|
|
|
|
2021-06-25 19:16:11 +00:00
|
|
|
if metric_attribute is None and isinstance(value, Metric):
|
|
|
|
if self._metric_attributes is None:
|
|
|
|
# compute once
|
|
|
|
self._metric_attributes = {
|
2021-07-26 11:37:35 +00:00
|
|
|
id(module): name for name, module in self.named_modules() if isinstance(module, Metric)
|
2021-06-25 19:16:11 +00:00
|
|
|
}
|
|
|
|
if not self._metric_attributes:
|
|
|
|
raise MisconfigurationException(
|
|
|
|
"Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
|
|
|
|
" You can fix this by setting an attribute for the metric in your `LightningModule`."
|
|
|
|
)
|
|
|
|
# try to find the passed metric in the LightningModule
|
2021-06-28 20:17:43 +00:00
|
|
|
metric_attribute = self._metric_attributes.get(id(value), None)
|
2021-06-25 19:16:11 +00:00
|
|
|
if metric_attribute is None:
|
|
|
|
raise MisconfigurationException(
|
|
|
|
"Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
|
|
|
|
f" You can fix this by calling `self.log({name}, ..., metric_attribute=name)` where `name` is one"
|
|
|
|
f" of {list(self._metric_attributes.values())}"
|
|
|
|
)
|
|
|
|
|
2021-08-24 18:45:54 +00:00
|
|
|
if (
|
|
|
|
self.trainer.training
|
|
|
|
and is_param_in_hook_signature(self.training_step, "dataloader_iter", explicit=True)
|
|
|
|
and batch_size is None
|
|
|
|
):
|
|
|
|
raise MisconfigurationException(
|
|
|
|
"With `def training_step(self, dataloader_iter)`, `self.log(..., batch_size=...)` should be provided."
|
|
|
|
)
|
|
|
|
|
2021-06-09 14:24:45 +00:00
|
|
|
results.log(
|
|
|
|
self._current_fx_name,
|
2021-05-31 07:54:28 +00:00
|
|
|
name,
|
|
|
|
value,
|
|
|
|
prog_bar=prog_bar,
|
|
|
|
logger=logger,
|
|
|
|
on_step=on_step,
|
|
|
|
on_epoch=on_epoch,
|
|
|
|
reduce_fx=reduce_fx,
|
|
|
|
enable_graph=enable_graph,
|
|
|
|
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
|
2021-06-09 14:24:45 +00:00
|
|
|
batch_size=batch_size,
|
2021-06-25 19:16:11 +00:00
|
|
|
sync_dist=sync_dist and distributed_available(),
|
|
|
|
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp,
|
2021-06-09 14:24:45 +00:00
|
|
|
sync_dist_group=sync_dist_group,
|
2021-06-25 19:16:11 +00:00
|
|
|
metric_attribute=metric_attribute,
|
|
|
|
rank_zero_only=rank_zero_only,
|
2021-05-31 07:54:28 +00:00
|
|
|
)
|
2020-09-28 00:26:16 +00:00
|
|
|
|
2021-06-09 14:24:45 +00:00
|
|
|
self.trainer.logger_connector._current_fx = self._current_fx_name
|
|
|
|
|
2020-09-30 02:12:56 +00:00
|
|
|
def log_dict(
|
|
|
|
self,
|
2021-06-07 08:31:36 +00:00
|
|
|
dictionary: Mapping[str, _METRIC_COLLECTION],
|
2020-09-30 02:12:56 +00:00
|
|
|
prog_bar: bool = False,
|
|
|
|
logger: bool = True,
|
2020-10-15 21:02:50 +00:00
|
|
|
on_step: Optional[bool] = None,
|
|
|
|
on_epoch: Optional[bool] = None,
|
2021-07-26 11:37:35 +00:00
|
|
|
reduce_fx: Union[str, Callable] = "default", # TODO: change to 'mean' when `sync_dist_op` is removed in 1.6
|
2021-08-02 16:05:56 +00:00
|
|
|
tbptt_reduce_fx: Optional[Any] = None, # todo: Remove in 1.6
|
|
|
|
tbptt_pad_token: Optional[Any] = None, # todo: Remove in 1.6
|
2020-09-30 02:12:56 +00:00
|
|
|
enable_graph: bool = False,
|
|
|
|
sync_dist: bool = False,
|
2021-08-02 16:05:56 +00:00
|
|
|
sync_dist_op: Optional[Any] = None, # todo: Remove in 1.6
|
2020-09-30 02:12:56 +00:00
|
|
|
sync_dist_group: Optional[Any] = None,
|
2021-03-02 16:03:36 +00:00
|
|
|
add_dataloader_idx: bool = True,
|
2021-08-03 22:05:34 +00:00
|
|
|
batch_size: Optional[int] = None,
|
|
|
|
rank_zero_only: Optional[bool] = None,
|
2021-05-19 19:25:32 +00:00
|
|
|
) -> None:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Log a dictionary of values at once.
|
2020-09-30 02:12:56 +00:00
|
|
|
|
|
|
|
Example::
|
|
|
|
|
|
|
|
values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n}
|
|
|
|
self.log_dict(values)
|
|
|
|
|
|
|
|
Args:
|
2021-06-09 14:24:45 +00:00
|
|
|
dictionary: key value pairs.
|
|
|
|
The values can be a ``float``, ``Tensor``, ``Metric``, or a dictionary of the former.
|
2020-09-30 02:12:56 +00:00
|
|
|
prog_bar: if True logs to the progress base
|
|
|
|
logger: if True logs to the logger
|
|
|
|
on_step: if True logs at this step. None auto-logs for training_step but not validation/test_step
|
|
|
|
on_epoch: if True logs epoch accumulated metrics. None auto-logs for val/test step but not training_step
|
2021-06-08 13:04:16 +00:00
|
|
|
reduce_fx: reduction function over step values for end of epoch. :meth:`torch.mean` by default.
|
2020-09-30 02:12:56 +00:00
|
|
|
enable_graph: if True, will not auto detach the graph
|
|
|
|
sync_dist: if True, reduces the metric across GPUs/TPUs
|
2021-03-02 16:03:36 +00:00
|
|
|
sync_dist_group: the ddp group sync across
|
|
|
|
add_dataloader_idx: if True, appends the index of the current dataloader to
|
|
|
|
the name (when using multiple). If False, user needs to give unique names for
|
|
|
|
each dataloader to not mix values
|
2021-08-03 22:05:34 +00:00
|
|
|
batch_size: Current batch_size. This will be directly inferred from the loaded batch,
|
|
|
|
but some data structures might need to explicitly provide it.
|
|
|
|
rank_zero_only: Whether the value will be logged only on rank 0. This will prevent synchronization which
|
|
|
|
would produce a deadlock as not all processes would perform this log call.
|
2020-09-30 02:12:56 +00:00
|
|
|
"""
|
|
|
|
for k, v in dictionary.items():
|
|
|
|
self.log(
|
|
|
|
name=k,
|
|
|
|
value=v,
|
|
|
|
prog_bar=prog_bar,
|
|
|
|
logger=logger,
|
|
|
|
on_step=on_step,
|
|
|
|
on_epoch=on_epoch,
|
|
|
|
reduce_fx=reduce_fx,
|
|
|
|
enable_graph=enable_graph,
|
|
|
|
sync_dist=sync_dist,
|
|
|
|
sync_dist_group=sync_dist_group,
|
|
|
|
sync_dist_op=sync_dist_op,
|
|
|
|
tbptt_pad_token=tbptt_pad_token,
|
|
|
|
tbptt_reduce_fx=tbptt_reduce_fx,
|
2021-07-26 11:37:35 +00:00
|
|
|
add_dataloader_idx=add_dataloader_idx,
|
2021-08-03 22:05:34 +00:00
|
|
|
batch_size=batch_size,
|
|
|
|
rank_zero_only=rank_zero_only,
|
2020-09-30 02:12:56 +00:00
|
|
|
)
|
|
|
|
|
2021-05-24 12:13:55 +00:00
|
|
|
@staticmethod
|
2021-06-09 14:24:45 +00:00
|
|
|
def __check_not_nested(value: dict, name: str) -> dict:
|
2021-06-08 13:04:16 +00:00
|
|
|
# self-imposed restriction. for simplicity
|
|
|
|
if any(isinstance(v, dict) for v in value.values()):
|
2021-07-26 11:37:35 +00:00
|
|
|
raise ValueError(f"`self.log({name}, {value})` was called, but nested dictionaries cannot be logged")
|
2021-06-08 13:04:16 +00:00
|
|
|
return value
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def __check_allowed(v: Any, name: str, value: Any) -> None:
|
2021-07-26 11:37:35 +00:00
|
|
|
raise ValueError(f"`self.log({name}, {value})` was called, but `{type(v).__name__}` values cannot be logged")
|
2021-06-01 11:51:50 +00:00
|
|
|
|
2021-06-09 14:24:45 +00:00
|
|
|
def __to_tensor(self, value: numbers.Number) -> torch.Tensor:
|
|
|
|
return torch.tensor(value, device=self.device)
|
|
|
|
|
2021-09-06 11:54:07 +00:00
|
|
|
def log_grad_norm(self, grad_norm_dict: Dict[str, float]) -> None:
|
2021-06-08 11:09:06 +00:00
|
|
|
"""Override this method to change the default behaviour of ``log_grad_norm``.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
grad_norm_dict: Dictionary containing current grad norm metrics
|
|
|
|
|
|
|
|
Example::
|
|
|
|
|
|
|
|
# DEFAULT
|
|
|
|
def log_grad_norm(self, grad_norm_dict):
|
|
|
|
self.log_dict(grad_norm_dict, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
|
|
|
"""
|
|
|
|
self.log_dict(grad_norm_dict, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
|
|
|
|
2021-05-19 19:25:32 +00:00
|
|
|
def __auto_choose_log_on_step(self, on_step: Optional[bool]) -> bool:
|
2020-09-29 06:00:28 +00:00
|
|
|
if on_step is None:
|
2021-05-19 19:25:32 +00:00
|
|
|
on_step = False
|
2021-07-26 11:37:35 +00:00
|
|
|
on_step |= self._current_fx_name in ("training_step", "training_step_end")
|
2020-09-29 06:00:28 +00:00
|
|
|
return on_step
|
|
|
|
|
2021-05-19 19:25:32 +00:00
|
|
|
def __auto_choose_log_on_epoch(self, on_epoch: Optional[bool]) -> bool:
|
2020-09-29 06:00:28 +00:00
|
|
|
if on_epoch is None:
|
2021-05-19 19:25:32 +00:00
|
|
|
on_epoch = True
|
2021-07-26 11:37:35 +00:00
|
|
|
on_epoch &= self._current_fx_name not in ("training_step", "training_step_end")
|
2020-09-29 06:00:28 +00:00
|
|
|
return on_epoch
|
|
|
|
|
2021-01-09 12:37:44 +00:00
|
|
|
def all_gather(
|
2021-07-26 11:37:35 +00:00
|
|
|
self, data: Union[torch.Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False
|
2021-01-09 12:37:44 +00:00
|
|
|
):
|
2020-12-08 23:20:01 +00:00
|
|
|
r"""
|
2021-07-05 11:43:27 +00:00
|
|
|
Allows users to call ``self.all_gather()`` from the LightningModule, thus making the ``all_gather`` operation
|
|
|
|
accelerator agnostic. ``all_gather`` is a function provided by accelerators to gather a tensor from several
|
|
|
|
distributed processes.
|
2020-12-08 23:20:01 +00:00
|
|
|
|
|
|
|
Args:
|
2021-07-05 11:43:27 +00:00
|
|
|
data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof.
|
2020-12-08 23:20:01 +00:00
|
|
|
group: the process group to gather results from. Defaults to all processes (world)
|
2021-07-05 11:43:27 +00:00
|
|
|
sync_grads: flag that allows users to synchronize gradients for the all_gather operation
|
2020-12-08 23:20:01 +00:00
|
|
|
|
|
|
|
Return:
|
2021-01-09 12:37:44 +00:00
|
|
|
A tensor of shape (world_size, batch, ...), or if the input was a collection
|
|
|
|
the output will also be a collection with tensors of this shape.
|
2020-12-08 23:20:01 +00:00
|
|
|
"""
|
2021-01-09 12:37:44 +00:00
|
|
|
group = group if group is not None else torch.distributed.group.WORLD
|
2021-09-27 12:52:57 +00:00
|
|
|
all_gather = self.trainer.training_type_plugin.all_gather
|
2021-01-09 12:37:44 +00:00
|
|
|
data = convert_to_tensors(data, device=self.device)
|
2021-06-08 13:04:16 +00:00
|
|
|
return apply_to_collection(data, torch.Tensor, all_gather, group=group, sync_grads=sync_grads)
|
2020-12-08 23:20:01 +00:00
|
|
|
|
2021-04-19 12:43:16 +00:00
|
|
|
def forward(self, *args, **kwargs) -> Any:
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2021-04-19 13:53:21 +00:00
|
|
|
Same as :meth:`torch.nn.Module.forward()`.
|
2020-06-15 21:04:32 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Args:
|
2020-04-06 12:12:44 +00:00
|
|
|
*args: Whatever you decide to pass into the forward method.
|
|
|
|
**kwargs: Keyword arguments are also possible.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
Return:
|
2021-04-19 13:53:21 +00:00
|
|
|
Your model's output
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
2020-09-03 18:24:44 +00:00
|
|
|
return super().forward(*args, **kwargs)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2021-04-19 12:43:16 +00:00
|
|
|
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
|
2020-04-06 12:12:44 +00:00
|
|
|
r"""
|
|
|
|
Here you compute and return the training loss and some additional metrics for e.g.
|
|
|
|
the progress bar or logger.
|
2020-02-11 04:55:22 +00:00
|
|
|
|
|
|
|
Args:
|
2020-04-06 12:12:44 +00:00
|
|
|
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
|
|
|
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
2021-09-09 07:45:52 +00:00
|
|
|
batch_idx (``int``): Integer displaying index of this batch
|
|
|
|
optimizer_idx (``int``): When using multiple optimizers, this argument will also be present.
|
|
|
|
hiddens (``Any``): Passed in if
|
2021-05-05 10:21:00 +00:00
|
|
|
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Return:
|
2020-10-13 09:52:04 +00:00
|
|
|
Any of.
|
|
|
|
|
|
|
|
- :class:`~torch.Tensor` - The loss tensor
|
2021-01-20 17:27:32 +00:00
|
|
|
- ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``
|
2021-09-07 11:52:20 +00:00
|
|
|
- ``None`` - Training will skip to the next batch. This is only for automatic optimization.
|
2021-10-20 15:43:08 +00:00
|
|
|
This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
In this step you'd normally do the forward pass and calculate the loss for a batch.
|
2020-03-05 23:52:17 +00:00
|
|
|
You can also do fancier things like multiple forward passes or something model specific.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
Example::
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
x, y, z = batch
|
2020-09-30 12:31:16 +00:00
|
|
|
out = self.encoder(x)
|
2020-08-11 23:39:43 +00:00
|
|
|
loss = self.loss(out, x)
|
2020-09-30 12:31:16 +00:00
|
|
|
return loss
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
If you define multiple optimizers, this step will be called with an additional
|
|
|
|
``optimizer_idx`` parameter.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
# Multiple optimizers (e.g.: GANs)
|
|
|
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
|
|
|
if optimizer_idx == 0:
|
|
|
|
# do training_step with encoder
|
2021-07-30 12:10:15 +00:00
|
|
|
...
|
2020-08-11 23:39:43 +00:00
|
|
|
if optimizer_idx == 1:
|
|
|
|
# do training_step with decoder
|
2021-07-30 12:10:15 +00:00
|
|
|
...
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
|
|
|
|
If you add truncated back propagation through time you will also get an additional
|
|
|
|
argument with the hidden states of the previous step.
|
|
|
|
|
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
# Truncated back-propagation through time
|
|
|
|
def training_step(self, batch, batch_idx, hiddens):
|
|
|
|
# hiddens are the hidden states from the previous truncated backprop step
|
|
|
|
out, hiddens = self.lstm(data, hiddens)
|
2021-09-09 07:45:52 +00:00
|
|
|
loss = ...
|
2021-07-30 12:10:15 +00:00
|
|
|
return {"loss": loss, "hiddens": hiddens}
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-10-13 09:52:04 +00:00
|
|
|
Note:
|
2020-04-06 12:12:44 +00:00
|
|
|
The loss value shown in the progress bar is smoothed (averaged) over the last values,
|
|
|
|
so it differs from the actual loss returned in train/validation step.
|
2019-08-13 15:37:37 +00:00
|
|
|
"""
|
2021-02-08 19:29:43 +00:00
|
|
|
rank_zero_warn("`training_step` must be implemented to be used with the Lightning Trainer")
|
2019-08-13 15:37:37 +00:00
|
|
|
|
2021-04-19 12:43:16 +00:00
|
|
|
def training_step_end(self, *args, **kwargs) -> STEP_OUTPUT:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Use this when training with dp or ddp2 because :meth:`training_step` will operate on only part of the
|
|
|
|
batch. However, this is still optional and only needed for things like softmax or NCE loss.
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
Note:
|
|
|
|
If you later switch to ddp or some other mode, this will still be called
|
|
|
|
so that you don't have to change your code
|
2020-04-03 12:43:26 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
# pseudocode
|
|
|
|
sub_batches = split_batches_for_dp(batch)
|
|
|
|
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
|
|
|
|
training_step_end(batch_parts_outputs)
|
2020-04-03 12:43:26 +00:00
|
|
|
|
|
|
|
Args:
|
2020-08-11 23:39:43 +00:00
|
|
|
batch_parts_outputs: What you return in `training_step` for each batch part.
|
2020-04-03 12:43:26 +00:00
|
|
|
|
|
|
|
Return:
|
2020-09-30 12:31:16 +00:00
|
|
|
Anything
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
When using dp/ddp2 distributed backends, only a portion of the batch is inside the training_step:
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
.. code-block:: python
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
# batch is 1/num_gpus big
|
|
|
|
x, y = batch
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
out = self(x)
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2021-07-30 12:16:47 +00:00
|
|
|
# softmax uses only a portion of the batch in the denominator
|
2020-08-11 23:39:43 +00:00
|
|
|
loss = self.softmax(out)
|
|
|
|
loss = nce_loss(loss)
|
2020-09-30 12:31:16 +00:00
|
|
|
return loss
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
If you wish to do something with all the parts of the batch, then use this method to do it:
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
.. code-block:: python
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
# batch is 1/num_gpus big
|
|
|
|
x, y = batch
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
out = self.encoder(x)
|
2021-07-30 12:10:15 +00:00
|
|
|
return {"pred": out}
|
|
|
|
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
def training_step_end(self, training_step_outputs):
|
2021-07-30 12:10:15 +00:00
|
|
|
gpu_0_pred = training_step_outputs[0]["pred"]
|
|
|
|
gpu_1_pred = training_step_outputs[1]["pred"]
|
|
|
|
gpu_n_pred = training_step_outputs[n]["pred"]
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
# this softmax now uses the full batch
|
2020-09-30 12:31:16 +00:00
|
|
|
loss = nce_loss([gpu_0_pred, gpu_1_pred, gpu_n_pred])
|
|
|
|
return loss
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
See Also:
|
2021-01-26 20:07:07 +00:00
|
|
|
See the :ref:`advanced/multi_gpu:Multi-GPU training` guide for more details.
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
|
|
|
|
2021-04-19 12:43:16 +00:00
|
|
|
def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Called at the end of the training epoch with the outputs of all training steps. Use this in case you
|
|
|
|
need to do something with all the outputs returned by :meth:`training_step`.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
# the pseudocode for these calls
|
|
|
|
train_outs = []
|
|
|
|
for train_batch in train_data:
|
|
|
|
out = training_step(train_batch)
|
|
|
|
train_outs.append(out)
|
|
|
|
training_epoch_end(train_outs)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Args:
|
2021-09-28 13:22:22 +00:00
|
|
|
outputs: List of outputs you defined in :meth:`training_step`.
|
|
|
|
If there are multiple optimizers, it is a list containing a list of outputs for each optimizer.
|
|
|
|
If using ``truncated_bptt_steps > 1``, each element is a list of outputs corresponding to the outputs
|
|
|
|
of each processed split batch.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Return:
|
2020-09-30 12:31:16 +00:00
|
|
|
None
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
Note:
|
|
|
|
If this method is not overridden, this won't be called.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
def training_epoch_end(self, training_step_outputs):
|
2021-09-28 13:22:22 +00:00
|
|
|
# do something with all training_step outputs
|
2020-09-30 12:31:16 +00:00
|
|
|
for out in training_step_outputs:
|
2021-07-30 12:10:15 +00:00
|
|
|
...
|
2019-11-05 15:01:52 +00:00
|
|
|
"""
|
|
|
|
|
2021-04-19 12:43:16 +00:00
|
|
|
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2020-04-06 12:12:44 +00:00
|
|
|
Operates on a single batch of data from the validation set.
|
2020-03-05 23:52:17 +00:00
|
|
|
In this step you'd might generate examples or calculate anything of interest like accuracy.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# the pseudocode for these calls
|
|
|
|
val_outs = []
|
|
|
|
for val_batch in val_data:
|
2020-11-14 01:53:11 +00:00
|
|
|
out = validation_step(val_batch)
|
2020-03-27 12:43:12 +00:00
|
|
|
val_outs.append(out)
|
2021-01-11 11:02:30 +00:00
|
|
|
validation_epoch_end(val_outs)
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Args:
|
2020-04-06 12:12:44 +00:00
|
|
|
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
|
|
|
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
2020-01-17 11:03:31 +00:00
|
|
|
batch_idx (int): The index of this batch
|
2020-03-05 23:52:17 +00:00
|
|
|
dataloader_idx (int): The index of the dataloader that produced this batch
|
2021-01-11 11:02:30 +00:00
|
|
|
(only if multiple val dataloaders used)
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
Return:
|
2020-10-13 09:52:04 +00:00
|
|
|
- Any object or value
|
2021-01-20 17:27:32 +00:00
|
|
|
- ``None`` - Validation will skip to the next batch
|
2020-03-06 15:33:17 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# pseudocode of order
|
2021-03-25 13:20:49 +00:00
|
|
|
val_outs = []
|
|
|
|
for val_batch in val_data:
|
|
|
|
out = validation_step(val_batch)
|
2021-07-30 12:10:15 +00:00
|
|
|
if defined("validation_step_end"):
|
2021-03-25 13:20:49 +00:00
|
|
|
out = validation_step_end(out)
|
|
|
|
val_outs.append(out)
|
|
|
|
val_outs = validation_epoch_end(val_outs)
|
2020-03-06 15:33:17 +00:00
|
|
|
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# if you have one val dataloader:
|
2021-07-30 12:10:15 +00:00
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
...
|
|
|
|
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
# if you have multiple val dataloaders:
|
2021-07-30 12:10:15 +00:00
|
|
|
def validation_step(self, batch, batch_idx, dataloader_idx):
|
|
|
|
...
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
Examples::
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# CASE 1: A single validation dataset
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# implement your own
|
|
|
|
out = self(x)
|
|
|
|
loss = self.loss(out, y)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# log 6 example images
|
|
|
|
# or generated text... or whatever
|
|
|
|
sample_imgs = x[:6]
|
|
|
|
grid = torchvision.utils.make_grid(sample_imgs)
|
|
|
|
self.logger.experiment.add_image('example_images', grid, 0)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# calculate acc
|
|
|
|
labels_hat = torch.argmax(out, dim=1)
|
|
|
|
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# log the outputs!
|
|
|
|
self.log_dict({'val_loss': loss, 'val_acc': val_acc})
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-11 11:02:30 +00:00
|
|
|
If you pass in multiple val dataloaders, :meth:`validation_step` will have an additional argument.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-11 11:02:30 +00:00
|
|
|
# CASE 2: multiple validation dataloaders
|
2021-01-26 09:44:54 +00:00
|
|
|
def validation_step(self, batch, batch_idx, dataloader_idx):
|
|
|
|
# dataloader_idx tells you which dataset this is.
|
2021-07-30 12:10:15 +00:00
|
|
|
...
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
Note:
|
|
|
|
If you don't need to validate you don't need to implement this method.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
Note:
|
|
|
|
When the :meth:`validation_step` is called, the model has been put in eval mode
|
2020-03-05 23:52:17 +00:00
|
|
|
and PyTorch gradients have been disabled. At the end of validation,
|
2020-03-06 15:33:17 +00:00
|
|
|
the model goes back to training mode and gradients are enabled.
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
|
|
|
|
2021-04-19 12:43:16 +00:00
|
|
|
def validation_step_end(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Use this when validating with dp or ddp2 because :meth:`validation_step` will operate on only part of
|
|
|
|
the batch. However, this is still optional and only needed for things like softmax or NCE loss.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
Note:
|
|
|
|
If you later switch to ddp or some other mode, this will still be called
|
|
|
|
so that you don't have to change your code.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
# pseudocode
|
|
|
|
sub_batches = split_batches_for_dp(batch)
|
2020-03-06 15:33:17 +00:00
|
|
|
batch_parts_outputs = [validation_step(sub_batch) for sub_batch in sub_batches]
|
2020-03-05 17:32:45 +00:00
|
|
|
validation_step_end(batch_parts_outputs)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Args:
|
2020-04-06 12:12:44 +00:00
|
|
|
batch_parts_outputs: What you return in :meth:`validation_step`
|
|
|
|
for each batch part.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Return:
|
2020-09-30 12:31:16 +00:00
|
|
|
None or anything
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
.. code-block:: python
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
# WITHOUT validation_step_end
|
|
|
|
# if used in DP or DDP2, this batch is 1/num_gpus large
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
# batch is 1/num_gpus big
|
|
|
|
x, y = batch
|
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
out = self.encoder(x)
|
2020-08-11 23:39:43 +00:00
|
|
|
loss = self.softmax(out)
|
|
|
|
loss = nce_loss(loss)
|
2021-07-30 12:10:15 +00:00
|
|
|
self.log("val_loss", loss)
|
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
# --------------
|
|
|
|
# with validation_step_end to do softmax over the full batch
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
# batch is 1/num_gpus big
|
|
|
|
x, y = batch
|
|
|
|
|
|
|
|
out = self(x)
|
2020-09-30 12:31:16 +00:00
|
|
|
return out
|
2020-08-11 23:39:43 +00:00
|
|
|
|
2021-07-30 12:10:15 +00:00
|
|
|
|
2020-12-24 18:37:30 +00:00
|
|
|
def validation_step_end(self, val_step_outputs):
|
2020-09-30 12:31:16 +00:00
|
|
|
for out in val_step_outputs:
|
2021-07-30 12:10:15 +00:00
|
|
|
...
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
See Also:
|
2021-01-26 20:07:07 +00:00
|
|
|
See the :ref:`advanced/multi_gpu:Multi-GPU training` guide for more details.
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-04-19 12:43:16 +00:00
|
|
|
def validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Called at the end of the validation epoch with the outputs of all validation steps.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
# the pseudocode for these calls
|
|
|
|
val_outs = []
|
|
|
|
for val_batch in val_data:
|
2020-04-13 16:16:54 +00:00
|
|
|
out = validation_step(val_batch)
|
2020-04-03 12:43:26 +00:00
|
|
|
val_outs.append(out)
|
2020-03-05 17:32:45 +00:00
|
|
|
validation_epoch_end(val_outs)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
Args:
|
2020-04-06 12:12:44 +00:00
|
|
|
outputs: List of outputs you defined in :meth:`validation_step`, or if there
|
|
|
|
are multiple dataloaders, a list containing a list of outputs for each dataloader.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
Return:
|
2020-09-30 12:31:16 +00:00
|
|
|
None
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
Note:
|
|
|
|
If you didn't define a :meth:`validation_step`, this won't be called.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
2020-04-06 12:12:44 +00:00
|
|
|
With a single dataloader:
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
def validation_epoch_end(self, val_step_outputs):
|
2020-09-30 12:31:16 +00:00
|
|
|
for out in val_step_outputs:
|
2021-07-30 12:10:15 +00:00
|
|
|
...
|
2020-03-05 23:52:17 +00:00
|
|
|
|
|
|
|
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
|
|
|
|
one entry per dataloader, while the inner list contains the individual outputs of
|
|
|
|
each validation step for that dataloader.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def validation_epoch_end(self, outputs):
|
2020-08-11 23:39:43 +00:00
|
|
|
for dataloader_output_result in outputs:
|
|
|
|
dataloader_outs = dataloader_output_result.dataloader_i_outputs
|
|
|
|
|
2021-07-30 12:10:15 +00:00
|
|
|
self.log("final_metric", final_value)
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
|
|
|
|
2021-04-19 12:43:16 +00:00
|
|
|
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
|
2020-03-05 17:32:45 +00:00
|
|
|
r"""
|
2020-04-06 12:12:44 +00:00
|
|
|
Operates on a single batch of data from the test set.
|
2020-03-05 23:52:17 +00:00
|
|
|
In this step you'd normally generate examples or calculate anything of interest
|
|
|
|
such as accuracy.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# the pseudocode for these calls
|
|
|
|
test_outs = []
|
|
|
|
for test_batch in test_data:
|
2020-04-06 12:12:44 +00:00
|
|
|
out = test_step(test_batch)
|
2020-03-06 15:33:17 +00:00
|
|
|
test_outs.append(out)
|
2020-03-05 17:32:45 +00:00
|
|
|
test_epoch_end(test_outs)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
Args:
|
2020-04-06 12:12:44 +00:00
|
|
|
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
|
|
|
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
|
|
|
batch_idx (int): The index of this batch.
|
2020-03-05 23:52:17 +00:00
|
|
|
dataloader_idx (int): The index of the dataloader that produced this batch
|
2021-01-11 11:02:30 +00:00
|
|
|
(only if multiple test dataloaders used).
|
2020-03-05 17:32:45 +00:00
|
|
|
|
|
|
|
Return:
|
2020-10-13 09:52:04 +00:00
|
|
|
Any of.
|
|
|
|
|
|
|
|
- Any object or value
|
2021-01-20 17:27:32 +00:00
|
|
|
- ``None`` - Testing will skip to the next batch
|
2020-03-05 17:32:45 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# if you have one test dataloader:
|
2021-07-30 12:10:15 +00:00
|
|
|
def test_step(self, batch, batch_idx):
|
|
|
|
...
|
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
|
|
|
|
# if you have multiple test dataloaders:
|
2021-07-30 12:10:15 +00:00
|
|
|
def test_step(self, batch, batch_idx, dataloader_idx):
|
|
|
|
...
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
Examples::
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# CASE 1: A single test dataset
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# implement your own
|
|
|
|
out = self(x)
|
|
|
|
loss = self.loss(out, y)
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# log 6 example images
|
|
|
|
# or generated text... or whatever
|
|
|
|
sample_imgs = x[:6]
|
|
|
|
grid = torchvision.utils.make_grid(sample_imgs)
|
|
|
|
self.logger.experiment.add_image('example_images', grid, 0)
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# calculate acc
|
|
|
|
labels_hat = torch.argmax(out, dim=1)
|
|
|
|
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# log the outputs!
|
|
|
|
self.log_dict({'test_loss': loss, 'test_acc': test_acc})
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2021-01-11 11:02:30 +00:00
|
|
|
If you pass in multiple test dataloaders, :meth:`test_step` will have an additional argument.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
.. code-block:: python
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2021-01-11 11:02:30 +00:00
|
|
|
# CASE 2: multiple test dataloaders
|
2021-01-26 09:44:54 +00:00
|
|
|
def test_step(self, batch, batch_idx, dataloader_idx):
|
|
|
|
# dataloader_idx tells you which dataset this is.
|
2021-07-30 12:10:15 +00:00
|
|
|
...
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
Note:
|
2021-01-11 11:02:30 +00:00
|
|
|
If you don't need to test you don't need to implement this method.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
Note:
|
|
|
|
When the :meth:`test_step` is called, the model has been put in eval mode and
|
2020-03-06 15:33:17 +00:00
|
|
|
PyTorch gradients have been disabled. At the end of the test epoch, the model goes back
|
2020-03-05 23:52:17 +00:00
|
|
|
to training mode and gradients are enabled.
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
|
|
|
|
2021-04-19 12:43:16 +00:00
|
|
|
def test_step_end(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Use this when testing with dp or ddp2 because :meth:`test_step` will operate on only part of the batch.
|
|
|
|
However, this is still optional and only needed for things like softmax or NCE loss.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
Note:
|
|
|
|
If you later switch to ddp or some other mode, this will still be called
|
|
|
|
so that you don't have to change your code.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# pseudocode
|
|
|
|
sub_batches = split_batches_for_dp(batch)
|
2020-03-06 15:33:17 +00:00
|
|
|
batch_parts_outputs = [test_step(sub_batch) for sub_batch in sub_batches]
|
2020-03-05 17:32:45 +00:00
|
|
|
test_step_end(batch_parts_outputs)
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Args:
|
2020-04-06 12:12:44 +00:00
|
|
|
batch_parts_outputs: What you return in :meth:`test_step` for each batch part.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Return:
|
2020-09-30 12:31:16 +00:00
|
|
|
None or anything
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
.. code-block:: python
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
# WITHOUT test_step_end
|
|
|
|
# if used in DP or DDP2, this batch is 1/num_gpus large
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
|
|
# batch is 1/num_gpus big
|
|
|
|
x, y = batch
|
|
|
|
|
|
|
|
out = self(x)
|
|
|
|
loss = self.softmax(out)
|
2021-07-30 12:10:15 +00:00
|
|
|
self.log("test_loss", loss)
|
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
# --------------
|
|
|
|
# with test_step_end to do softmax over the full batch
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
|
|
# batch is 1/num_gpus big
|
|
|
|
x, y = batch
|
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
out = self.encoder(x)
|
|
|
|
return out
|
2020-08-11 23:39:43 +00:00
|
|
|
|
2021-07-30 12:10:15 +00:00
|
|
|
|
2020-12-24 18:37:30 +00:00
|
|
|
def test_step_end(self, output_results):
|
2020-08-11 23:39:43 +00:00
|
|
|
# this out is now the full size of the batch
|
|
|
|
all_test_step_outs = output_results.out
|
|
|
|
loss = nce_loss(all_test_step_outs)
|
2021-07-30 12:10:15 +00:00
|
|
|
self.log("test_loss", loss)
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
See Also:
|
2021-01-26 20:07:07 +00:00
|
|
|
See the :ref:`advanced/multi_gpu:Multi-GPU training` guide for more details.
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
|
|
|
|
2021-04-19 12:43:16 +00:00
|
|
|
def test_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Called at the end of a test epoch with the output of all test steps.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# the pseudocode for these calls
|
|
|
|
test_outs = []
|
|
|
|
for test_batch in test_data:
|
|
|
|
out = test_step(test_batch)
|
|
|
|
test_outs.append(out)
|
|
|
|
test_epoch_end(test_outs)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
Args:
|
2020-04-06 12:12:44 +00:00
|
|
|
outputs: List of outputs you defined in :meth:`test_step_end`, or if there
|
|
|
|
are multiple dataloaders, a list containing a list of outputs for each dataloader
|
2020-03-05 17:32:45 +00:00
|
|
|
|
|
|
|
Return:
|
2020-09-30 12:31:16 +00:00
|
|
|
None
|
2020-04-06 12:12:44 +00:00
|
|
|
|
|
|
|
Note:
|
|
|
|
If you didn't define a :meth:`test_step`, this won't be called.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
2020-04-06 12:12:44 +00:00
|
|
|
With a single dataloader:
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
def test_epoch_end(self, outputs):
|
2020-08-11 23:39:43 +00:00
|
|
|
# do something with the outputs of all test batches
|
|
|
|
all_test_preds = test_step_outputs.predictions
|
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
some_result = calc_all_results(all_test_preds)
|
|
|
|
self.log(some_result)
|
2020-03-05 23:52:17 +00:00
|
|
|
|
|
|
|
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
|
|
|
|
one entry per dataloader, while the inner list contains the individual outputs of
|
|
|
|
each test step for that dataloader.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def test_epoch_end(self, outputs):
|
2020-09-30 12:31:16 +00:00
|
|
|
final_value = 0
|
|
|
|
for dataloader_outputs in outputs:
|
|
|
|
for test_step_out in dataloader_outputs:
|
|
|
|
# do something
|
|
|
|
final_value += test_step_out
|
2020-08-11 23:39:43 +00:00
|
|
|
|
2021-07-30 12:10:15 +00:00
|
|
|
self.log("final_metric", final_value)
|
2019-08-30 22:56:09 +00:00
|
|
|
"""
|
|
|
|
|
2021-04-19 12:43:16 +00:00
|
|
|
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Step function called during :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. By default, it
|
|
|
|
calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward`. Override to add any processing
|
|
|
|
logic.
|
2021-04-19 13:53:21 +00:00
|
|
|
|
2021-06-16 11:23:27 +00:00
|
|
|
The :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` is used
|
|
|
|
to scale inference on multi-devices.
|
|
|
|
|
|
|
|
To prevent an OOM error, it is possible to use :class:`~pytorch_lightning.callbacks.BasePredictionWriter`
|
|
|
|
callback to write the predictions to disk or database after each batch or on epoch end.
|
|
|
|
|
|
|
|
The :class:`~pytorch_lightning.callbacks.BasePredictionWriter` should be used while using a spawn
|
2021-10-20 15:32:53 +00:00
|
|
|
based accelerator. This happens for ``Trainer(strategy="ddp_spawn")``
|
2021-06-16 11:23:27 +00:00
|
|
|
or training on 8 TPU cores with ``Trainer(tpu_cores=8)`` as predictions won't be returned.
|
|
|
|
|
|
|
|
Example ::
|
|
|
|
|
|
|
|
class MyModel(LightningModule):
|
|
|
|
|
|
|
|
def predicts_step(self, batch, batch_idx, dataloader_idx):
|
|
|
|
return self(batch)
|
|
|
|
|
|
|
|
dm = ...
|
|
|
|
model = MyModel()
|
|
|
|
trainer = Trainer(gpus=2)
|
|
|
|
predictions = trainer.predict(model, dm)
|
|
|
|
|
|
|
|
|
2021-04-19 13:53:21 +00:00
|
|
|
Args:
|
|
|
|
batch: Current batch
|
|
|
|
batch_idx: Index of current batch
|
|
|
|
dataloader_idx: Index of the current dataloader
|
|
|
|
|
|
|
|
Return:
|
|
|
|
Predicted output
|
2021-01-27 16:38:14 +00:00
|
|
|
"""
|
|
|
|
return self(batch)
|
|
|
|
|
2021-02-13 00:27:44 +00:00
|
|
|
def configure_callbacks(self):
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Configure model-specific callbacks. When the model gets attached, e.g., when ``.fit()`` or ``.test()``
|
|
|
|
gets called, the list returned here will be merged with the list of callbacks passed to the Trainer's
|
|
|
|
``callbacks`` argument. If a callback returned here has the same type as one or several callbacks already
|
|
|
|
present in the Trainer's callbacks list, it will take priority and replace them. In addition, Lightning
|
|
|
|
will make sure :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks run last.
|
2021-02-13 00:27:44 +00:00
|
|
|
|
|
|
|
Return:
|
|
|
|
A list of callbacks which will extend the list of callbacks in the Trainer.
|
|
|
|
|
|
|
|
Example::
|
|
|
|
|
|
|
|
def configure_callbacks(self):
|
2021-09-18 13:34:09 +00:00
|
|
|
early_stop = EarlyStopping(monitor="val_acc", mode="max")
|
2021-02-13 00:27:44 +00:00
|
|
|
checkpoint = ModelCheckpoint(monitor="val_loss")
|
|
|
|
return [early_stop, checkpoint]
|
|
|
|
|
|
|
|
Note:
|
|
|
|
Certain callback methods like :meth:`~pytorch_lightning.callbacks.base.Callback.on_init_start`
|
|
|
|
will never be invoked on the new callbacks returned here.
|
|
|
|
"""
|
|
|
|
return []
|
|
|
|
|
2021-02-08 19:29:43 +00:00
|
|
|
def configure_optimizers(self):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2020-03-05 23:52:17 +00:00
|
|
|
Choose what optimizers and learning-rate schedulers to use in your optimization.
|
|
|
|
Normally you'd need one. But in the case of GANs or similar you might have multiple.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
Return:
|
|
|
|
Any of these 6 options.
|
|
|
|
|
2021-04-19 14:08:49 +00:00
|
|
|
- **Single optimizer**.
|
|
|
|
- **List or Tuple** of optimizers.
|
2021-05-07 14:10:24 +00:00
|
|
|
- **Two lists** - The first list has multiple optimizers, and the second has multiple LR schedulers
|
2021-09-04 00:47:43 +00:00
|
|
|
(or multiple ``lr_scheduler_config``).
|
2021-04-19 14:08:49 +00:00
|
|
|
- **Dictionary**, with an ``"optimizer"`` key, and (optionally) a ``"lr_scheduler"``
|
2021-09-04 00:47:43 +00:00
|
|
|
key whose value is a single LR scheduler or ``lr_scheduler_config``.
|
2021-04-19 14:08:49 +00:00
|
|
|
- **Tuple of dictionaries** as described above, with an optional ``"frequency"`` key.
|
|
|
|
- **None** - Fit will run without any optimizer.
|
2020-03-31 16:41:24 +00:00
|
|
|
|
2021-09-04 00:47:43 +00:00
|
|
|
The ``lr_scheduler_config`` is a dictionary which contains the scheduler and its associated configuration.
|
2021-05-07 14:10:24 +00:00
|
|
|
The default configuration is shown below.
|
2021-05-06 08:39:01 +00:00
|
|
|
|
2021-05-07 14:10:24 +00:00
|
|
|
.. code-block:: python
|
2021-05-06 08:39:01 +00:00
|
|
|
|
2021-09-04 00:47:43 +00:00
|
|
|
lr_scheduler_config = {
|
2021-05-07 14:10:24 +00:00
|
|
|
# REQUIRED: The scheduler instance
|
2021-07-30 12:10:15 +00:00
|
|
|
"scheduler": lr_scheduler,
|
2021-05-07 14:10:24 +00:00
|
|
|
# The unit of the scheduler's step size, could also be 'step'.
|
|
|
|
# 'epoch' updates the scheduler on epoch end whereas 'step'
|
|
|
|
# updates it after a optimizer update.
|
2021-07-30 12:10:15 +00:00
|
|
|
"interval": "epoch",
|
2021-05-07 14:10:24 +00:00
|
|
|
# How many epochs/steps should pass between calls to
|
|
|
|
# `scheduler.step()`. 1 corresponds to updating the learning
|
|
|
|
# rate after every epoch/step.
|
2021-07-30 12:10:15 +00:00
|
|
|
"frequency": 1,
|
2021-05-07 14:10:24 +00:00
|
|
|
# Metric to to monitor for schedulers like `ReduceLROnPlateau`
|
2021-07-30 12:10:15 +00:00
|
|
|
"monitor": "val_loss",
|
2021-05-07 14:10:24 +00:00
|
|
|
# If set to `True`, will enforce that the value specified 'monitor'
|
|
|
|
# is available when the scheduler is updated, thus stopping
|
|
|
|
# training if not found. If set to `False`, it will only produce a warning
|
2021-07-30 12:10:15 +00:00
|
|
|
"strict": True,
|
2021-05-07 14:10:24 +00:00
|
|
|
# If using the `LearningRateMonitor` callback to monitor the
|
|
|
|
# learning rate progress, this keyword can be used to specify
|
|
|
|
# a custom logged name
|
2021-07-30 12:10:15 +00:00
|
|
|
"name": None,
|
2021-05-07 14:10:24 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
When there are schedulers in which the ``.step()`` method is conditioned on a value, such as the
|
2021-09-04 00:47:43 +00:00
|
|
|
:class:`torch.optim.lr_scheduler.ReduceLROnPlateau` scheduler, Lightning requires that the
|
|
|
|
``lr_scheduler_config`` contains the keyword ``"monitor"`` set to the metric name that the scheduler
|
|
|
|
should be conditioned on.
|
2021-05-07 14:10:24 +00:00
|
|
|
|
|
|
|
.. testcode::
|
|
|
|
|
|
|
|
# The ReduceLROnPlateau scheduler requires a monitor
|
|
|
|
def configure_optimizers(self):
|
|
|
|
optimizer = Adam(...)
|
|
|
|
return {
|
2021-07-30 12:10:15 +00:00
|
|
|
"optimizer": optimizer,
|
|
|
|
"lr_scheduler": {
|
|
|
|
"scheduler": ReduceLROnPlateau(optimizer, ...),
|
|
|
|
"monitor": "metric_to_track",
|
2021-10-12 01:44:07 +00:00
|
|
|
"frequency": "indicates how often the metric is updated"
|
|
|
|
# If "monitor" references validation metrics, then "frequency" should be set to a
|
|
|
|
# multiple of "trainer.check_val_every_n_epoch".
|
2021-07-30 12:10:15 +00:00
|
|
|
},
|
2021-05-06 08:39:01 +00:00
|
|
|
}
|
|
|
|
|
2021-07-30 12:10:15 +00:00
|
|
|
|
2021-05-07 14:10:24 +00:00
|
|
|
# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
|
|
|
|
def configure_optimizers(self):
|
|
|
|
optimizer1 = Adam(...)
|
|
|
|
optimizer2 = SGD(...)
|
|
|
|
scheduler1 = ReduceLROnPlateau(optimizer1, ...)
|
|
|
|
scheduler2 = LambdaLR(optimizer2, ...)
|
|
|
|
return (
|
|
|
|
{
|
2021-07-30 12:10:15 +00:00
|
|
|
"optimizer": optimizer1,
|
|
|
|
"lr_scheduler": {
|
|
|
|
"scheduler": scheduler1,
|
|
|
|
"monitor": "metric_to_track",
|
|
|
|
},
|
2021-05-07 14:10:24 +00:00
|
|
|
},
|
2021-07-30 12:10:15 +00:00
|
|
|
{"optimizer": optimizer2, "lr_scheduler": scheduler2},
|
2021-05-07 14:10:24 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
Metrics can be made available to monitor by simply logging it using
|
|
|
|
``self.log('metric_to_track', metric_val)`` in your :class:`~pytorch_lightning.core.lightning.LightningModule`.
|
2021-05-06 08:39:01 +00:00
|
|
|
|
2020-03-31 16:41:24 +00:00
|
|
|
Note:
|
2021-05-04 09:37:40 +00:00
|
|
|
The ``frequency`` value specified in a dict along with the ``optimizer`` key is an int corresponding
|
|
|
|
to the number of sequential batches optimized with the specific optimizer.
|
|
|
|
It should be given to none or to all of the optimizers.
|
|
|
|
There is a difference between passing multiple optimizers in a list,
|
|
|
|
and passing multiple optimizers in dictionaries with a frequency of 1:
|
2021-05-07 14:10:24 +00:00
|
|
|
|
|
|
|
- In the former case, all optimizers will operate on the given batch in each optimization step.
|
|
|
|
- In the latter, only one optimizer will operate on the given batch at every step.
|
|
|
|
|
2021-09-04 00:47:43 +00:00
|
|
|
This is different from the ``frequency`` value specified in the ``lr_scheduler_config`` mentioned above.
|
2021-05-04 09:37:40 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01)
|
|
|
|
optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01)
|
|
|
|
return [
|
2021-07-30 12:10:15 +00:00
|
|
|
{"optimizer": optimizer_one, "frequency": 5},
|
|
|
|
{"optimizer": optimizer_two, "frequency": 10},
|
2021-05-04 09:37:40 +00:00
|
|
|
]
|
|
|
|
|
|
|
|
In this example, the first optimizer will be used for the first 5 steps,
|
|
|
|
the second optimizer for the next 10 steps and that cycle will continue.
|
|
|
|
If an LR scheduler is specified for an optimizer using the ``lr_scheduler`` key in the above dict,
|
|
|
|
the scheduler will only be updated when its optimizer is being used.
|
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
Examples::
|
|
|
|
|
2021-05-07 14:10:24 +00:00
|
|
|
# most cases. no learning rate scheduler
|
2021-01-26 09:44:54 +00:00
|
|
|
def configure_optimizers(self):
|
2021-04-19 14:08:49 +00:00
|
|
|
return Adam(self.parameters(), lr=1e-3)
|
2021-01-26 09:44:54 +00:00
|
|
|
|
|
|
|
# multiple optimizer case (e.g.: GAN)
|
|
|
|
def configure_optimizers(self):
|
2021-04-19 14:08:49 +00:00
|
|
|
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
|
|
|
dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
|
|
|
|
return gen_opt, dis_opt
|
2021-01-26 09:44:54 +00:00
|
|
|
|
|
|
|
# example with learning rate schedulers
|
|
|
|
def configure_optimizers(self):
|
2021-04-19 14:08:49 +00:00
|
|
|
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
|
|
|
dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
|
|
|
|
dis_sch = CosineAnnealing(dis_opt, T_max=10)
|
|
|
|
return [gen_opt, dis_opt], [dis_sch]
|
2021-01-26 09:44:54 +00:00
|
|
|
|
|
|
|
# example with step-based learning rate schedulers
|
2021-05-07 14:10:24 +00:00
|
|
|
# each optimizer has its own scheduler
|
2021-01-26 09:44:54 +00:00
|
|
|
def configure_optimizers(self):
|
|
|
|
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
2021-04-19 14:08:49 +00:00
|
|
|
dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
|
2021-05-07 14:10:24 +00:00
|
|
|
gen_sch = {
|
|
|
|
'scheduler': ExponentialLR(gen_opt, 0.99),
|
|
|
|
'interval': 'step' # called after each training step
|
|
|
|
}
|
2021-04-19 14:08:49 +00:00
|
|
|
dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch
|
|
|
|
return [gen_opt, dis_opt], [gen_sch, dis_sch]
|
2021-01-26 09:44:54 +00:00
|
|
|
|
|
|
|
# example with optimizer frequencies
|
|
|
|
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
|
|
|
|
# https://arxiv.org/abs/1704.00028
|
|
|
|
def configure_optimizers(self):
|
|
|
|
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
2021-04-19 14:08:49 +00:00
|
|
|
dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
|
2021-01-26 09:44:54 +00:00
|
|
|
n_critic = 5
|
|
|
|
return (
|
|
|
|
{'optimizer': dis_opt, 'frequency': n_critic},
|
|
|
|
{'optimizer': gen_opt, 'frequency': 1}
|
|
|
|
)
|
2020-03-31 16:41:24 +00:00
|
|
|
|
2020-03-20 19:49:01 +00:00
|
|
|
Note:
|
|
|
|
Some things to know:
|
|
|
|
|
2021-04-19 14:08:49 +00:00
|
|
|
- Lightning calls ``.backward()`` and ``.step()`` on each optimizer and learning rate scheduler as needed.
|
|
|
|
- If you use 16-bit precision (``precision=16``), Lightning will automatically handle the optimizers.
|
|
|
|
- If you use multiple optimizers, :meth:`training_step` will have an additional ``optimizer_idx`` parameter.
|
|
|
|
- If you use :class:`torch.optim.LBFGS`, Lightning handles the closure function automatically for you.
|
|
|
|
- If you use multiple optimizers, gradients will be calculated only for the parameters of current optimizer
|
|
|
|
at each training step.
|
|
|
|
- If you need to control how often those optimizers step or override the default ``.step()`` schedule,
|
|
|
|
override the :meth:`optimizer_step` hook.
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
2021-02-08 19:29:43 +00:00
|
|
|
rank_zero_warn("`configure_optimizers` must be implemented to be used with the Lightning Trainer")
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2021-07-06 08:18:08 +00:00
|
|
|
def manual_backward(self, loss: Tensor, *args, **kwargs) -> None:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Call this directly from your :meth:`training_step` when doing optimizations manually. By using this,
|
|
|
|
Lightning can ensure that all the proper scaling gets applied when using mixed precision.
|
2020-10-10 16:19:22 +00:00
|
|
|
|
2021-04-19 14:08:49 +00:00
|
|
|
See :ref:`manual optimization<common/optimizers:Manual optimization>` for more examples.
|
2020-11-10 19:44:51 +00:00
|
|
|
|
2020-10-10 16:19:22 +00:00
|
|
|
Example::
|
|
|
|
|
|
|
|
def training_step(...):
|
2021-04-19 14:08:49 +00:00
|
|
|
opt = self.optimizers()
|
2020-10-10 16:19:22 +00:00
|
|
|
loss = ...
|
2021-04-19 14:08:49 +00:00
|
|
|
opt.zero_grad()
|
2020-10-10 16:19:22 +00:00
|
|
|
# automatically applies scaling, etc...
|
2021-03-01 18:15:43 +00:00
|
|
|
self.manual_backward(loss)
|
2021-04-19 14:08:49 +00:00
|
|
|
opt.step()
|
2021-07-05 11:43:27 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
loss: The tensor on which to compute gradients. Must have a graph attached.
|
|
|
|
*args: Additional positional arguments to be forwarded to :meth:`~torch.Tensor.backward`
|
|
|
|
**kwargs: Additional keyword arguments to be forwarded to :meth:`~torch.Tensor.backward`
|
2020-10-10 16:19:22 +00:00
|
|
|
"""
|
2021-07-26 11:37:35 +00:00
|
|
|
self._verify_is_manual_optimization("manual_backward")
|
2021-09-03 00:15:40 +00:00
|
|
|
self.trainer.accelerator.backward(loss, None, None, *args, **kwargs)
|
2020-11-10 19:44:51 +00:00
|
|
|
|
2021-07-08 14:02:09 +00:00
|
|
|
def backward(
|
|
|
|
self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args, **kwargs
|
|
|
|
) -> None:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your
|
|
|
|
own implementation if you need to.
|
2020-10-10 22:44:24 +00:00
|
|
|
|
|
|
|
Args:
|
2021-07-05 11:43:27 +00:00
|
|
|
loss: The loss tensor returned by :meth:`training_step`. If gradient accumulation is used, the loss here
|
|
|
|
holds the normalized value (scaled by 1 / accumulation steps).
|
2021-07-08 14:02:09 +00:00
|
|
|
optimizer: Current optimizer being used. ``None`` if using manual optimization.
|
|
|
|
optimizer_idx: Index of the current optimizer being used. ``None`` if using manual optimization.
|
2020-10-10 22:44:24 +00:00
|
|
|
|
|
|
|
Example::
|
|
|
|
|
2020-10-11 02:04:50 +00:00
|
|
|
def backward(self, loss, optimizer, optimizer_idx):
|
2020-10-10 22:44:24 +00:00
|
|
|
loss.backward()
|
|
|
|
"""
|
2021-07-06 08:13:09 +00:00
|
|
|
loss.backward(*args, **kwargs)
|
2020-10-10 22:44:24 +00:00
|
|
|
|
2021-10-18 15:29:51 +00:00
|
|
|
def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer], optimizer_idx: int) -> None:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Makes sure only the gradients of the current optimizer's parameters are calculated in the training step
|
2021-10-18 15:29:51 +00:00
|
|
|
to prevent dangling gradients in multiple-optimizer setup.
|
2021-01-25 23:45:49 +00:00
|
|
|
|
2021-10-18 15:29:51 +00:00
|
|
|
This is only called automatically when automatic optimization is enabled and multiple optimizers are used.
|
|
|
|
It works with :meth:`untoggle_optimizer` to make sure ``param_requires_grad_state`` is properly reset.
|
2021-02-04 22:50:57 +00:00
|
|
|
|
2021-10-18 15:29:51 +00:00
|
|
|
Args:
|
|
|
|
optimizer: The optimizer to toggle.
|
|
|
|
optimizer_idx: The index of the optimizer to toggle.
|
2021-07-05 11:43:27 +00:00
|
|
|
"""
|
2021-02-04 22:50:57 +00:00
|
|
|
# Iterate over all optimizer parameters to preserve their `requires_grad` information
|
|
|
|
# in case these are pre-defined during `configure_optimizers`
|
2021-01-25 23:45:49 +00:00
|
|
|
param_requires_grad_state = {}
|
2021-02-04 22:50:57 +00:00
|
|
|
for opt in self.optimizers(use_pl_optimizer=False):
|
2021-01-25 23:45:49 +00:00
|
|
|
for group in opt.param_groups:
|
2021-07-26 11:37:35 +00:00
|
|
|
for param in group["params"]:
|
2021-02-04 22:50:57 +00:00
|
|
|
# If a param already appear in param_requires_grad_state, continue
|
|
|
|
if param in param_requires_grad_state:
|
|
|
|
continue
|
|
|
|
param_requires_grad_state[param] = param.requires_grad
|
|
|
|
param.requires_grad = False
|
|
|
|
|
|
|
|
# Then iterate over the current optimizer's parameters and set its `requires_grad`
|
|
|
|
# properties accordingly
|
|
|
|
for group in optimizer.param_groups:
|
2021-07-26 11:37:35 +00:00
|
|
|
for param in group["params"]:
|
2021-02-04 22:50:57 +00:00
|
|
|
param.requires_grad = param_requires_grad_state[param]
|
2021-01-25 23:45:49 +00:00
|
|
|
self._param_requires_grad_state = param_requires_grad_state
|
|
|
|
|
2021-10-18 15:29:51 +00:00
|
|
|
def untoggle_optimizer(self, optimizer_idx: int) -> None:
|
|
|
|
"""Resets the state of required gradients that were toggled with :meth:`toggle_optimizer`.
|
2021-01-25 23:45:49 +00:00
|
|
|
|
2021-10-18 15:29:51 +00:00
|
|
|
This is only called automatically when automatic optimization is enabled and multiple optimizers are used.
|
2021-07-05 11:43:27 +00:00
|
|
|
|
2021-10-18 15:29:51 +00:00
|
|
|
Args:
|
|
|
|
optimizer_idx: The index of the optimizer to untoggle.
|
2021-01-25 23:45:49 +00:00
|
|
|
"""
|
|
|
|
for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)):
|
|
|
|
if optimizer_idx != opt_idx:
|
|
|
|
for group in opt.param_groups:
|
2021-07-26 11:37:35 +00:00
|
|
|
for param in group["params"]:
|
2021-01-25 23:45:49 +00:00
|
|
|
if param in self._param_requires_grad_state:
|
|
|
|
param.requires_grad = self._param_requires_grad_state[param]
|
|
|
|
# save memory
|
2021-07-14 10:32:13 +00:00
|
|
|
self._param_requires_grad_state = {}
|
2020-10-10 18:35:25 +00:00
|
|
|
|
2021-10-13 14:45:13 +00:00
|
|
|
def clip_gradients(
|
|
|
|
self,
|
|
|
|
optimizer: Optimizer,
|
|
|
|
gradient_clip_val: Optional[Union[int, float]] = None,
|
2021-10-25 16:40:22 +00:00
|
|
|
gradient_clip_algorithm: Optional[str] = None,
|
2021-10-13 14:45:13 +00:00
|
|
|
):
|
|
|
|
"""Handles gradient clipping internally.
|
|
|
|
|
|
|
|
Note:
|
|
|
|
Do not override this method. If you want to customize gradient clipping, consider
|
|
|
|
using :meth:`configure_gradient_clipping` method.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer: Current optimizer being used.
|
|
|
|
gradient_clip_val: The value at which to clip gradients.
|
|
|
|
gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
|
|
|
|
to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm.
|
|
|
|
"""
|
|
|
|
if gradient_clip_val is None:
|
|
|
|
gradient_clip_val = self.trainer.gradient_clip_val or 0.0
|
|
|
|
elif self.trainer.gradient_clip_val is not None and self.trainer.gradient_clip_val != gradient_clip_val:
|
|
|
|
raise MisconfigurationException(
|
2021-10-25 16:40:22 +00:00
|
|
|
f"You have set `Trainer(gradient_clip_val={self.trainer.gradient_clip_val!r})`"
|
|
|
|
f" and have passed `clip_gradients(gradient_clip_val={gradient_clip_val!r})`."
|
|
|
|
" Please use only one of them."
|
2021-10-13 14:45:13 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
if gradient_clip_algorithm is None:
|
|
|
|
gradient_clip_algorithm = self.trainer.gradient_clip_algorithm or "norm"
|
|
|
|
else:
|
|
|
|
gradient_clip_algorithm = gradient_clip_algorithm.lower()
|
|
|
|
if (
|
|
|
|
self.trainer.gradient_clip_algorithm is not None
|
|
|
|
and self.trainer.gradient_clip_algorithm != gradient_clip_algorithm
|
|
|
|
):
|
|
|
|
raise MisconfigurationException(
|
2021-10-25 16:40:22 +00:00
|
|
|
f"You have set `Trainer(gradient_clip_algorithm={self.trainer.gradient_clip_algorithm.value!r})`"
|
|
|
|
f" and have passed `clip_gradients(gradient_clip_algorithm={gradient_clip_algorithm!r})"
|
|
|
|
" Please use only one of them."
|
2021-10-13 14:45:13 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
if not isinstance(gradient_clip_val, (int, float)):
|
|
|
|
raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.")
|
|
|
|
|
|
|
|
if not GradClipAlgorithmType.supported_type(gradient_clip_algorithm.lower()):
|
|
|
|
raise MisconfigurationException(
|
|
|
|
f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid."
|
|
|
|
f" Allowed algorithms: {GradClipAlgorithmType.supported_types()}."
|
|
|
|
)
|
|
|
|
|
|
|
|
gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm)
|
|
|
|
self.trainer.accelerator.clip_gradients(optimizer, gradient_clip_val, gradient_clip_algorithm)
|
|
|
|
|
|
|
|
def configure_gradient_clipping(
|
|
|
|
self,
|
|
|
|
optimizer: Optimizer,
|
|
|
|
optimizer_idx: int,
|
|
|
|
gradient_clip_val: Optional[Union[int, float]] = None,
|
|
|
|
gradient_clip_algorithm: Optional[str] = None,
|
|
|
|
):
|
|
|
|
"""Perform gradient clipping for the optimizer parameters. Called before :meth:`optimizer_step`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer: Current optimizer being used.
|
|
|
|
optimizer_idx: Index of the current optimizer being used.
|
|
|
|
gradient_clip_val: The value at which to clip gradients. By default value passed in Trainer
|
|
|
|
will be available here.
|
|
|
|
gradient_clip_algorithm: The gradient clipping algorithm to use. By default value
|
|
|
|
passed in Trainer will be available here.
|
|
|
|
|
|
|
|
Example::
|
|
|
|
|
|
|
|
# Perform gradient clipping on gradients associated with discriminator (optimizer_idx=1) in GAN
|
|
|
|
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
|
|
|
|
if optimizer_idx == 1:
|
|
|
|
# Lightning will handle the gradient clipping
|
|
|
|
self.clip_gradients(
|
|
|
|
optimizer,
|
|
|
|
gradient_clip_val=gradient_clip_val,
|
|
|
|
gradient_clip_algorithm=gradient_clip_algorithm
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
# implement your own custom logic to clip gradients for generator (optimizer_idx=0)
|
|
|
|
"""
|
|
|
|
self.clip_gradients(
|
|
|
|
optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
|
|
|
|
)
|
|
|
|
|
2020-03-12 16:47:23 +00:00
|
|
|
def optimizer_step(
|
2020-07-24 15:42:15 +00:00
|
|
|
self,
|
2021-10-18 15:29:51 +00:00
|
|
|
epoch: int,
|
|
|
|
batch_idx: int,
|
|
|
|
optimizer: Union[Optimizer, LightningOptimizer],
|
|
|
|
optimizer_idx: int = 0,
|
|
|
|
optimizer_closure: Optional[Callable[[], Any]] = None,
|
|
|
|
on_tpu: bool = False,
|
|
|
|
using_native_amp: bool = False,
|
|
|
|
using_lbfgs: bool = False,
|
2020-03-12 16:47:23 +00:00
|
|
|
) -> None:
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2020-04-06 12:12:44 +00:00
|
|
|
Override this method to adjust the default way the
|
|
|
|
:class:`~pytorch_lightning.trainer.trainer.Trainer` calls each optimizer.
|
|
|
|
By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example
|
2021-06-17 10:50:37 +00:00
|
|
|
once per optimizer. This method (and ``zero_grad()``) won't be called during the
|
|
|
|
accumulation phase when ``Trainer(accumulate_grad_batches != 1)``.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Args:
|
2020-03-12 16:47:23 +00:00
|
|
|
epoch: Current epoch
|
|
|
|
batch_idx: Index of current batch
|
|
|
|
optimizer: A PyTorch optimizer
|
2021-04-19 14:08:49 +00:00
|
|
|
optimizer_idx: If you used multiple optimizers, this indexes into that list.
|
2021-09-08 10:24:57 +00:00
|
|
|
optimizer_closure: Closure for all optimizers. This closure must be executed as it includes the
|
|
|
|
calls to ``training_step()``, ``optimizer.zero_grad()``, and ``backward()``.
|
2021-04-19 14:08:49 +00:00
|
|
|
on_tpu: ``True`` if TPU backward is required
|
|
|
|
using_native_amp: ``True`` if using native amp
|
|
|
|
using_lbfgs: True if the matching optimizer is :class:`torch.optim.LBFGS`
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
Examples::
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# DEFAULT
|
|
|
|
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
|
|
|
|
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
|
|
|
|
optimizer.step(closure=optimizer_closure)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# Alternating schedule for optimizer steps (i.e.: GANs)
|
|
|
|
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
|
|
|
|
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
|
2021-04-19 14:08:49 +00:00
|
|
|
# update generator opt every step
|
2021-01-26 09:44:54 +00:00
|
|
|
if optimizer_idx == 0:
|
2021-04-19 14:08:49 +00:00
|
|
|
optimizer.step(closure=optimizer_closure)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-04-19 14:08:49 +00:00
|
|
|
# update discriminator opt every 2 steps
|
2021-01-26 09:44:54 +00:00
|
|
|
if optimizer_idx == 1:
|
2021-04-19 14:08:49 +00:00
|
|
|
if (batch_idx + 1) % 2 == 0 :
|
2021-01-26 09:44:54 +00:00
|
|
|
optimizer.step(closure=optimizer_closure)
|
2021-09-08 10:24:57 +00:00
|
|
|
else:
|
|
|
|
# call the closure by itself to run `training_step` + `backward` without an optimizer step
|
|
|
|
optimizer_closure()
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# ...
|
|
|
|
# add as many optimizers as you want
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
Here's another example showing how to use this for more advanced things such as
|
|
|
|
learning rate warm-up:
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# learning rate warm-up
|
2021-07-30 12:10:15 +00:00
|
|
|
def optimizer_step(
|
|
|
|
self,
|
|
|
|
epoch,
|
|
|
|
batch_idx,
|
|
|
|
optimizer,
|
|
|
|
optimizer_idx,
|
|
|
|
optimizer_closure,
|
|
|
|
on_tpu,
|
|
|
|
using_native_amp,
|
|
|
|
using_lbfgs,
|
|
|
|
):
|
2021-01-26 09:44:54 +00:00
|
|
|
# warm up lr
|
|
|
|
if self.trainer.global_step < 500:
|
2021-07-30 12:10:15 +00:00
|
|
|
lr_scale = min(1.0, float(self.trainer.global_step + 1) / 500.0)
|
2021-01-26 09:44:54 +00:00
|
|
|
for pg in optimizer.param_groups:
|
2021-07-30 12:10:15 +00:00
|
|
|
pg["lr"] = lr_scale * self.learning_rate
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# update params
|
|
|
|
optimizer.step(closure=optimizer_closure)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2019-08-13 13:32:45 +00:00
|
|
|
"""
|
2020-12-11 19:24:59 +00:00
|
|
|
optimizer.step(closure=optimizer_closure)
|
2020-04-16 16:01:41 +00:00
|
|
|
|
2021-02-08 19:29:43 +00:00
|
|
|
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
|
2021-04-19 14:08:49 +00:00
|
|
|
"""Override this method to change the default behaviour of ``optimizer.zero_grad()``.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
epoch: Current epoch
|
|
|
|
batch_idx: Index of current batch
|
|
|
|
optimizer: A PyTorch optimizer
|
|
|
|
optimizer_idx: If you used multiple optimizers this indexes into that list.
|
|
|
|
|
|
|
|
Examples::
|
|
|
|
|
|
|
|
# DEFAULT
|
|
|
|
def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
# Set gradients to `None` instead of zero to improve performance.
|
|
|
|
def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
|
|
|
|
See :meth:`torch.optim.Optimizer.zero_grad` for the explanation of the above example.
|
|
|
|
"""
|
2019-08-13 13:32:45 +00:00
|
|
|
optimizer.zero_grad()
|
|
|
|
|
2021-09-09 07:45:52 +00:00
|
|
|
def tbptt_split_batch(self, batch: Any, split_size: int) -> List[Any]:
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2020-03-05 23:52:17 +00:00
|
|
|
When using truncated backpropagation through time, each batch must be split along the
|
2020-04-06 12:12:44 +00:00
|
|
|
time dimension. Lightning handles this by default, but for custom behavior override
|
2020-03-05 23:52:17 +00:00
|
|
|
this function.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Args:
|
2020-03-12 16:47:23 +00:00
|
|
|
batch: Current batch
|
2020-04-06 12:12:44 +00:00
|
|
|
split_size: The size of the split
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Return:
|
2020-04-06 12:12:44 +00:00
|
|
|
List of batch splits. Each split will be passed to :meth:`training_step` to enable truncated
|
2020-01-17 11:03:31 +00:00
|
|
|
back propagation through time. The default implementation splits root level Tensors and
|
|
|
|
Sequences at dim=1 (i.e. time dim). It assumes that each time dim is the same length.
|
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
Examples::
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
def tbptt_split_batch(self, batch, split_size):
|
2021-09-09 07:45:52 +00:00
|
|
|
splits = []
|
|
|
|
for t in range(0, time_dims[0], split_size):
|
|
|
|
batch_split = []
|
|
|
|
for i, x in enumerate(batch):
|
|
|
|
if isinstance(x, torch.Tensor):
|
|
|
|
split_x = x[:, t:t + split_size]
|
|
|
|
elif isinstance(x, collections.Sequence):
|
|
|
|
split_x = [None] * len(x)
|
|
|
|
for batch_idx in range(len(x)):
|
2021-01-26 09:44:54 +00:00
|
|
|
split_x[batch_idx] = x[batch_idx][t:t + split_size]
|
2021-09-09 07:45:52 +00:00
|
|
|
batch_split.append(split_x)
|
|
|
|
splits.append(batch_split)
|
|
|
|
return splits
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
Note:
|
|
|
|
Called in the training loop after
|
|
|
|
:meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start`
|
2021-05-05 10:21:00 +00:00
|
|
|
if :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.
|
2020-04-06 12:12:44 +00:00
|
|
|
Each returned batch split is passed separately to :meth:`training_step`.
|
2019-10-31 10:45:28 +00:00
|
|
|
"""
|
2021-02-08 19:29:43 +00:00
|
|
|
time_dims = [len(x[0]) for x in batch if isinstance(x, (torch.Tensor, collections.Sequence))]
|
2019-10-31 10:45:28 +00:00
|
|
|
assert len(time_dims) >= 1, "Unable to determine batch time dimension"
|
2021-02-08 19:29:43 +00:00
|
|
|
assert all(x == time_dims[0] for x in time_dims), "Batch time dimension length is ambiguous"
|
2019-10-31 10:45:28 +00:00
|
|
|
|
|
|
|
splits = []
|
|
|
|
for t in range(0, time_dims[0], split_size):
|
|
|
|
batch_split = []
|
|
|
|
for i, x in enumerate(batch):
|
|
|
|
if isinstance(x, torch.Tensor):
|
2021-07-26 11:37:35 +00:00
|
|
|
split_x = x[:, t : t + split_size]
|
2019-10-31 10:45:28 +00:00
|
|
|
elif isinstance(x, collections.Sequence):
|
|
|
|
split_x = [None] * len(x)
|
|
|
|
for batch_idx in range(len(x)):
|
2021-07-26 11:37:35 +00:00
|
|
|
split_x[batch_idx] = x[batch_idx][t : t + split_size]
|
2019-10-31 10:45:28 +00:00
|
|
|
|
|
|
|
batch_split.append(split_x)
|
|
|
|
|
|
|
|
splits.append(batch_split)
|
|
|
|
|
|
|
|
return splits
|
|
|
|
|
2021-07-01 10:08:16 +00:00
|
|
|
def summarize(self, mode: Optional[str] = "top", max_depth: Optional[int] = None) -> Optional[ModelSummary]:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Summarize this LightningModule.
|
2021-07-05 11:43:27 +00:00
|
|
|
|
2021-08-03 22:08:51 +00:00
|
|
|
.. deprecated:: v1.5
|
|
|
|
This method was deprecated in v1.5 in favor of `pytorch_lightning.utilities.model_summary.summarize`
|
|
|
|
and will be removed in v1.7.
|
|
|
|
|
2021-07-05 11:43:27 +00:00
|
|
|
Args:
|
|
|
|
mode: Can be either ``'top'`` (summarize only direct submodules) or ``'full'`` (summarize all layers).
|
|
|
|
|
|
|
|
.. deprecated:: v1.4
|
|
|
|
This parameter was deprecated in v1.4 in favor of `max_depth` and will be removed in v1.6.
|
|
|
|
|
|
|
|
max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the
|
|
|
|
layer summary off. Default: 1.
|
|
|
|
|
|
|
|
Return:
|
|
|
|
The model summary object
|
|
|
|
"""
|
2021-08-03 22:08:51 +00:00
|
|
|
warning_cache.deprecation(
|
|
|
|
"The `LightningModule.summarize` method is deprecated in v1.5 and will be removed in v1.7. "
|
|
|
|
"Use `pytorch_lightning.utilities.model_summary.summarize` instead.",
|
|
|
|
stacklevel=6,
|
|
|
|
)
|
2021-01-05 07:43:18 +00:00
|
|
|
|
2021-08-03 22:08:51 +00:00
|
|
|
return summarize(self, mode, max_depth)
|
2019-07-25 16:01:52 +00:00
|
|
|
|
2020-03-12 16:47:23 +00:00
|
|
|
def freeze(self) -> None:
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2020-04-06 12:12:44 +00:00
|
|
|
Freeze all params for inference.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
Example::
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
model = MyLightningModule(...)
|
|
|
|
model.freeze()
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
"""
|
2019-07-25 16:01:52 +00:00
|
|
|
for param in self.parameters():
|
|
|
|
param.requires_grad = False
|
|
|
|
|
2019-11-05 14:14:33 +00:00
|
|
|
self.eval()
|
|
|
|
|
2020-03-12 16:47:23 +00:00
|
|
|
def unfreeze(self) -> None:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Unfreeze all parameters for training.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
model = MyLightningModule(...)
|
|
|
|
model.unfreeze()
|
|
|
|
"""
|
2019-07-25 16:01:52 +00:00
|
|
|
for param in self.parameters():
|
|
|
|
param.requires_grad = True
|
2019-11-05 14:14:33 +00:00
|
|
|
|
|
|
|
self.train()
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-04-24 00:46:18 +00:00
|
|
|
def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]:
|
2020-02-05 11:24:43 +00:00
|
|
|
r"""
|
2021-09-09 20:53:47 +00:00
|
|
|
.. deprecated:: v1.5
|
|
|
|
This method was deprecated in v1.5 in favor of
|
|
|
|
`pytorch_lightning.callbacks.progress.base.get_standard_metrics` and will be removed in v1.7.
|
|
|
|
|
2020-07-28 20:32:34 +00:00
|
|
|
Implement this to override the default items displayed in the progress bar.
|
|
|
|
By default it includes the average loss value, split index of BPTT (if used)
|
|
|
|
and the version of the experiment when using a logger.
|
|
|
|
|
|
|
|
.. code-block::
|
|
|
|
|
|
|
|
Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, loss=4.501, v_num=10]
|
|
|
|
|
|
|
|
Here is an example how to override the defaults:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def get_progress_bar_dict(self):
|
|
|
|
# don't show the version number
|
|
|
|
items = super().get_progress_bar_dict()
|
|
|
|
items.pop("v_num", None)
|
|
|
|
return items
|
2020-02-05 11:24:43 +00:00
|
|
|
|
|
|
|
Return:
|
|
|
|
Dictionary with the items to be displayed in the progress bar.
|
|
|
|
"""
|
2021-09-09 20:53:47 +00:00
|
|
|
return progress_base.get_standard_metrics(self.trainer, self)
|
2020-04-24 00:46:18 +00:00
|
|
|
|
2020-10-11 17:12:35 +00:00
|
|
|
def _verify_is_manual_optimization(self, fn_name):
|
2021-04-26 05:36:26 +00:00
|
|
|
if self.automatic_optimization:
|
2020-12-10 10:01:33 +00:00
|
|
|
raise MisconfigurationException(
|
2021-07-26 11:37:35 +00:00
|
|
|
f"to use {fn_name}, please disable automatic optimization:"
|
|
|
|
" set model property `automatic_optimization` as False"
|
2020-12-10 10:01:33 +00:00
|
|
|
)
|
2020-10-11 17:12:35 +00:00
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
@classmethod
|
|
|
|
def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Collect all module arguments in the current constructor and all child constructors. The child
|
|
|
|
constructors are all the ``__init__`` methods that reach the current class through (chained)
|
|
|
|
``super().__init__()`` calls.
|
2020-06-08 11:19:34 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
frame: instance frame
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
self_arguments: arguments dictionary of the first instance
|
|
|
|
parents_arguments: arguments dictionary of the parent's instances
|
2020-06-04 12:35:50 +00:00
|
|
|
"""
|
2020-06-08 11:19:34 +00:00
|
|
|
if not frame:
|
|
|
|
frame = inspect.currentframe()
|
2020-05-24 22:59:08 +00:00
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
frame_args = collect_init_args(frame.f_back, [])
|
2020-06-04 12:35:50 +00:00
|
|
|
self_arguments = frame_args[-1]
|
2020-05-24 22:59:08 +00:00
|
|
|
|
2020-11-03 11:13:10 +00:00
|
|
|
# set hyper_parameters in child
|
2020-06-08 11:19:34 +00:00
|
|
|
self_arguments = self_arguments
|
|
|
|
parents_arguments = {}
|
2020-06-04 12:35:50 +00:00
|
|
|
|
|
|
|
# add all arguments from parents
|
2020-05-24 22:59:08 +00:00
|
|
|
for args in frame_args[:-1]:
|
2020-06-08 11:19:34 +00:00
|
|
|
parents_arguments.update(args)
|
|
|
|
return self_arguments, parents_arguments
|
2020-05-24 22:59:08 +00:00
|
|
|
|
2020-12-12 10:17:03 +00:00
|
|
|
@torch.no_grad()
|
2021-07-26 11:37:35 +00:00
|
|
|
def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs):
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Saves the model in ONNX format.
|
2020-07-31 10:27:57 +00:00
|
|
|
|
|
|
|
Args:
|
2020-12-12 10:17:03 +00:00
|
|
|
file_path: The path of the file the onnx model should be saved to.
|
|
|
|
input_sample: An input for tracing. Default: None (Use self.example_input_array)
|
2020-07-31 10:27:57 +00:00
|
|
|
**kwargs: Will be passed to torch.onnx.export function.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> class SimpleModel(LightningModule):
|
|
|
|
... def __init__(self):
|
|
|
|
... super().__init__()
|
|
|
|
... self.l1 = torch.nn.Linear(in_features=64, out_features=4)
|
|
|
|
...
|
|
|
|
... def forward(self, x):
|
|
|
|
... return torch.relu(self.l1(x.view(x.size(0), -1)))
|
|
|
|
|
|
|
|
>>> with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile:
|
|
|
|
... model = SimpleModel()
|
|
|
|
... input_sample = torch.randn((1, 64))
|
|
|
|
... model.to_onnx(tmpfile.name, input_sample, export_params=True)
|
|
|
|
... os.path.isfile(tmpfile.name)
|
|
|
|
True
|
|
|
|
"""
|
2020-12-12 10:17:03 +00:00
|
|
|
mode = self.training
|
2020-07-31 10:27:57 +00:00
|
|
|
|
2020-12-12 10:17:03 +00:00
|
|
|
if input_sample is None:
|
|
|
|
if self.example_input_array is None:
|
2020-09-21 02:59:21 +00:00
|
|
|
raise ValueError(
|
2020-12-12 10:17:03 +00:00
|
|
|
"Could not export to ONNX since neither `input_sample` nor"
|
|
|
|
" `model.example_input_array` attribute is set."
|
2020-09-21 02:59:21 +00:00
|
|
|
)
|
2020-12-12 10:17:03 +00:00
|
|
|
input_sample = self.example_input_array
|
|
|
|
|
2021-02-18 11:58:12 +00:00
|
|
|
input_sample = self._apply_batch_transfer_handler(input_sample)
|
2020-12-12 10:17:03 +00:00
|
|
|
|
2020-09-21 02:59:21 +00:00
|
|
|
if "example_outputs" not in kwargs:
|
2020-07-31 10:27:57 +00:00
|
|
|
self.eval()
|
2021-09-02 01:36:20 +00:00
|
|
|
if isinstance(input_sample, Tuple):
|
|
|
|
kwargs["example_outputs"] = self(*input_sample)
|
|
|
|
else:
|
|
|
|
kwargs["example_outputs"] = self(input_sample)
|
2020-07-31 10:27:57 +00:00
|
|
|
|
2020-12-12 10:17:03 +00:00
|
|
|
torch.onnx.export(self, input_sample, file_path, **kwargs)
|
|
|
|
self.train(mode)
|
2020-07-31 10:27:57 +00:00
|
|
|
|
2020-12-12 10:17:03 +00:00
|
|
|
@torch.no_grad()
|
2020-09-21 02:59:21 +00:00
|
|
|
def to_torchscript(
|
2020-12-12 10:17:03 +00:00
|
|
|
self,
|
|
|
|
file_path: Optional[Union[str, Path]] = None,
|
2021-07-26 11:37:35 +00:00
|
|
|
method: Optional[str] = "script",
|
2020-12-12 10:17:03 +00:00
|
|
|
example_inputs: Optional[Any] = None,
|
|
|
|
**kwargs,
|
2020-09-21 02:59:21 +00:00
|
|
|
) -> Union[ScriptModule, Dict[str, ScriptModule]]:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. If you want to use tracing,
|
|
|
|
please provided the argument ``method='trace'`` and make sure that either the `example_inputs` argument is
|
|
|
|
provided, or the model has :attr:`example_input_array` set. If you would like to customize the modules that
|
|
|
|
are scripted you should override this method. In case you want to return multiple modules, we recommend
|
|
|
|
using a dictionary.
|
2020-09-03 18:24:44 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
file_path: Path where to save the torchscript. Default: None (no file saved).
|
2020-10-14 13:20:52 +00:00
|
|
|
method: Whether to use TorchScript's script or trace method. Default: 'script'
|
2020-12-12 10:17:03 +00:00
|
|
|
example_inputs: An input to be used to do tracing when method is set to 'trace'.
|
2021-07-05 11:43:27 +00:00
|
|
|
Default: None (uses :attr:`example_input_array`)
|
2020-10-14 13:20:52 +00:00
|
|
|
**kwargs: Additional arguments that will be passed to the :func:`torch.jit.script` or
|
|
|
|
:func:`torch.jit.trace` function.
|
2020-09-03 18:24:44 +00:00
|
|
|
|
|
|
|
Note:
|
|
|
|
- Requires the implementation of the
|
|
|
|
:meth:`~pytorch_lightning.core.lightning.LightningModule.forward` method.
|
|
|
|
- The exported script will be set to evaluation mode.
|
|
|
|
- It is recommended that you install the latest supported version of PyTorch
|
|
|
|
to use this feature without limitations. See also the :mod:`torch.jit`
|
|
|
|
documentation for supported features.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> class SimpleModel(LightningModule):
|
|
|
|
... def __init__(self):
|
|
|
|
... super().__init__()
|
|
|
|
... self.l1 = torch.nn.Linear(in_features=64, out_features=4)
|
|
|
|
...
|
|
|
|
... def forward(self, x):
|
|
|
|
... return torch.relu(self.l1(x.view(x.size(0), -1)))
|
|
|
|
...
|
|
|
|
>>> model = SimpleModel()
|
|
|
|
>>> torch.jit.save(model.to_torchscript(), "model.pt") # doctest: +SKIP
|
|
|
|
>>> os.path.isfile("model.pt") # doctest: +SKIP
|
2020-10-29 05:46:57 +00:00
|
|
|
>>> torch.jit.save(model.to_torchscript(file_path="model_trace.pt", method='trace', # doctest: +SKIP
|
|
|
|
... example_inputs=torch.randn(1, 64))) # doctest: +SKIP
|
|
|
|
>>> os.path.isfile("model_trace.pt") # doctest: +SKIP
|
2020-09-03 18:24:44 +00:00
|
|
|
True
|
|
|
|
|
|
|
|
Return:
|
2021-07-05 11:43:27 +00:00
|
|
|
This LightningModule as a torchscript, regardless of whether `file_path` is
|
2020-09-03 18:24:44 +00:00
|
|
|
defined or not.
|
|
|
|
"""
|
|
|
|
mode = self.training
|
2020-12-12 10:17:03 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
if method == "script":
|
2020-12-12 10:17:03 +00:00
|
|
|
torchscript_module = torch.jit.script(self.eval(), **kwargs)
|
2021-07-26 11:37:35 +00:00
|
|
|
elif method == "trace":
|
2020-12-12 10:17:03 +00:00
|
|
|
# if no example inputs are provided, try to see if model has example_input_array set
|
|
|
|
if example_inputs is None:
|
|
|
|
if self.example_input_array is None:
|
|
|
|
raise ValueError(
|
2021-07-26 11:37:35 +00:00
|
|
|
"Choosing method=`trace` requires either `example_inputs`"
|
|
|
|
" or `model.example_input_array` to be defined."
|
2020-12-12 10:17:03 +00:00
|
|
|
)
|
|
|
|
example_inputs = self.example_input_array
|
|
|
|
|
|
|
|
# automatically send example inputs to the right device and use trace
|
2021-02-18 11:58:12 +00:00
|
|
|
example_inputs = self._apply_batch_transfer_handler(example_inputs)
|
2020-12-12 10:17:03 +00:00
|
|
|
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
|
|
|
|
else:
|
2021-02-18 11:58:12 +00:00
|
|
|
raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was: {method}")
|
2020-12-12 10:17:03 +00:00
|
|
|
|
2020-09-03 18:24:44 +00:00
|
|
|
self.train(mode)
|
|
|
|
|
|
|
|
if file_path is not None:
|
2021-05-21 11:23:15 +00:00
|
|
|
fs = get_filesystem(file_path)
|
|
|
|
with fs.open(file_path, "wb") as f:
|
|
|
|
torch.jit.save(torchscript_module, f)
|
2020-09-03 18:24:44 +00:00
|
|
|
|
2020-10-14 13:20:52 +00:00
|
|
|
return torchscript_module
|
2020-09-03 18:24:44 +00:00
|
|
|
|
2021-02-11 12:04:57 +00:00
|
|
|
@property
|
|
|
|
def model_size(self) -> float:
|
2021-07-30 13:53:40 +00:00
|
|
|
rank_zero_deprecation(
|
|
|
|
"The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7."
|
|
|
|
" Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.",
|
|
|
|
stacklevel=5,
|
|
|
|
)
|
|
|
|
return get_model_size_mb(self)
|
2021-06-23 01:19:37 +00:00
|
|
|
|
|
|
|
def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
|
|
|
|
sharing, we cast the data to numpy.
|
2021-06-23 01:19:37 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
queue: the instance of the queue to append the data.
|
2021-09-10 20:58:02 +00:00
|
|
|
|
|
|
|
.. deprecated:: v1.5
|
|
|
|
This method was deprecated in v1.5 in favor of `DDPSpawnPlugin.add_to_queue`
|
|
|
|
and will be removed in v1.7.
|
2021-06-23 01:19:37 +00:00
|
|
|
"""
|
2021-09-10 20:58:02 +00:00
|
|
|
if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin):
|
|
|
|
self.trainer.training_type_plugin.add_to_queue(self.trainer, queue)
|
2021-06-23 01:19:37 +00:00
|
|
|
|
|
|
|
def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
|
|
|
|
we cast back the data to ``torch.Tensor``.
|
2021-06-23 01:19:37 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
queue: the instance of the queue from where to get the data.
|
2021-09-10 20:58:02 +00:00
|
|
|
|
|
|
|
.. deprecated:: v1.5
|
|
|
|
This method was deprecated in v1.5 in favor of `DDPSpawnPlugin.get_from_queue`
|
|
|
|
and will be removed in v1.7.
|
2021-06-23 01:19:37 +00:00
|
|
|
"""
|
2021-09-10 20:58:02 +00:00
|
|
|
if self.trainer and isinstance(self.trainer.training_type_plugin, pl.plugins.training_type.DDPSpawnPlugin):
|
|
|
|
self.trainer.training_type_plugin.get_from_queue(self.trainer, queue)
|
2021-07-20 18:31:49 +00:00
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def _prevent_trainer_and_dataloaders_deepcopy(self) -> None:
|
|
|
|
self._should_prevent_trainer_and_dataloaders_deepcopy = True
|
|
|
|
yield
|
|
|
|
self._should_prevent_trainer_and_dataloaders_deepcopy = False
|
|
|
|
|
|
|
|
def __getstate__(self) -> Dict[str, Any]:
|
|
|
|
state = dict(self.__dict__)
|
|
|
|
if self._should_prevent_trainer_and_dataloaders_deepcopy:
|
|
|
|
state["trainer"] = None
|
|
|
|
state.pop("train_dataloader", None)
|
|
|
|
state.pop("val_dataloader", None)
|
|
|
|
state.pop("test_dataloader", None)
|
|
|
|
state.pop("predict_dataloader", None)
|
|
|
|
return state
|
2021-08-23 19:59:38 +00:00
|
|
|
|
|
|
|
def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Adds ShardedTensor state dict hooks if ShardedTensors are supported.
|
|
|
|
|
|
|
|
These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
|
2021-08-23 19:59:38 +00:00
|
|
|
"""
|
2021-10-21 21:01:56 +00:00
|
|
|
if not _TORCH_GREATER_EQUAL_DEV_1_10 or _IS_WINDOWS:
|
2021-08-23 19:59:38 +00:00
|
|
|
return
|
|
|
|
|
|
|
|
from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook
|
|
|
|
|
|
|
|
self._register_state_dict_hook(state_dict_hook)
|
|
|
|
self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
|