2020-08-20 02:03:22 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2020-11-13 15:05:54 +00:00
|
|
|
"""nn.Module with additional great features."""
|
|
|
|
|
2019-10-31 10:45:28 +00:00
|
|
|
import collections
|
2020-10-15 12:30:49 +00:00
|
|
|
import copy
|
2020-02-25 15:36:44 +00:00
|
|
|
import inspect
|
2020-12-01 00:09:46 +00:00
|
|
|
import os
|
|
|
|
import tempfile
|
2021-02-11 12:04:57 +00:00
|
|
|
import uuid
|
2021-01-27 10:02:16 +00:00
|
|
|
from abc import ABC
|
|
|
|
from argparse import Namespace
|
|
|
|
from functools import partial
|
|
|
|
from pathlib import Path
|
2021-02-22 11:01:54 +00:00
|
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
2019-10-22 08:32:40 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
import torch
|
2020-12-01 00:09:46 +00:00
|
|
|
from torch import ScriptModule, Tensor
|
|
|
|
from torch.nn import Module
|
|
|
|
from torch.optim.optimizer import Optimizer
|
|
|
|
|
2020-03-17 22:44:00 +00:00
|
|
|
from pytorch_lightning import _logger as log
|
2019-11-27 03:39:18 +00:00
|
|
|
from pytorch_lightning.core.grads import GradInformation
|
2020-09-29 17:51:44 +00:00
|
|
|
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
|
2020-01-21 20:18:32 +00:00
|
|
|
from pytorch_lightning.core.memory import ModelSummary
|
2020-12-07 12:55:49 +00:00
|
|
|
from pytorch_lightning.core.optimizer import LightningOptimizer
|
2021-01-09 12:37:44 +00:00
|
|
|
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES
|
2020-11-02 20:51:43 +00:00
|
|
|
from pytorch_lightning.core.step_result import Result
|
2020-12-21 09:15:04 +00:00
|
|
|
from pytorch_lightning.utilities import rank_zero_warn
|
2021-01-09 12:37:44 +00:00
|
|
|
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
|
2020-07-24 15:42:15 +00:00
|
|
|
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
|
2020-09-30 02:12:56 +00:00
|
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
2020-11-26 23:37:48 +00:00
|
|
|
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args
|
2020-09-21 02:59:21 +00:00
|
|
|
|
2021-02-22 11:01:54 +00:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
from pytorch_lightning.trainer.states import RunningStage
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-09-29 17:51:44 +00:00
|
|
|
class LightningModule(
|
|
|
|
ABC,
|
|
|
|
DeviceDtypeModuleMixin,
|
|
|
|
GradInformation,
|
|
|
|
ModelIO,
|
|
|
|
ModelHooks,
|
|
|
|
DataHooks,
|
|
|
|
CheckpointHooks,
|
|
|
|
Module,
|
|
|
|
):
|
2020-09-25 14:20:15 +00:00
|
|
|
# Below is for property support of JIT in PyTorch 1.7
|
|
|
|
# since none of them is important when using JIT, we are going to ignore them.
|
2020-10-01 14:37:00 +00:00
|
|
|
__jit_unused_properties__ = [
|
2020-09-29 17:51:44 +00:00
|
|
|
"datamodule",
|
|
|
|
"example_input_array",
|
|
|
|
"hparams",
|
2020-10-15 13:43:06 +00:00
|
|
|
"hparams_initial",
|
2020-09-29 17:51:44 +00:00
|
|
|
"on_gpu",
|
2020-10-05 15:10:40 +00:00
|
|
|
"current_epoch",
|
|
|
|
"global_step",
|
2021-01-27 16:38:14 +00:00
|
|
|
"running_stage",
|
2021-02-01 14:28:17 +00:00
|
|
|
"global_rank",
|
|
|
|
"local_rank",
|
|
|
|
"logger",
|
2021-02-11 12:04:57 +00:00
|
|
|
"model_size",
|
2020-10-01 14:37:00 +00:00
|
|
|
] + DeviceDtypeModuleMixin.__jit_unused_properties__
|
2020-09-25 14:20:15 +00:00
|
|
|
|
2019-07-25 16:08:00 +00:00
|
|
|
def __init__(self, *args, **kwargs):
|
2020-03-27 12:36:50 +00:00
|
|
|
super().__init__(*args, **kwargs)
|
2019-03-31 20:29:50 +00:00
|
|
|
|
2020-09-22 09:05:33 +00:00
|
|
|
# see (https://github.com/pytorch/pytorch/blob/3e6bb5233f9ca2c5aa55d9cda22a7ee85439aa6e/
|
|
|
|
# torch/nn/modules/module.py#L227)
|
|
|
|
torch._C._log_api_usage_once(f"lightning.module.{self.__class__.__name__}")
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
self.exp_save_path = None
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
self.loaded_optimizer_states_dict = {}
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
#: Pointer to the trainer object
|
2019-04-23 11:25:09 +00:00
|
|
|
self.trainer = None
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2021-01-12 10:22:37 +00:00
|
|
|
self._distrib_type = None
|
|
|
|
self._device_type = None
|
2020-05-24 22:59:08 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
#: True if using amp
|
2019-08-24 01:23:27 +00:00
|
|
|
self.use_amp = False
|
2019-03-31 20:29:50 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
#: The precision used
|
|
|
|
self.precision = 32
|
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
# optionally can be set by user
|
|
|
|
self._example_input_array = None
|
2020-07-24 15:42:15 +00:00
|
|
|
self._datamodule = None
|
2020-10-13 20:47:23 +00:00
|
|
|
self._results: Optional[Result] = None
|
2020-09-28 00:26:16 +00:00
|
|
|
self._current_fx_name = ''
|
2020-11-10 19:44:51 +00:00
|
|
|
self._running_manual_backward = False
|
2020-11-02 20:51:43 +00:00
|
|
|
self._current_hook_fx_name = None
|
|
|
|
self._current_dataloader_idx = None
|
2021-01-11 16:21:10 +00:00
|
|
|
self._automatic_optimization: bool = True
|
2020-06-15 21:05:58 +00:00
|
|
|
|
2021-01-08 21:13:12 +00:00
|
|
|
def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
|
|
|
|
if use_pl_optimizer:
|
|
|
|
opts = list(self.trainer.lightning_optimizers.values())
|
|
|
|
else:
|
|
|
|
opts = self.trainer.optimizers
|
2020-10-11 13:35:51 +00:00
|
|
|
|
|
|
|
# single optimizer
|
|
|
|
if isinstance(opts, list) and len(opts) == 1 and isinstance(opts[0], Optimizer):
|
|
|
|
return opts[0]
|
|
|
|
# multiple opts
|
2020-12-01 00:09:46 +00:00
|
|
|
return opts
|
2020-10-10 16:19:22 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
@property
|
|
|
|
def example_input_array(self) -> Any:
|
|
|
|
return self._example_input_array
|
|
|
|
|
2020-10-05 15:10:40 +00:00
|
|
|
@property
|
|
|
|
def current_epoch(self) -> int:
|
|
|
|
"""The current epoch"""
|
|
|
|
return self.trainer.current_epoch if self.trainer else 0
|
|
|
|
|
|
|
|
@property
|
|
|
|
def global_step(self) -> int:
|
|
|
|
"""Total training batches seen across all epochs"""
|
|
|
|
return self.trainer.global_step if self.trainer else 0
|
|
|
|
|
2021-02-01 14:28:17 +00:00
|
|
|
@property
|
|
|
|
def global_rank(self) -> int:
|
|
|
|
""" The index of the current process across all nodes and devices. """
|
|
|
|
return self.trainer.global_rank if self.trainer else 0
|
|
|
|
|
|
|
|
@property
|
|
|
|
def local_rank(self) -> int:
|
|
|
|
""" The index of the current process within a single node. """
|
|
|
|
return self.trainer.local_rank if self.trainer else 0
|
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
@example_input_array.setter
|
|
|
|
def example_input_array(self, example: Any) -> None:
|
|
|
|
self._example_input_array = example
|
|
|
|
|
2020-07-24 15:42:15 +00:00
|
|
|
@property
|
|
|
|
def datamodule(self) -> Any:
|
|
|
|
return self._datamodule
|
|
|
|
|
|
|
|
@datamodule.setter
|
|
|
|
def datamodule(self, datamodule: Any) -> None:
|
|
|
|
self._datamodule = datamodule
|
|
|
|
|
2020-05-17 12:20:51 +00:00
|
|
|
@property
|
|
|
|
def on_gpu(self):
|
|
|
|
"""
|
|
|
|
True if your model is currently running on GPUs.
|
|
|
|
Useful to set flags around the LightningModule for different CPU vs GPU behavior.
|
|
|
|
"""
|
2020-09-21 02:59:21 +00:00
|
|
|
return self.device.type == "cuda"
|
2020-05-17 12:20:51 +00:00
|
|
|
|
2020-11-14 04:43:42 +00:00
|
|
|
@property
|
|
|
|
def automatic_optimization(self) -> bool:
|
|
|
|
"""
|
|
|
|
If False you are responsible for calling .backward, .step, zero_grad.
|
|
|
|
"""
|
2021-01-11 16:21:10 +00:00
|
|
|
return self._automatic_optimization
|
|
|
|
|
2021-02-22 11:01:54 +00:00
|
|
|
@property
|
|
|
|
def running_stage(self) -> Optional["RunningStage"]:
|
|
|
|
return self.trainer._running_stage if self.trainer else None
|
|
|
|
|
2021-01-11 16:21:10 +00:00
|
|
|
@automatic_optimization.setter
|
|
|
|
def automatic_optimization(self, automatic_optimization: bool) -> None:
|
|
|
|
self._automatic_optimization = automatic_optimization
|
|
|
|
|
2021-02-01 14:28:17 +00:00
|
|
|
@property
|
|
|
|
def logger(self):
|
|
|
|
""" Reference to the logger object in the Trainer. """
|
|
|
|
return self.trainer.logger if self.trainer else None
|
|
|
|
|
2021-02-18 19:24:19 +00:00
|
|
|
def _apply_batch_transfer_handler(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0):
|
|
|
|
batch = self.on_before_batch_transfer(batch, dataloader_idx)
|
2021-02-18 11:58:12 +00:00
|
|
|
batch = self.transfer_batch_to_device(batch, device)
|
2021-02-18 19:24:19 +00:00
|
|
|
batch = self.on_after_batch_transfer(batch, dataloader_idx)
|
2021-02-18 11:58:12 +00:00
|
|
|
return batch
|
|
|
|
|
2020-03-12 16:47:23 +00:00
|
|
|
def print(self, *args, **kwargs) -> None:
|
2020-02-25 03:30:53 +00:00
|
|
|
r"""
|
2020-04-06 12:12:44 +00:00
|
|
|
Prints only from process 0. Use this in any distributed mode to log only once.
|
2020-02-25 03:30:53 +00:00
|
|
|
|
|
|
|
Args:
|
2021-02-22 09:40:18 +00:00
|
|
|
*args: The thing to print. The same as for Python's built-in print function.
|
|
|
|
**kwargs: The same as for Python's built-in print function.
|
2020-04-06 12:12:44 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
Example::
|
2020-02-25 03:30:53 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
def forward(self, x):
|
|
|
|
self.print(x, 'in forward')
|
2020-02-25 03:30:53 +00:00
|
|
|
|
|
|
|
"""
|
2020-06-13 16:00:14 +00:00
|
|
|
if self.trainer.is_global_zero:
|
2021-02-22 09:40:18 +00:00
|
|
|
progress_bar = self.trainer.progress_bar_callback
|
|
|
|
if progress_bar is not None and progress_bar.is_enabled:
|
|
|
|
progress_bar.print(*args, **kwargs)
|
|
|
|
else:
|
|
|
|
print(*args, **kwargs)
|
2020-02-25 03:30:53 +00:00
|
|
|
|
2020-09-28 00:26:16 +00:00
|
|
|
def log(
|
|
|
|
self,
|
|
|
|
name: str,
|
|
|
|
value: Any,
|
|
|
|
prog_bar: bool = False,
|
|
|
|
logger: bool = True,
|
2020-10-15 21:02:50 +00:00
|
|
|
on_step: Optional[bool] = None,
|
|
|
|
on_epoch: Optional[bool] = None,
|
2020-09-28 00:26:16 +00:00
|
|
|
reduce_fx: Callable = torch.mean,
|
|
|
|
tbptt_reduce_fx: Callable = torch.mean,
|
|
|
|
tbptt_pad_token: int = 0,
|
|
|
|
enable_graph: bool = False,
|
|
|
|
sync_dist: bool = False,
|
|
|
|
sync_dist_op: Union[Any, str] = 'mean',
|
|
|
|
sync_dist_group: Optional[Any] = None,
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Log a key, value
|
|
|
|
|
|
|
|
Example::
|
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
self.log('train_loss', loss)
|
2020-09-28 00:26:16 +00:00
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
The default behavior per hook is as follows
|
2020-09-28 00:26:16 +00:00
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
.. csv-table:: ``*`` also applies to the test loop
|
|
|
|
:header: "LightningMoule Hook", "on_step", "on_epoch", "prog_bar", "logger"
|
|
|
|
:widths: 20, 10, 10, 10, 10
|
|
|
|
|
|
|
|
"training_step", "T", "F", "F", "T"
|
|
|
|
"training_step_end", "T", "F", "F", "T"
|
|
|
|
"training_epoch_end", "F", "T", "F", "T"
|
|
|
|
"validation_step*", "F", "T", "F", "T"
|
|
|
|
"validation_step_end*", "F", "T", "F", "T"
|
|
|
|
"validation_epoch_end*", "F", "T", "F", "T"
|
2020-09-28 00:26:16 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
name: key name
|
|
|
|
value: value name
|
2020-09-30 12:31:16 +00:00
|
|
|
prog_bar: if True logs to the progress bar
|
2020-09-28 00:26:16 +00:00
|
|
|
logger: if True logs to the logger
|
2020-09-30 12:31:16 +00:00
|
|
|
on_step: if True logs at this step. None auto-logs at the training_step but not validation/test_step
|
|
|
|
on_epoch: if True logs epoch accumulated metrics. None auto-logs at the val/test step but not training_step
|
2020-10-22 00:51:39 +00:00
|
|
|
reduce_fx: reduction function over step values for end of epoch. Torch.mean by default
|
2020-09-28 00:26:16 +00:00
|
|
|
tbptt_reduce_fx: function to reduce on truncated back prop
|
|
|
|
tbptt_pad_token: token to use for padding
|
|
|
|
enable_graph: if True, will not auto detach the graph
|
|
|
|
sync_dist: if True, reduces the metric across GPUs/TPUs
|
2020-10-22 00:51:39 +00:00
|
|
|
sync_dist_op: the op to sync across GPUs/TPUs
|
2020-09-28 00:26:16 +00:00
|
|
|
sync_dist_group: the ddp group
|
|
|
|
"""
|
|
|
|
if self._results is not None:
|
|
|
|
# in any epoch end can't log step metrics (only epoch metric)
|
|
|
|
if 'epoch_end' in self._current_fx_name and on_step:
|
2020-09-30 02:12:56 +00:00
|
|
|
m = f'on_step=True cannot be used on {self._current_fx_name} method'
|
|
|
|
raise MisconfigurationException(m)
|
|
|
|
|
2020-10-19 20:20:17 +00:00
|
|
|
if 'epoch_end' in self._current_fx_name and on_epoch is False:
|
2020-09-30 02:12:56 +00:00
|
|
|
m = f'on_epoch cannot be False when called from the {self._current_fx_name} method'
|
|
|
|
raise MisconfigurationException(m)
|
|
|
|
|
|
|
|
# add log_dict
|
|
|
|
# TODO: if logged twice fail with crash
|
2020-09-28 00:26:16 +00:00
|
|
|
|
2020-09-29 06:00:28 +00:00
|
|
|
# set the default depending on the fx_name
|
|
|
|
on_step = self.__auto_choose_log_on_step(on_step)
|
|
|
|
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)
|
|
|
|
|
2020-11-02 20:51:43 +00:00
|
|
|
if self._current_hook_fx_name is not None:
|
|
|
|
self.trainer.logger_connector.check_logging_in_callbacks(
|
2021-02-08 19:29:43 +00:00
|
|
|
self._current_hook_fx_name, on_step=on_step, on_epoch=on_epoch
|
2020-11-02 20:51:43 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
# make sure user doesn't introduce logic for multi-dataloaders
|
|
|
|
if "/dataloader_idx_" in name:
|
|
|
|
raise MisconfigurationException(
|
2021-02-08 19:29:43 +00:00
|
|
|
f"Logged key: {name} should not contain information about dataloader_idx."
|
|
|
|
)
|
2020-11-02 20:51:43 +00:00
|
|
|
|
PoC: Accelerator refactor (#5743)
* restoring the result from subprocess
* fix queue.get() order for results
* add missing "block_backward_sync" context manager
* add missing "block_backward_sync" context manager
* fix sync_batchnorm
* fix supported gpu-ids for tuple
* fix clip gradients and inf recursion
* accelerator selection: added cluster_environment plugin
* fix torchelastic test
* fix reduce early stopping decision for DDP
* fix tests: callbacks, conversion to lightning optimizer
* fix lightning optimizer does not pickle
* fix setting benchmark and deterministic option
* fix slurm amp test
* fix prepare_data test and determine node_rank
* fix retrieving last path when testing
* remove obsolete plugin argument
* fix test: test_trainer_config
* fix torchscript tests
* fix trainer.model access
* move properties
* fix test_transfer_batch_hook
* fix auto_select_gpus
* fix omegaconf test
* fix test that needs to simulate slurm ddp
* add horovod plugin
* fix test with named arguments
* clean up whitespace
* fix datamodules test
* remove old accelerators
* fix naming
* move old plugins
* move to plugins
* create precision subpackage
* create training_type subpackage
* fix all new import errors
* fix wrong arguments order passed to test
* fix LR finder
* Added sharded training type and amp plugin
* Move clip grad to precision plugin
* Added sharded spawn, select accelerators based on distributed_backend + enable custom fp16 plugin automatically
* Fix import issue, attempting to fix tests
* Fix initial test
* Reflect hook logic from master, should wrap model after move to device
* Optional state consolidation, since master has optimizers not wrapped
* change attribute for instance test
* reset optimizers
optimizers are not used in main process, so state would be wrong.
* legacy
* imports in accel
* legacy2
* trainer imports
* fix import errors after rebase
* move hook to new setup location
* provide unwrapping logic
* fix trainer callback system
* added ddp2 implementation
* fix imports .legacy
* move plugins
* restore legacy
* drop test.py from root
* add tpu accelerator and plugins
* fixes
* fix lightning optimizer merge
* reset bugreportmodel
* unwrapping
* step routing forward
* model access
* unwrap
* opt
* integrate distrib_type
* sync changes
* sync
* fixes
* add forgotten generators
* add missing logic
* update
* import
* missed imports
* import fixes
* isort
* mv f
* changelog
* format
* move helper to parallel plugin
* d
* add world size
* clean up
* duplicate
* activate ddp_sharded and tpu
* set nvidia flags
* remove unused colab var
* use_tpu <-> on_tpu attrs
* make some ddp_cpu and clusterplugin tests pass
* Ref/accelerator connector (#5742)
* final cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* connector cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* trainer cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* accelerator cleanup + missing logic in accelerator connector
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* add missing changes to callbacks
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* reflect accelerator changes to lightning module
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* clean cluster envs
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* cleanup plugins
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* add broadcasting
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* yapf
* remove plugin connector
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* plugins
* manual optimization
* update optimizer routing
* add rank to torchelastic
* fix memory mixed precision
* setstate on trainer for pickling in ddp spawn
* add predict method
* add back commented accelerator code
* adapt test for sync_batch_norm to new plugin
* fix deprecated tests
* fix ddp cpu choice when no num_processes are given
* yapf format
* skip a memory test that cannot pass anymore
* fix pickle error in spawn plugin
* x
* avoid
* x
* fix cyclic import in docs build
* add support for sharded
* update typing
* add sharded and sharded_spawn to distributed types
* make unwrap model default
* refactor LightningShardedDataParallel similar to LightningDistributedDataParallel
* update sharded spawn to reflect changes
* update sharded to reflect changes
* Merge 1.1.5 changes
* fix merge
* fix merge
* yapf isort
* fix merge
* yapf isort
* fix indentation in test
* copy over reinit scheduler implementation from dev1.2
* fix apex tracking calls with dev_debugger
* reduce diff to dev1.2, clean up
* fix trainer config test when gpus>0 and num_processes >0 and ddp_cpu
* sort plugin tests legacy/new
* fix error handling for amp on cpu
* fix merge
fix merge
fix merge
* [Feat] Resolve manual_backward (#5837)
* resolve manual_backward
* resolve flake8
* update
* resolve for ddp_spawn
* resolve flake8
* resolve flake8
* resolve flake8
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* fix tests/accelerator tests on cpu
* [BugFix] Resolve manual optimization (#5852)
* resolve manual_optimization
* update
* update
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* Remove copy trainer parameters to happen earlier within the loop and add safe guard to get ref model (#5856)
* resovle a bug
* Accelerator refactor sharded rpc (#5854)
* rpc branch
* merge
* update handling of rpc
* make devices etc. Optional in RPC
* set devices etc. later if necessary
* remove devices from sequential
* make devices optional in rpc
* fix import
* uncomment everything
* fix cluster selection
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* resolve bug
* fix assert in rpc test
* resolve a test
* fix docs compilation
* accelerator refactor - fix for sharded parity test (#5866)
* fix memory issue with ddp_spawn
* x
x
x
x
x
x
x
x
x
* x
* Remove DDP2 as this does not apply
* Add missing pre optimizer hook to ensure lambda closure is called
* fix apex docstring
* [accelerator][BugFix] Resolve some test for 1 gpu (#5863)
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* update
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* revert init
* update
* resolve flake8
* update
* update
* update
* update
* update
* all_gather
* update
* make plugins work, add misconfig for RPC
* update
* update
* remove breaking test
* resolve some tests
* resolve flake8
* revert to ddp_spawn
Co-authored-by: root <root@ip-172-31-88-60.ec2.internal>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de>
* yapf isort
* resolve flake8
* fix apex doctests
* fix apex doctests 2
* resolve docs
* update drone
* clean env
* update
* update
* update
* update
* merge
* Fix RPC related tests, clean out old API, update for new accelerator API [skip ci] (#5881)
* Fix RPC related tests, clean out old API, update for new accelerator API
* Move tests out of legacy folder, update paths and names
* Update test_remove_1-4.py
* Expose properties for tpu cores/gpus/num_gpus
* Add root GPU property
* Move properties to properties.py
* move tests that were previously in drone
* Fix root GPU property (#5908)
* Move root GPU to property, remove horovod set as this is handled in horovod plugin, ensure we mock correctly to set GPU accelerator
* Add missing tests back
* fix best model path transfer when no checkpoint callback available
* Fix setup hook order [wip] (#5858)
* Call trainer setup hook before accelerator setup
* Add test case
* add new test
* typo
* fix callback order in test
Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* rename ddp sequential -> rpc sequential for special test
* revert
* fix stupid merge problem
* Use property in connector for sampler (#5913)
* merge the import conflicts
* fix spawning of processes in slurm
* [wip] Fix some bugs for TPU [skip ci] (#5878)
* fixed for single tpu
* fixed spawn
* fixed spawn
* update
* update
* wip
* resolve bugs
* resolve bug
* update on comment
* removed decorator
* resolve comments
* set to 4
* update
* update
* need cleaning
* update
* update
* update
* resolve flake8
* resolve bugs
* exclude broadcast
* resolve bugs
* change test
* update
* update
* skip if meet fails
* properly raise trace
* update
* add catch
* wrap test
* resolve typo
* update
* typo
Co-authored-by: Lezwon Castelino <lezwon@gmail.com>
Co-authored-by: Your Name <you@example.com>
* resolve some tests
* update
* fix imports
* update
* resolve flake8
* update azure pipeline
* skip a sharded test on cpu that requires a gpu
* resolve tpus
* resolve bug
* resolve flake8
* update
* updat utils
* revert permission change on files
* suggestions from carlos
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
* remove unrelated formatting changes
* remove incomplete comment
* Update pytorch_lightning/accelerators/__init__.py
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
* remove unrelated formatting change
* add types
* warn 1.7 ddp manual backward only if ddp kwarg unset
* yapf + isort
* pep8 unused imports
* fix cyclic import in docs
* Apply suggestions from code review
* typer in accelerator.py
* typo
* Apply suggestions from code review
* formatting
* update on comments
* update typo
* Update pytorch_lightning/trainer/properties.py
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* update
* suggestion from code review
* suggestion from code review
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: root <root@ip-172-31-88-60.ec2.internal>
Co-authored-by: Lezwon Castelino <lezwon@gmail.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
2021-02-12 20:48:56 +00:00
|
|
|
training_type_plugin = self.trainer.training_type_plugin
|
2020-11-05 17:52:02 +00:00
|
|
|
|
2020-09-28 00:26:16 +00:00
|
|
|
self._results.log(
|
|
|
|
name,
|
|
|
|
value,
|
|
|
|
prog_bar,
|
|
|
|
logger,
|
|
|
|
on_step,
|
|
|
|
on_epoch,
|
|
|
|
reduce_fx,
|
|
|
|
tbptt_reduce_fx,
|
|
|
|
tbptt_pad_token,
|
|
|
|
enable_graph,
|
|
|
|
sync_dist,
|
|
|
|
sync_dist_op,
|
2020-11-02 20:51:43 +00:00
|
|
|
sync_dist_group,
|
PoC: Accelerator refactor (#5743)
* restoring the result from subprocess
* fix queue.get() order for results
* add missing "block_backward_sync" context manager
* add missing "block_backward_sync" context manager
* fix sync_batchnorm
* fix supported gpu-ids for tuple
* fix clip gradients and inf recursion
* accelerator selection: added cluster_environment plugin
* fix torchelastic test
* fix reduce early stopping decision for DDP
* fix tests: callbacks, conversion to lightning optimizer
* fix lightning optimizer does not pickle
* fix setting benchmark and deterministic option
* fix slurm amp test
* fix prepare_data test and determine node_rank
* fix retrieving last path when testing
* remove obsolete plugin argument
* fix test: test_trainer_config
* fix torchscript tests
* fix trainer.model access
* move properties
* fix test_transfer_batch_hook
* fix auto_select_gpus
* fix omegaconf test
* fix test that needs to simulate slurm ddp
* add horovod plugin
* fix test with named arguments
* clean up whitespace
* fix datamodules test
* remove old accelerators
* fix naming
* move old plugins
* move to plugins
* create precision subpackage
* create training_type subpackage
* fix all new import errors
* fix wrong arguments order passed to test
* fix LR finder
* Added sharded training type and amp plugin
* Move clip grad to precision plugin
* Added sharded spawn, select accelerators based on distributed_backend + enable custom fp16 plugin automatically
* Fix import issue, attempting to fix tests
* Fix initial test
* Reflect hook logic from master, should wrap model after move to device
* Optional state consolidation, since master has optimizers not wrapped
* change attribute for instance test
* reset optimizers
optimizers are not used in main process, so state would be wrong.
* legacy
* imports in accel
* legacy2
* trainer imports
* fix import errors after rebase
* move hook to new setup location
* provide unwrapping logic
* fix trainer callback system
* added ddp2 implementation
* fix imports .legacy
* move plugins
* restore legacy
* drop test.py from root
* add tpu accelerator and plugins
* fixes
* fix lightning optimizer merge
* reset bugreportmodel
* unwrapping
* step routing forward
* model access
* unwrap
* opt
* integrate distrib_type
* sync changes
* sync
* fixes
* add forgotten generators
* add missing logic
* update
* import
* missed imports
* import fixes
* isort
* mv f
* changelog
* format
* move helper to parallel plugin
* d
* add world size
* clean up
* duplicate
* activate ddp_sharded and tpu
* set nvidia flags
* remove unused colab var
* use_tpu <-> on_tpu attrs
* make some ddp_cpu and clusterplugin tests pass
* Ref/accelerator connector (#5742)
* final cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* connector cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* trainer cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* accelerator cleanup + missing logic in accelerator connector
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* add missing changes to callbacks
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* reflect accelerator changes to lightning module
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* clean cluster envs
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* cleanup plugins
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* add broadcasting
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* yapf
* remove plugin connector
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* plugins
* manual optimization
* update optimizer routing
* add rank to torchelastic
* fix memory mixed precision
* setstate on trainer for pickling in ddp spawn
* add predict method
* add back commented accelerator code
* adapt test for sync_batch_norm to new plugin
* fix deprecated tests
* fix ddp cpu choice when no num_processes are given
* yapf format
* skip a memory test that cannot pass anymore
* fix pickle error in spawn plugin
* x
* avoid
* x
* fix cyclic import in docs build
* add support for sharded
* update typing
* add sharded and sharded_spawn to distributed types
* make unwrap model default
* refactor LightningShardedDataParallel similar to LightningDistributedDataParallel
* update sharded spawn to reflect changes
* update sharded to reflect changes
* Merge 1.1.5 changes
* fix merge
* fix merge
* yapf isort
* fix merge
* yapf isort
* fix indentation in test
* copy over reinit scheduler implementation from dev1.2
* fix apex tracking calls with dev_debugger
* reduce diff to dev1.2, clean up
* fix trainer config test when gpus>0 and num_processes >0 and ddp_cpu
* sort plugin tests legacy/new
* fix error handling for amp on cpu
* fix merge
fix merge
fix merge
* [Feat] Resolve manual_backward (#5837)
* resolve manual_backward
* resolve flake8
* update
* resolve for ddp_spawn
* resolve flake8
* resolve flake8
* resolve flake8
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* fix tests/accelerator tests on cpu
* [BugFix] Resolve manual optimization (#5852)
* resolve manual_optimization
* update
* update
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* Remove copy trainer parameters to happen earlier within the loop and add safe guard to get ref model (#5856)
* resovle a bug
* Accelerator refactor sharded rpc (#5854)
* rpc branch
* merge
* update handling of rpc
* make devices etc. Optional in RPC
* set devices etc. later if necessary
* remove devices from sequential
* make devices optional in rpc
* fix import
* uncomment everything
* fix cluster selection
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* resolve bug
* fix assert in rpc test
* resolve a test
* fix docs compilation
* accelerator refactor - fix for sharded parity test (#5866)
* fix memory issue with ddp_spawn
* x
x
x
x
x
x
x
x
x
* x
* Remove DDP2 as this does not apply
* Add missing pre optimizer hook to ensure lambda closure is called
* fix apex docstring
* [accelerator][BugFix] Resolve some test for 1 gpu (#5863)
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* update
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* revert init
* update
* resolve flake8
* update
* update
* update
* update
* update
* all_gather
* update
* make plugins work, add misconfig for RPC
* update
* update
* remove breaking test
* resolve some tests
* resolve flake8
* revert to ddp_spawn
Co-authored-by: root <root@ip-172-31-88-60.ec2.internal>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de>
* yapf isort
* resolve flake8
* fix apex doctests
* fix apex doctests 2
* resolve docs
* update drone
* clean env
* update
* update
* update
* update
* merge
* Fix RPC related tests, clean out old API, update for new accelerator API [skip ci] (#5881)
* Fix RPC related tests, clean out old API, update for new accelerator API
* Move tests out of legacy folder, update paths and names
* Update test_remove_1-4.py
* Expose properties for tpu cores/gpus/num_gpus
* Add root GPU property
* Move properties to properties.py
* move tests that were previously in drone
* Fix root GPU property (#5908)
* Move root GPU to property, remove horovod set as this is handled in horovod plugin, ensure we mock correctly to set GPU accelerator
* Add missing tests back
* fix best model path transfer when no checkpoint callback available
* Fix setup hook order [wip] (#5858)
* Call trainer setup hook before accelerator setup
* Add test case
* add new test
* typo
* fix callback order in test
Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* rename ddp sequential -> rpc sequential for special test
* revert
* fix stupid merge problem
* Use property in connector for sampler (#5913)
* merge the import conflicts
* fix spawning of processes in slurm
* [wip] Fix some bugs for TPU [skip ci] (#5878)
* fixed for single tpu
* fixed spawn
* fixed spawn
* update
* update
* wip
* resolve bugs
* resolve bug
* update on comment
* removed decorator
* resolve comments
* set to 4
* update
* update
* need cleaning
* update
* update
* update
* resolve flake8
* resolve bugs
* exclude broadcast
* resolve bugs
* change test
* update
* update
* skip if meet fails
* properly raise trace
* update
* add catch
* wrap test
* resolve typo
* update
* typo
Co-authored-by: Lezwon Castelino <lezwon@gmail.com>
Co-authored-by: Your Name <you@example.com>
* resolve some tests
* update
* fix imports
* update
* resolve flake8
* update azure pipeline
* skip a sharded test on cpu that requires a gpu
* resolve tpus
* resolve bug
* resolve flake8
* update
* updat utils
* revert permission change on files
* suggestions from carlos
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
* remove unrelated formatting changes
* remove incomplete comment
* Update pytorch_lightning/accelerators/__init__.py
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
* remove unrelated formatting change
* add types
* warn 1.7 ddp manual backward only if ddp kwarg unset
* yapf + isort
* pep8 unused imports
* fix cyclic import in docs
* Apply suggestions from code review
* typer in accelerator.py
* typo
* Apply suggestions from code review
* formatting
* update on comments
* update typo
* Update pytorch_lightning/trainer/properties.py
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* update
* suggestion from code review
* suggestion from code review
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: root <root@ip-172-31-88-60.ec2.internal>
Co-authored-by: Lezwon Castelino <lezwon@gmail.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
2021-02-12 20:48:56 +00:00
|
|
|
training_type_plugin.reduce,
|
2020-11-02 20:51:43 +00:00
|
|
|
self._current_dataloader_idx,
|
2020-12-16 21:06:54 +00:00
|
|
|
self.device,
|
2020-09-28 00:26:16 +00:00
|
|
|
)
|
|
|
|
|
2020-09-30 02:12:56 +00:00
|
|
|
def log_dict(
|
|
|
|
self,
|
|
|
|
dictionary: dict,
|
|
|
|
prog_bar: bool = False,
|
|
|
|
logger: bool = True,
|
2020-10-15 21:02:50 +00:00
|
|
|
on_step: Optional[bool] = None,
|
|
|
|
on_epoch: Optional[bool] = None,
|
2020-09-30 02:12:56 +00:00
|
|
|
reduce_fx: Callable = torch.mean,
|
|
|
|
tbptt_reduce_fx: Callable = torch.mean,
|
|
|
|
tbptt_pad_token: int = 0,
|
|
|
|
enable_graph: bool = False,
|
|
|
|
sync_dist: bool = False,
|
|
|
|
sync_dist_op: Union[Any, str] = 'mean',
|
|
|
|
sync_dist_group: Optional[Any] = None,
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Log a dictonary of values at once
|
|
|
|
|
|
|
|
Example::
|
|
|
|
|
|
|
|
values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n}
|
|
|
|
self.log_dict(values)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dictionary: key value pairs (str, tensors)
|
|
|
|
prog_bar: if True logs to the progress base
|
|
|
|
logger: if True logs to the logger
|
|
|
|
on_step: if True logs at this step. None auto-logs for training_step but not validation/test_step
|
|
|
|
on_epoch: if True logs epoch accumulated metrics. None auto-logs for val/test step but not training_step
|
2020-10-22 00:51:39 +00:00
|
|
|
reduce_fx: reduction function over step values for end of epoch. Torch.mean by default
|
2020-09-30 02:12:56 +00:00
|
|
|
tbptt_reduce_fx: function to reduce on truncated back prop
|
|
|
|
tbptt_pad_token: token to use for padding
|
|
|
|
enable_graph: if True, will not auto detach the graph
|
|
|
|
sync_dist: if True, reduces the metric across GPUs/TPUs
|
2020-10-22 00:51:39 +00:00
|
|
|
sync_dist_op: the op to sync across GPUs/TPUs
|
2020-09-30 02:12:56 +00:00
|
|
|
sync_dist_group: the ddp group:
|
|
|
|
"""
|
|
|
|
for k, v in dictionary.items():
|
|
|
|
self.log(
|
|
|
|
name=k,
|
|
|
|
value=v,
|
|
|
|
prog_bar=prog_bar,
|
|
|
|
logger=logger,
|
|
|
|
on_step=on_step,
|
|
|
|
on_epoch=on_epoch,
|
|
|
|
reduce_fx=reduce_fx,
|
|
|
|
enable_graph=enable_graph,
|
|
|
|
sync_dist=sync_dist,
|
|
|
|
sync_dist_group=sync_dist_group,
|
|
|
|
sync_dist_op=sync_dist_op,
|
|
|
|
tbptt_pad_token=tbptt_pad_token,
|
|
|
|
tbptt_reduce_fx=tbptt_reduce_fx,
|
|
|
|
)
|
|
|
|
|
2021-02-11 14:32:32 +00:00
|
|
|
def write_prediction(
|
|
|
|
self, name: str, value: Union[torch.Tensor, List[torch.Tensor]], filename: str = 'predictions.pt'
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Write predictions to disk using ``torch.save``
|
|
|
|
|
|
|
|
Example::
|
|
|
|
|
|
|
|
self.write_prediction('pred', torch.tensor(...), filename='my_predictions.pt')
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name: a string indicating the name to save the predictions under
|
|
|
|
value: the predictions, either a single :class:`~torch.Tensor` or a list of them
|
|
|
|
filename: name of the file to save the predictions to
|
|
|
|
|
|
|
|
Note:
|
|
|
|
when running in distributed mode, calling ``write_prediction`` will create a file for
|
|
|
|
each device with respective names: ``filename_rank_0.pt``, ``filename_rank_1.pt``, ...
|
|
|
|
|
|
|
|
"""
|
2020-10-05 22:04:02 +00:00
|
|
|
self.trainer.evaluation_loop.predictions._add_prediction(name, value, filename)
|
|
|
|
|
2021-02-11 14:32:32 +00:00
|
|
|
def write_prediction_dict(self, predictions_dict: Dict[str, Any], filename: str = 'predictions.pt'):
|
|
|
|
"""
|
|
|
|
Write a dictonary of predictions to disk at once using ``torch.save``
|
|
|
|
|
|
|
|
Example::
|
|
|
|
|
|
|
|
pred_dict = {'pred1': torch.tensor(...), 'pred2': torch.tensor(...)}
|
|
|
|
self.write_prediction_dict(pred_dict)
|
|
|
|
|
|
|
|
Args:
|
|
|
|
predictions_dict: dict containing predictions, where each prediction should
|
|
|
|
either be single :class:`~torch.Tensor` or a list of them
|
|
|
|
|
|
|
|
Note:
|
|
|
|
when running in distributed mode, calling ``write_prediction_dict`` will create a file for
|
|
|
|
each device with respective names: ``filename_rank_0.pt``, ``filename_rank_1.pt``, ...
|
|
|
|
|
|
|
|
"""
|
2020-10-05 22:04:02 +00:00
|
|
|
for k, v in predictions_dict.items():
|
|
|
|
self.write_prediction(k, v, filename)
|
|
|
|
|
2020-09-29 06:00:28 +00:00
|
|
|
def __auto_choose_log_on_step(self, on_step):
|
|
|
|
if on_step is None:
|
|
|
|
if self._current_fx_name in {'training_step', 'training_step_end'}:
|
|
|
|
on_step = True
|
2021-02-08 19:29:43 +00:00
|
|
|
elif self._current_fx_name in {
|
|
|
|
'evaluation_step', 'evaluation_step_end', 'evaluation_epoch_end', 'training_epoch_end'
|
|
|
|
}:
|
2020-09-29 06:00:28 +00:00
|
|
|
on_step = False
|
|
|
|
else:
|
|
|
|
on_step = False
|
|
|
|
|
|
|
|
return on_step
|
|
|
|
|
|
|
|
def __auto_choose_log_on_epoch(self, on_epoch):
|
|
|
|
if on_epoch is None:
|
|
|
|
if self._current_fx_name in {'training_step', 'training_step_end'}:
|
|
|
|
on_epoch = False
|
2021-02-08 19:29:43 +00:00
|
|
|
elif self._current_fx_name in {
|
|
|
|
'evaluation_step', 'evaluation_step_end', 'evaluation_epoch_end', 'training_epoch_end'
|
|
|
|
}:
|
2020-09-29 06:00:28 +00:00
|
|
|
on_epoch = True
|
|
|
|
else:
|
|
|
|
on_epoch = True
|
|
|
|
|
|
|
|
return on_epoch
|
|
|
|
|
2021-01-09 12:37:44 +00:00
|
|
|
def all_gather(
|
|
|
|
self,
|
|
|
|
data: Union[torch.Tensor, Dict, List, Tuple],
|
|
|
|
group: Optional[Any] = None,
|
|
|
|
sync_grads: bool = False,
|
|
|
|
):
|
2020-12-08 23:20:01 +00:00
|
|
|
r"""
|
|
|
|
Allows users to call ``self.all_gather()`` from the LightningModule, thus making
|
|
|
|
the ```all_gather``` operation accelerator agnostic.
|
|
|
|
|
|
|
|
```all_gather``` is a function provided by accelerators to gather a tensor from several
|
|
|
|
distributed processes
|
|
|
|
|
|
|
|
Args:
|
2021-01-09 12:37:44 +00:00
|
|
|
tensor: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof.
|
2020-12-08 23:20:01 +00:00
|
|
|
group: the process group to gather results from. Defaults to all processes (world)
|
|
|
|
sync_grads: flag that allows users to synchronize gradients for all_gather op
|
|
|
|
|
|
|
|
Return:
|
2021-01-09 12:37:44 +00:00
|
|
|
A tensor of shape (world_size, batch, ...), or if the input was a collection
|
|
|
|
the output will also be a collection with tensors of this shape.
|
2020-12-08 23:20:01 +00:00
|
|
|
"""
|
2021-01-09 12:37:44 +00:00
|
|
|
group = group if group is not None else torch.distributed.group.WORLD
|
2021-02-18 15:54:12 +00:00
|
|
|
all_gather = self.trainer.accelerator.all_gather
|
2021-01-09 12:37:44 +00:00
|
|
|
data = convert_to_tensors(data, device=self.device)
|
|
|
|
all_gather = partial(all_gather, group=group, sync_grads=sync_grads)
|
|
|
|
return apply_to_collection(data, torch.Tensor, all_gather)
|
2020-12-08 23:20:01 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
def forward(self, *args, **kwargs):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2020-04-06 12:12:44 +00:00
|
|
|
Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define
|
|
|
|
the operations you want to use for prediction (i.e.: on a server or as a feature extractor).
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
Normally you'd call ``self()`` from your :meth:`training_step` method.
|
2020-03-05 23:52:17 +00:00
|
|
|
This makes it easy to write a complex system for training with the outputs
|
|
|
|
you'd want in a prediction setting.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-06-15 21:04:32 +00:00
|
|
|
You may also find the :func:`~pytorch_lightning.core.decorators.auto_move_data` decorator useful
|
|
|
|
when using the module outside Lightning in a production setting.
|
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Args:
|
2020-04-06 12:12:44 +00:00
|
|
|
*args: Whatever you decide to pass into the forward method.
|
|
|
|
**kwargs: Keyword arguments are also possible.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
Return:
|
|
|
|
Predicted output
|
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
Examples::
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# example if we were using this model as a feature extractor
|
|
|
|
def forward(self, x):
|
|
|
|
feature_maps = self.convnet(x)
|
|
|
|
return feature_maps
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
|
|
|
feature_maps = self(x)
|
|
|
|
logits = self.classifier(feature_maps)
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# ...
|
|
|
|
return loss
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# splitting it this way allows model to be used a feature extractor
|
|
|
|
model = MyModelAbove()
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
inputs = server.get_request()
|
|
|
|
results = model(inputs)
|
|
|
|
server.write_results(results)
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# -------------
|
|
|
|
# This is in stark contrast to torch.nn.Module where normally you would have this:
|
|
|
|
def forward(self, batch):
|
|
|
|
x, y = batch
|
|
|
|
feature_maps = self.convnet(x)
|
|
|
|
logits = self.classifier(feature_maps)
|
|
|
|
return logits
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
2020-09-03 18:24:44 +00:00
|
|
|
return super().forward(*args, **kwargs)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
def training_step(self, *args, **kwargs):
|
2020-04-06 12:12:44 +00:00
|
|
|
r"""
|
|
|
|
Here you compute and return the training loss and some additional metrics for e.g.
|
|
|
|
the progress bar or logger.
|
2020-02-11 04:55:22 +00:00
|
|
|
|
|
|
|
Args:
|
2020-04-06 12:12:44 +00:00
|
|
|
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
|
|
|
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
2020-02-11 04:55:22 +00:00
|
|
|
batch_idx (int): Integer displaying index of this batch
|
2020-04-06 12:12:44 +00:00
|
|
|
optimizer_idx (int): When using multiple optimizers, this argument will also be present.
|
|
|
|
hiddens(:class:`~torch.Tensor`): Passed in if
|
|
|
|
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Return:
|
2020-10-13 09:52:04 +00:00
|
|
|
Any of.
|
|
|
|
|
|
|
|
- :class:`~torch.Tensor` - The loss tensor
|
2021-01-20 17:27:32 +00:00
|
|
|
- ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``
|
|
|
|
- ``None`` - Training will skip to the next batch
|
|
|
|
|
|
|
|
Note:
|
|
|
|
Returning ``None`` is currently not supported for multi-GPU or TPU, or with 16-bit precision enabled.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
In this step you'd normally do the forward pass and calculate the loss for a batch.
|
2020-03-05 23:52:17 +00:00
|
|
|
You can also do fancier things like multiple forward passes or something model specific.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
Example::
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
x, y, z = batch
|
2020-09-30 12:31:16 +00:00
|
|
|
out = self.encoder(x)
|
2020-08-11 23:39:43 +00:00
|
|
|
loss = self.loss(out, x)
|
2020-09-30 12:31:16 +00:00
|
|
|
return loss
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
If you define multiple optimizers, this step will be called with an additional
|
|
|
|
``optimizer_idx`` parameter.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
# Multiple optimizers (e.g.: GANs)
|
|
|
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
|
|
|
if optimizer_idx == 0:
|
|
|
|
# do training_step with encoder
|
|
|
|
if optimizer_idx == 1:
|
|
|
|
# do training_step with decoder
|
|
|
|
|
|
|
|
|
|
|
|
If you add truncated back propagation through time you will also get an additional
|
|
|
|
argument with the hidden states of the previous step.
|
|
|
|
|
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
# Truncated back-propagation through time
|
|
|
|
def training_step(self, batch, batch_idx, hiddens):
|
|
|
|
# hiddens are the hidden states from the previous truncated backprop step
|
|
|
|
...
|
|
|
|
out, hiddens = self.lstm(data, hiddens)
|
|
|
|
...
|
2020-09-30 12:31:16 +00:00
|
|
|
return {'loss': loss, 'hiddens': hiddens}
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-10-13 09:52:04 +00:00
|
|
|
Note:
|
2020-04-06 12:12:44 +00:00
|
|
|
The loss value shown in the progress bar is smoothed (averaged) over the last values,
|
|
|
|
so it differs from the actual loss returned in train/validation step.
|
2019-08-13 15:37:37 +00:00
|
|
|
"""
|
2021-02-08 19:29:43 +00:00
|
|
|
rank_zero_warn("`training_step` must be implemented to be used with the Lightning Trainer")
|
2019-08-13 15:37:37 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
def training_step_end(self, *args, **kwargs):
|
|
|
|
"""
|
|
|
|
Use this when training with dp or ddp2 because :meth:`training_step`
|
|
|
|
will operate on only part of the batch. However, this is still optional
|
|
|
|
and only needed for things like softmax or NCE loss.
|
|
|
|
|
|
|
|
Note:
|
|
|
|
If you later switch to ddp or some other mode, this will still be called
|
|
|
|
so that you don't have to change your code
|
2020-04-03 12:43:26 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
# pseudocode
|
|
|
|
sub_batches = split_batches_for_dp(batch)
|
|
|
|
batch_parts_outputs = [training_step(sub_batch) for sub_batch in sub_batches]
|
|
|
|
training_step_end(batch_parts_outputs)
|
2020-04-03 12:43:26 +00:00
|
|
|
|
|
|
|
Args:
|
2020-08-11 23:39:43 +00:00
|
|
|
batch_parts_outputs: What you return in `training_step` for each batch part.
|
2020-04-03 12:43:26 +00:00
|
|
|
|
|
|
|
Return:
|
2020-09-30 12:31:16 +00:00
|
|
|
Anything
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
When using dp/ddp2 distributed backends, only a portion of the batch is inside the training_step:
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
.. code-block:: python
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
# batch is 1/num_gpus big
|
|
|
|
x, y = batch
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
out = self(x)
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
# softmax uses only a portion of the batch in the denomintaor
|
|
|
|
loss = self.softmax(out)
|
|
|
|
loss = nce_loss(loss)
|
2020-09-30 12:31:16 +00:00
|
|
|
return loss
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
If you wish to do something with all the parts of the batch, then use this method to do it:
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
.. code-block:: python
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
# batch is 1/num_gpus big
|
|
|
|
x, y = batch
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
out = self.encoder(x)
|
|
|
|
return {'pred': out}
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
def training_step_end(self, training_step_outputs):
|
2020-09-30 12:31:16 +00:00
|
|
|
gpu_0_pred = training_step_outputs[0]['pred']
|
|
|
|
gpu_1_pred = training_step_outputs[1]['pred']
|
|
|
|
gpu_n_pred = training_step_outputs[n]['pred']
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
# this softmax now uses the full batch
|
2020-09-30 12:31:16 +00:00
|
|
|
loss = nce_loss([gpu_0_pred, gpu_1_pred, gpu_n_pred])
|
|
|
|
return loss
|
2020-04-03 12:43:26 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
See Also:
|
2021-01-26 20:07:07 +00:00
|
|
|
See the :ref:`advanced/multi_gpu:Multi-GPU training` guide for more details.
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
|
|
|
|
2020-10-07 11:40:38 +00:00
|
|
|
def training_epoch_end(self, outputs: List[Any]) -> None:
|
2020-08-11 23:39:43 +00:00
|
|
|
"""
|
|
|
|
Called at the end of the training epoch with the outputs of all training steps.
|
|
|
|
Use this in case you need to do something with all the outputs for every training_step.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
# the pseudocode for these calls
|
|
|
|
train_outs = []
|
|
|
|
for train_batch in train_data:
|
|
|
|
out = training_step(train_batch)
|
|
|
|
train_outs.append(out)
|
|
|
|
training_epoch_end(train_outs)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Args:
|
2020-08-11 23:39:43 +00:00
|
|
|
outputs: List of outputs you defined in :meth:`training_step`, or if there are
|
|
|
|
multiple dataloaders, a list containing a list of outputs for each dataloader.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Return:
|
2020-09-30 12:31:16 +00:00
|
|
|
None
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
Note:
|
|
|
|
If this method is not overridden, this won't be called.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
Example::
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
def training_epoch_end(self, training_step_outputs):
|
|
|
|
# do something with all training_step outputs
|
|
|
|
return result
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
With multiple dataloaders, ``outputs`` will be a list of lists. The outer list contains
|
|
|
|
one entry per dataloader, while the inner list contains the individual outputs of
|
|
|
|
each training step for that dataloader.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
def training_epoch_end(self, training_step_outputs):
|
|
|
|
for out in training_step_outputs:
|
|
|
|
# do something here
|
2019-11-05 15:01:52 +00:00
|
|
|
"""
|
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
def validation_step(self, *args, **kwargs):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2020-04-06 12:12:44 +00:00
|
|
|
Operates on a single batch of data from the validation set.
|
2020-03-05 23:52:17 +00:00
|
|
|
In this step you'd might generate examples or calculate anything of interest like accuracy.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# the pseudocode for these calls
|
|
|
|
val_outs = []
|
|
|
|
for val_batch in val_data:
|
2020-11-14 01:53:11 +00:00
|
|
|
out = validation_step(val_batch)
|
2020-03-27 12:43:12 +00:00
|
|
|
val_outs.append(out)
|
2021-01-11 11:02:30 +00:00
|
|
|
validation_epoch_end(val_outs)
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Args:
|
2020-04-06 12:12:44 +00:00
|
|
|
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
|
|
|
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
2020-01-17 11:03:31 +00:00
|
|
|
batch_idx (int): The index of this batch
|
2020-03-05 23:52:17 +00:00
|
|
|
dataloader_idx (int): The index of the dataloader that produced this batch
|
2021-01-11 11:02:30 +00:00
|
|
|
(only if multiple val dataloaders used)
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
Return:
|
2020-10-13 09:52:04 +00:00
|
|
|
Any of.
|
|
|
|
|
|
|
|
- Any object or value
|
2021-01-20 17:27:32 +00:00
|
|
|
- ``None`` - Validation will skip to the next batch
|
2020-03-06 15:33:17 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# pseudocode of order
|
|
|
|
out = validation_step()
|
|
|
|
if defined('validation_step_end'):
|
|
|
|
out = validation_step_end(out)
|
|
|
|
out = validation_epoch_end(out)
|
|
|
|
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# if you have one val dataloader:
|
2019-12-04 11:57:10 +00:00
|
|
|
def validation_step(self, batch, batch_idx)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
# if you have multiple val dataloaders:
|
2020-03-05 17:32:45 +00:00
|
|
|
def validation_step(self, batch, batch_idx, dataloader_idx)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
Examples::
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# CASE 1: A single validation dataset
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# implement your own
|
|
|
|
out = self(x)
|
|
|
|
loss = self.loss(out, y)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# log 6 example images
|
|
|
|
# or generated text... or whatever
|
|
|
|
sample_imgs = x[:6]
|
|
|
|
grid = torchvision.utils.make_grid(sample_imgs)
|
|
|
|
self.logger.experiment.add_image('example_images', grid, 0)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# calculate acc
|
|
|
|
labels_hat = torch.argmax(out, dim=1)
|
|
|
|
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# log the outputs!
|
|
|
|
self.log_dict({'val_loss': loss, 'val_acc': val_acc})
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-11 11:02:30 +00:00
|
|
|
If you pass in multiple val dataloaders, :meth:`validation_step` will have an additional argument.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-11 11:02:30 +00:00
|
|
|
# CASE 2: multiple validation dataloaders
|
2021-01-26 09:44:54 +00:00
|
|
|
def validation_step(self, batch, batch_idx, dataloader_idx):
|
|
|
|
# dataloader_idx tells you which dataset this is.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
Note:
|
|
|
|
If you don't need to validate you don't need to implement this method.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
Note:
|
|
|
|
When the :meth:`validation_step` is called, the model has been put in eval mode
|
2020-03-05 23:52:17 +00:00
|
|
|
and PyTorch gradients have been disabled. At the end of validation,
|
2020-03-06 15:33:17 +00:00
|
|
|
the model goes back to training mode and gradients are enabled.
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
def validation_step_end(self, *args, **kwargs):
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
2020-04-06 12:12:44 +00:00
|
|
|
Use this when validating with dp or ddp2 because :meth:`validation_step`
|
|
|
|
will operate on only part of the batch. However, this is still optional
|
2020-03-05 17:32:45 +00:00
|
|
|
and only needed for things like softmax or NCE loss.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
Note:
|
|
|
|
If you later switch to ddp or some other mode, this will still be called
|
|
|
|
so that you don't have to change your code.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
# pseudocode
|
|
|
|
sub_batches = split_batches_for_dp(batch)
|
2020-03-06 15:33:17 +00:00
|
|
|
batch_parts_outputs = [validation_step(sub_batch) for sub_batch in sub_batches]
|
2020-03-05 17:32:45 +00:00
|
|
|
validation_step_end(batch_parts_outputs)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Args:
|
2020-04-06 12:12:44 +00:00
|
|
|
batch_parts_outputs: What you return in :meth:`validation_step`
|
|
|
|
for each batch part.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Return:
|
2020-09-30 12:31:16 +00:00
|
|
|
None or anything
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
.. code-block:: python
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
# WITHOUT validation_step_end
|
|
|
|
# if used in DP or DDP2, this batch is 1/num_gpus large
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
# batch is 1/num_gpus big
|
|
|
|
x, y = batch
|
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
out = self.encoder(x)
|
2020-08-11 23:39:43 +00:00
|
|
|
loss = self.softmax(out)
|
|
|
|
loss = nce_loss(loss)
|
2020-09-30 12:31:16 +00:00
|
|
|
self.log('val_loss', loss)
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
# --------------
|
|
|
|
# with validation_step_end to do softmax over the full batch
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
# batch is 1/num_gpus big
|
|
|
|
x, y = batch
|
|
|
|
|
|
|
|
out = self(x)
|
2020-09-30 12:31:16 +00:00
|
|
|
return out
|
2020-08-11 23:39:43 +00:00
|
|
|
|
2020-12-24 18:37:30 +00:00
|
|
|
def validation_step_end(self, val_step_outputs):
|
2020-09-30 12:31:16 +00:00
|
|
|
for out in val_step_outputs:
|
|
|
|
# do something with these
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
See Also:
|
2021-01-26 20:07:07 +00:00
|
|
|
See the :ref:`advanced/multi_gpu:Multi-GPU training` guide for more details.
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-12-24 18:37:30 +00:00
|
|
|
def validation_epoch_end(self, outputs: List[Any]) -> None:
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
2020-04-06 12:12:44 +00:00
|
|
|
Called at the end of the validation epoch with the outputs of all validation steps.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
# the pseudocode for these calls
|
|
|
|
val_outs = []
|
|
|
|
for val_batch in val_data:
|
2020-04-13 16:16:54 +00:00
|
|
|
out = validation_step(val_batch)
|
2020-04-03 12:43:26 +00:00
|
|
|
val_outs.append(out)
|
2020-03-05 17:32:45 +00:00
|
|
|
validation_epoch_end(val_outs)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
Args:
|
2020-04-06 12:12:44 +00:00
|
|
|
outputs: List of outputs you defined in :meth:`validation_step`, or if there
|
|
|
|
are multiple dataloaders, a list containing a list of outputs for each dataloader.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
Return:
|
2020-09-30 12:31:16 +00:00
|
|
|
None
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
Note:
|
|
|
|
If you didn't define a :meth:`validation_step`, this won't be called.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
2020-04-06 12:12:44 +00:00
|
|
|
With a single dataloader:
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
def validation_epoch_end(self, val_step_outputs):
|
2020-09-30 12:31:16 +00:00
|
|
|
for out in val_step_outputs:
|
|
|
|
# do something
|
2020-03-05 23:52:17 +00:00
|
|
|
|
|
|
|
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
|
|
|
|
one entry per dataloader, while the inner list contains the individual outputs of
|
|
|
|
each validation step for that dataloader.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def validation_epoch_end(self, outputs):
|
2020-08-11 23:39:43 +00:00
|
|
|
for dataloader_output_result in outputs:
|
|
|
|
dataloader_outs = dataloader_output_result.dataloader_i_outputs
|
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
self.log('final_metric', final_value)
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
def test_step(self, *args, **kwargs):
|
2020-03-05 17:32:45 +00:00
|
|
|
r"""
|
2020-04-06 12:12:44 +00:00
|
|
|
Operates on a single batch of data from the test set.
|
2020-03-05 23:52:17 +00:00
|
|
|
In this step you'd normally generate examples or calculate anything of interest
|
|
|
|
such as accuracy.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# the pseudocode for these calls
|
|
|
|
test_outs = []
|
|
|
|
for test_batch in test_data:
|
2020-04-06 12:12:44 +00:00
|
|
|
out = test_step(test_batch)
|
2020-03-06 15:33:17 +00:00
|
|
|
test_outs.append(out)
|
2020-03-05 17:32:45 +00:00
|
|
|
test_epoch_end(test_outs)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
Args:
|
2020-04-06 12:12:44 +00:00
|
|
|
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
|
|
|
|
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
|
|
|
|
batch_idx (int): The index of this batch.
|
2020-03-05 23:52:17 +00:00
|
|
|
dataloader_idx (int): The index of the dataloader that produced this batch
|
2021-01-11 11:02:30 +00:00
|
|
|
(only if multiple test dataloaders used).
|
2020-03-05 17:32:45 +00:00
|
|
|
|
|
|
|
Return:
|
2020-10-13 09:52:04 +00:00
|
|
|
Any of.
|
|
|
|
|
|
|
|
- Any object or value
|
2021-01-20 17:27:32 +00:00
|
|
|
- ``None`` - Testing will skip to the next batch
|
2020-03-05 17:32:45 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# if you have one test dataloader:
|
|
|
|
def test_step(self, batch, batch_idx)
|
|
|
|
|
|
|
|
# if you have multiple test dataloaders:
|
|
|
|
def test_step(self, batch, batch_idx, dataloader_idx)
|
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
Examples::
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# CASE 1: A single test dataset
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# implement your own
|
|
|
|
out = self(x)
|
|
|
|
loss = self.loss(out, y)
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# log 6 example images
|
|
|
|
# or generated text... or whatever
|
|
|
|
sample_imgs = x[:6]
|
|
|
|
grid = torchvision.utils.make_grid(sample_imgs)
|
|
|
|
self.logger.experiment.add_image('example_images', grid, 0)
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# calculate acc
|
|
|
|
labels_hat = torch.argmax(out, dim=1)
|
|
|
|
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# log the outputs!
|
|
|
|
self.log_dict({'test_loss': loss, 'test_acc': test_acc})
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2021-01-11 11:02:30 +00:00
|
|
|
If you pass in multiple test dataloaders, :meth:`test_step` will have an additional argument.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
.. code-block:: python
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2021-01-11 11:02:30 +00:00
|
|
|
# CASE 2: multiple test dataloaders
|
2021-01-26 09:44:54 +00:00
|
|
|
def test_step(self, batch, batch_idx, dataloader_idx):
|
|
|
|
# dataloader_idx tells you which dataset this is.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
Note:
|
2021-01-11 11:02:30 +00:00
|
|
|
If you don't need to test you don't need to implement this method.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
Note:
|
|
|
|
When the :meth:`test_step` is called, the model has been put in eval mode and
|
2020-03-06 15:33:17 +00:00
|
|
|
PyTorch gradients have been disabled. At the end of the test epoch, the model goes back
|
2020-03-05 23:52:17 +00:00
|
|
|
to training mode and gradients are enabled.
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
def test_step_end(self, *args, **kwargs):
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
2020-04-06 12:12:44 +00:00
|
|
|
Use this when testing with dp or ddp2 because :meth:`test_step` will operate
|
2020-03-05 17:32:45 +00:00
|
|
|
on only part of the batch. However, this is still optional
|
|
|
|
and only needed for things like softmax or NCE loss.
|
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
Note:
|
|
|
|
If you later switch to ddp or some other mode, this will still be called
|
|
|
|
so that you don't have to change your code.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# pseudocode
|
|
|
|
sub_batches = split_batches_for_dp(batch)
|
2020-03-06 15:33:17 +00:00
|
|
|
batch_parts_outputs = [test_step(sub_batch) for sub_batch in sub_batches]
|
2020-03-05 17:32:45 +00:00
|
|
|
test_step_end(batch_parts_outputs)
|
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Args:
|
2020-04-06 12:12:44 +00:00
|
|
|
batch_parts_outputs: What you return in :meth:`test_step` for each batch part.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Return:
|
2020-09-30 12:31:16 +00:00
|
|
|
None or anything
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
.. code-block:: python
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-08-11 23:39:43 +00:00
|
|
|
# WITHOUT test_step_end
|
|
|
|
# if used in DP or DDP2, this batch is 1/num_gpus large
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
|
|
# batch is 1/num_gpus big
|
|
|
|
x, y = batch
|
|
|
|
|
|
|
|
out = self(x)
|
|
|
|
loss = self.softmax(out)
|
2020-10-08 14:00:04 +00:00
|
|
|
self.log('test_loss', loss)
|
2020-08-11 23:39:43 +00:00
|
|
|
|
|
|
|
# --------------
|
|
|
|
# with test_step_end to do softmax over the full batch
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
|
|
# batch is 1/num_gpus big
|
|
|
|
x, y = batch
|
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
out = self.encoder(x)
|
|
|
|
return out
|
2020-08-11 23:39:43 +00:00
|
|
|
|
2020-12-24 18:37:30 +00:00
|
|
|
def test_step_end(self, output_results):
|
2020-08-11 23:39:43 +00:00
|
|
|
# this out is now the full size of the batch
|
|
|
|
all_test_step_outs = output_results.out
|
|
|
|
loss = nce_loss(all_test_step_outs)
|
2020-09-30 12:31:16 +00:00
|
|
|
self.log('test_loss', loss)
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
See Also:
|
2021-01-26 20:07:07 +00:00
|
|
|
See the :ref:`advanced/multi_gpu:Multi-GPU training` guide for more details.
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
|
|
|
|
2021-02-08 19:29:43 +00:00
|
|
|
def test_epoch_end(self, outputs: List[Any]) -> None:
|
2020-03-05 17:32:45 +00:00
|
|
|
"""
|
2020-04-06 12:12:44 +00:00
|
|
|
Called at the end of a test epoch with the output of all test steps.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# the pseudocode for these calls
|
|
|
|
test_outs = []
|
|
|
|
for test_batch in test_data:
|
|
|
|
out = test_step(test_batch)
|
|
|
|
test_outs.append(out)
|
|
|
|
test_epoch_end(test_outs)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 17:32:45 +00:00
|
|
|
Args:
|
2020-04-06 12:12:44 +00:00
|
|
|
outputs: List of outputs you defined in :meth:`test_step_end`, or if there
|
|
|
|
are multiple dataloaders, a list containing a list of outputs for each dataloader
|
2020-03-05 17:32:45 +00:00
|
|
|
|
|
|
|
Return:
|
2020-09-30 12:31:16 +00:00
|
|
|
None
|
2020-04-06 12:12:44 +00:00
|
|
|
|
|
|
|
Note:
|
|
|
|
If you didn't define a :meth:`test_step`, this won't be called.
|
2020-03-05 17:32:45 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
Examples:
|
2020-04-06 12:12:44 +00:00
|
|
|
With a single dataloader:
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
def test_epoch_end(self, outputs):
|
2020-08-11 23:39:43 +00:00
|
|
|
# do something with the outputs of all test batches
|
|
|
|
all_test_preds = test_step_outputs.predictions
|
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
some_result = calc_all_results(all_test_preds)
|
|
|
|
self.log(some_result)
|
2020-03-05 23:52:17 +00:00
|
|
|
|
|
|
|
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
|
|
|
|
one entry per dataloader, while the inner list contains the individual outputs of
|
|
|
|
each test step for that dataloader.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def test_epoch_end(self, outputs):
|
2020-09-30 12:31:16 +00:00
|
|
|
final_value = 0
|
|
|
|
for dataloader_outputs in outputs:
|
|
|
|
for test_step_out in dataloader_outputs:
|
|
|
|
# do something
|
|
|
|
final_value += test_step_out
|
2020-08-11 23:39:43 +00:00
|
|
|
|
2020-09-30 12:31:16 +00:00
|
|
|
self.log('final_metric', final_value)
|
2019-08-30 22:56:09 +00:00
|
|
|
"""
|
|
|
|
|
2021-01-27 16:38:14 +00:00
|
|
|
def predict(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None):
|
|
|
|
"""
|
|
|
|
Use this function with trainer.predict(...). Override if you need to add any processing logic.
|
|
|
|
"""
|
|
|
|
return self(batch)
|
|
|
|
|
2021-02-13 00:27:44 +00:00
|
|
|
def configure_callbacks(self):
|
|
|
|
"""
|
|
|
|
Configure model-specific callbacks.
|
|
|
|
When the model gets attached, e.g., when ``.fit()`` or ``.test()`` gets called,
|
|
|
|
the list returned here will be merged with the list of callbacks passed to the Trainer's ``callbacks`` argument.
|
|
|
|
If a callback returned here has the same type as one or several callbacks already present in
|
|
|
|
the Trainer's callbacks list, it will take priority and replace them.
|
|
|
|
In addition, Lightning will make sure :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
|
|
|
|
callbacks run last.
|
|
|
|
|
|
|
|
Return:
|
|
|
|
A list of callbacks which will extend the list of callbacks in the Trainer.
|
|
|
|
|
|
|
|
Example::
|
|
|
|
|
|
|
|
def configure_callbacks(self):
|
|
|
|
early_stop = EarlyStopping(monitor"val_acc", mode="max")
|
|
|
|
checkpoint = ModelCheckpoint(monitor="val_loss")
|
|
|
|
return [early_stop, checkpoint]
|
|
|
|
|
|
|
|
Note:
|
|
|
|
Certain callback methods like :meth:`~pytorch_lightning.callbacks.base.Callback.on_init_start`
|
|
|
|
will never be invoked on the new callbacks returned here.
|
|
|
|
"""
|
|
|
|
return []
|
|
|
|
|
2021-02-08 19:29:43 +00:00
|
|
|
def configure_optimizers(self):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2020-03-05 23:52:17 +00:00
|
|
|
Choose what optimizers and learning-rate schedulers to use in your optimization.
|
|
|
|
Normally you'd need one. But in the case of GANs or similar you might have multiple.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
Return:
|
|
|
|
Any of these 6 options.
|
|
|
|
|
2020-03-31 16:41:24 +00:00
|
|
|
- Single optimizer.
|
|
|
|
- List or Tuple - List of optimizers.
|
2020-06-04 15:23:44 +00:00
|
|
|
- Two lists - The first list has multiple optimizers, the second a list of LR schedulers (or lr_dict).
|
2020-11-12 19:22:06 +00:00
|
|
|
- Dictionary, with an 'optimizer' key, and (optionally) a 'lr_scheduler'
|
2020-12-14 07:38:10 +00:00
|
|
|
key whose value is a single LR scheduler or lr_dict.
|
2020-04-06 12:12:44 +00:00
|
|
|
- Tuple of dictionaries as described, with an optional 'frequency' key.
|
2020-04-02 15:48:53 +00:00
|
|
|
- None - Fit will run without any optimizer.
|
2020-03-31 16:41:24 +00:00
|
|
|
|
|
|
|
Note:
|
2020-04-06 12:12:44 +00:00
|
|
|
The 'frequency' value is an int corresponding to the number of sequential batches
|
2020-03-31 16:41:24 +00:00
|
|
|
optimized with the specific optimizer. It should be given to none or to all of the optimizers.
|
2020-04-06 12:12:44 +00:00
|
|
|
There is a difference between passing multiple optimizers in a list,
|
2020-03-31 16:41:24 +00:00
|
|
|
and passing multiple optimizers in dictionaries with a frequency of 1:
|
|
|
|
In the former case, all optimizers will operate on the given batch in each optimization step.
|
|
|
|
In the latter, only one optimizer will operate on the given batch at every step.
|
2020-03-05 23:52:17 +00:00
|
|
|
|
2020-12-14 07:38:10 +00:00
|
|
|
The lr_dict is a dictionary which contains the scheduler and its associated configuration.
|
|
|
|
The default configuration is shown below.
|
2020-06-04 15:23:44 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
{
|
2020-12-14 07:38:10 +00:00
|
|
|
'scheduler': lr_scheduler, # The LR scheduler instance (required)
|
2020-06-04 15:23:44 +00:00
|
|
|
'interval': 'epoch', # The unit of the scheduler's step size
|
|
|
|
'frequency': 1, # The frequency of the scheduler
|
|
|
|
'reduce_on_plateau': False, # For ReduceLROnPlateau scheduler
|
2020-10-21 12:14:37 +00:00
|
|
|
'monitor': 'val_loss', # Metric for ReduceLROnPlateau to monitor
|
2020-12-14 07:38:10 +00:00
|
|
|
'strict': True, # Whether to crash the training if `monitor` is not found
|
|
|
|
'name': None, # Custom name for LearningRateMonitor to use
|
2020-06-04 15:23:44 +00:00
|
|
|
}
|
|
|
|
|
2020-12-14 07:38:10 +00:00
|
|
|
Only the ``scheduler`` key is required, the rest will be set to the defaults above.
|
2020-06-04 15:23:44 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
Examples::
|
|
|
|
|
|
|
|
# most cases
|
|
|
|
def configure_optimizers(self):
|
|
|
|
opt = Adam(self.parameters(), lr=1e-3)
|
|
|
|
return opt
|
|
|
|
|
|
|
|
# multiple optimizer case (e.g.: GAN)
|
|
|
|
def configure_optimizers(self):
|
|
|
|
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
|
|
|
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
|
|
|
return generator_opt, disriminator_opt
|
|
|
|
|
|
|
|
# example with learning rate schedulers
|
|
|
|
def configure_optimizers(self):
|
|
|
|
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
|
|
|
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
|
|
|
discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10)
|
|
|
|
return [generator_opt, disriminator_opt], [discriminator_sched]
|
|
|
|
|
|
|
|
# example with step-based learning rate schedulers
|
|
|
|
def configure_optimizers(self):
|
|
|
|
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
|
|
|
dis_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
|
|
|
gen_sched = {'scheduler': ExponentialLR(gen_opt, 0.99),
|
|
|
|
'interval': 'step'} # called after each training step
|
|
|
|
dis_sched = CosineAnnealing(discriminator_opt, T_max=10) # called every epoch
|
|
|
|
return [gen_opt, dis_opt], [gen_sched, dis_sched]
|
|
|
|
|
|
|
|
# example with optimizer frequencies
|
|
|
|
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
|
|
|
|
# https://arxiv.org/abs/1704.00028
|
|
|
|
def configure_optimizers(self):
|
|
|
|
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
|
|
|
|
dis_opt = Adam(self.model_disc.parameters(), lr=0.02)
|
|
|
|
n_critic = 5
|
|
|
|
return (
|
|
|
|
{'optimizer': dis_opt, 'frequency': n_critic},
|
|
|
|
{'optimizer': gen_opt, 'frequency': 1}
|
|
|
|
)
|
2020-03-31 16:41:24 +00:00
|
|
|
|
2020-03-20 19:49:01 +00:00
|
|
|
Note:
|
|
|
|
|
|
|
|
Some things to know:
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
- Lightning calls ``.backward()`` and ``.step()`` on each optimizer
|
2020-03-20 19:49:01 +00:00
|
|
|
and learning rate scheduler as needed.
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
- If you use 16-bit precision (``precision=16``), Lightning will automatically
|
2020-03-20 19:49:01 +00:00
|
|
|
handle the optimizers for you.
|
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
- If you use multiple optimizers, :meth:`training_step` will have an additional
|
2020-03-20 19:49:01 +00:00
|
|
|
``optimizer_idx`` parameter.
|
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
- If you use LBFGS Lightning handles the closure function automatically for you.
|
2020-03-20 19:49:01 +00:00
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
- If you use multiple optimizers, gradients will be calculated only
|
2020-03-20 19:49:01 +00:00
|
|
|
for the parameters of current optimizer at each training step.
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
- If you need to control how often those optimizers step or override the
|
2020-04-06 12:12:44 +00:00
|
|
|
default ``.step()`` schedule, override the :meth:`optimizer_step` hook.
|
2020-03-20 19:49:01 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
- If you only want to call a learning rate scheduler every ``x`` step or epoch,
|
2020-06-04 15:23:44 +00:00
|
|
|
or want to monitor a custom metric, you can specify these in a lr_dict:
|
2020-03-20 19:49:01 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
{
|
|
|
|
'scheduler': lr_scheduler,
|
2020-06-04 15:23:44 +00:00
|
|
|
'interval': 'step', # or 'epoch'
|
2020-03-20 19:49:01 +00:00
|
|
|
'monitor': 'val_f1',
|
2020-06-04 15:23:44 +00:00
|
|
|
'frequency': x,
|
2020-03-20 19:49:01 +00:00
|
|
|
}
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
2021-02-08 19:29:43 +00:00
|
|
|
rank_zero_warn("`configure_optimizers` must be implemented to be used with the Lightning Trainer")
|
2019-03-31 01:45:16 +00:00
|
|
|
|
[accelerator][FeatBugFix] Improve manual optimization API (#5771)
* fix trainer.model access
* move properties
* fix test_transfer_batch_hook
* fix auto_select_gpus
* fix omegaconf test
* fix test that needs to simulate slurm ddp
* add horovod plugin
* fix test with named arguments
* clean up whitespace
* fix datamodules test
* remove old accelerators
* fix naming
* move old plugins
* move to plugins
* create precision subpackage
* create training_type subpackage
* fix all new import errors
* fix wrong arguments order passed to test
* fix LR finder
* Added sharded training type and amp plugin
* Move clip grad to precision plugin
* Added sharded spawn, select accelerators based on distributed_backend + enable custom fp16 plugin automatically
* Fix import issue, attempting to fix tests
* Fix initial test
* Reflect hook logic from master, should wrap model after move to device
* Optional state consolidation, since master has optimizers not wrapped
* change attribute for instance test
* reset optimizers
optimizers are not used in main process, so state would be wrong.
* legacy
* imports in accel
* legacy2
* trainer imports
* fix import errors after rebase
* move hook to new setup location
* provide unwrapping logic
* fix trainer callback system
* added ddp2 implementation
* fix imports .legacy
* move plugins
* restore legacy
* drop test.py from root
* add tpu accelerator and plugins
* fixes
* fix lightning optimizer merge
* reset bugreportmodel
* unwrapping
* step routing forward
* model access
* unwrap
* opt
* integrate distrib_type
* sync changes
* sync
* fixes
* add forgotten generators
* add missing logic
* update
* import
* missed imports
* import fixes
* isort
* mv f
* changelog
* format
* move helper to parallel plugin
* d
* add world size
* clean up
* duplicate
* activate ddp_sharded and tpu
* set nvidia flags
* remove unused colab var
* use_tpu <-> on_tpu attrs
* make some ddp_cpu and clusterplugin tests pass
* Ref/accelerator connector (#5742)
* final cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* connector cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* trainer cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* accelerator cleanup + missing logic in accelerator connector
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* add missing changes to callbacks
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* reflect accelerator changes to lightning module
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* clean cluster envs
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* cleanup plugins
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* add broadcasting
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* yapf
* remove plugin connector
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* plugins
* manual optimization
* update optimizer routing
* add rank to torchelastic
* fix memory mixed precision
* setstate on trainer for pickling in ddp spawn
* add predict method
* add back commented accelerator code
* adapt test for sync_batch_norm to new plugin
* fix deprecated tests
* fix ddp cpu choice when no num_processes are given
* yapf format
* skip a memory test that cannot pass anymore
* update on comments
* fix pickle error in spawn plugin
* x
* avoid
* x
* fix cyclic import in docs build
* add support for sharded
* update typing
* add sharded and sharded_spawn to distributed types
* make unwrap model default
* refactor LightningShardedDataParallel similar to LightningDistributedDataParallel
* update sharded spawn to reflect changes
* update sharded to reflect changes
* Merge 1.1.5 changes
* fix merge
* fix merge
* yapf isort
* fix merge
* yapf isort
* fix indentation in test
* copy over reinit scheduler implementation from dev1.2
* fix apex tracking calls with dev_debugger
* reduce diff to dev1.2, clean up
* fix trainer config test when gpus>0 and num_processes >0 and ddp_cpu
* sort plugin tests legacy/new
* fix error handling for amp on cpu
* fix merge
fix merge
fix merge
* [Feat] Resolve manual_backward (#5837)
* resolve manual_backward
* resolve flake8
* update
* resolve for ddp_spawn
* resolve flake8
* resolve flake8
* resolve flake8
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* fix tests/accelerator tests on cpu
* [BugFix] Resolve manual optimization (#5852)
* resolve manual_optimization
* update
* update
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* Remove copy trainer parameters to happen earlier within the loop and add safe guard to get ref model (#5856)
* resovle a bug
* Accelerator refactor sharded rpc (#5854)
* rpc branch
* merge
* update handling of rpc
* make devices etc. Optional in RPC
* set devices etc. later if necessary
* remove devices from sequential
* make devices optional in rpc
* fix import
* uncomment everything
* fix cluster selection
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* resolve bug
* fix assert in rpc test
* resolve a test
* fix docs compilation
* accelerator refactor - fix for sharded parity test (#5866)
* fix memory issue with ddp_spawn
* x
x
x
x
x
x
x
x
x
* x
* Remove DDP2 as this does not apply
* Add missing pre optimizer hook to ensure lambda closure is called
* fix apex docstring
* [accelerator][BugFix] Resolve some test for 1 gpu (#5863)
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* update
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* revert init
* update
* resolve flake8
* update
* update
* update
* update
* update
* all_gather
* update
* make plugins work, add misconfig for RPC
* update
* update
* remove breaking test
* resolve some tests
* resolve flake8
* revert to ddp_spawn
Co-authored-by: root <root@ip-172-31-88-60.ec2.internal>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de>
* yapf isort
* resolve flake8
* fix apex doctests
* fix apex doctests 2
* resolve docs
* update drone
* clean env
* update
* update
* update
* update
* merge
* Fix RPC related tests, clean out old API, update for new accelerator API [skip ci] (#5881)
* Fix RPC related tests, clean out old API, update for new accelerator API
* Move tests out of legacy folder, update paths and names
* Update test_remove_1-4.py
* Expose properties for tpu cores/gpus/num_gpus
* Add root GPU property
* Move properties to properties.py
* move tests that were previously in drone
* Fix root GPU property (#5908)
* Move root GPU to property, remove horovod set as this is handled in horovod plugin, ensure we mock correctly to set GPU accelerator
* Add missing tests back
* fix best model path transfer when no checkpoint callback available
* Fix setup hook order [wip] (#5858)
* Call trainer setup hook before accelerator setup
* Add test case
* add new test
* typo
* fix callback order in test
Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* rename ddp sequential -> rpc sequential for special test
* revert
* fix stupid merge problem
* Use property in connector for sampler (#5913)
* merge the import conflicts
* fix spawning of processes in slurm
* [wip] Fix some bugs for TPU [skip ci] (#5878)
* fixed for single tpu
* fixed spawn
* fixed spawn
* update
* update
* wip
* resolve bugs
* resolve bug
* update on comment
* removed decorator
* resolve comments
* set to 4
* update
* update
* need cleaning
* update
* update
* update
* resolve flake8
* resolve bugs
* exclude broadcast
* resolve bugs
* change test
* update
* update
* skip if meet fails
* properly raise trace
* update
* add catch
* wrap test
* resolve typo
* update
* typo
Co-authored-by: Lezwon Castelino <lezwon@gmail.com>
Co-authored-by: Your Name <you@example.com>
* resolve some tests
* update
* fix imports
* update
* resolve flake8
* update azure pipeline
* skip a sharded test on cpu that requires a gpu
* resolve tpus
* resolve bug
* resolve flake8
* update
* updat utils
* revert permission change on files
* suggestions from carlos
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
* remove unrelated formatting changes
* remove incomplete comment
* Update pytorch_lightning/accelerators/__init__.py
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
* remove unrelated formatting change
* add types
* warn 1.7 ddp manual backward only if ddp kwarg unset
* yapf + isort
* pep8 unused imports
* fix cyclic import in docs
* Apply suggestions from code review
* typer in accelerator.py
* typo
* Apply suggestions from code review
* formatting
* update on comments
* update typo
* Update pytorch_lightning/trainer/properties.py
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* update
* update on comments
* resolve some comments
* update on comments
* resolve test
* add toggle_model
* update
* update on comments
* update doc
* typo
* update
* typo
* remove space
* update
* update on comments
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: justusschock <justus.schock@posteo.de>
Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: root <root@ip-172-31-88-60.ec2.internal>
Co-authored-by: Lezwon Castelino <lezwon@gmail.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
2021-02-16 21:00:35 +00:00
|
|
|
def manual_backward(self, loss: Tensor, optimizer: Optional[Optimizer] = None, *args, **kwargs) -> None:
|
2020-10-10 16:19:22 +00:00
|
|
|
"""
|
|
|
|
Call this directly from your training_step when doing optimizations manually.
|
|
|
|
By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you
|
|
|
|
|
|
|
|
This function forwards all args to the .backward() call as well.
|
|
|
|
|
2020-10-11 17:12:35 +00:00
|
|
|
.. tip:: In manual mode we still automatically clip grads if Trainer(gradient_clip_val=x) is set
|
|
|
|
|
2020-11-12 19:22:06 +00:00
|
|
|
.. tip:: In manual mode we still automatically accumulate grad over batches if
|
2020-12-01 00:09:46 +00:00
|
|
|
Trainer(accumulate_grad_batches=x) is set and you use `optimizer.step()`
|
2020-11-10 19:44:51 +00:00
|
|
|
|
2020-10-10 16:19:22 +00:00
|
|
|
Example::
|
|
|
|
|
|
|
|
def training_step(...):
|
|
|
|
(opt_a, opt_b) = self.optimizers()
|
|
|
|
loss = ...
|
|
|
|
# automatically applies scaling, etc...
|
2020-10-10 22:44:24 +00:00
|
|
|
self.manual_backward(loss, opt_a)
|
2020-12-01 00:09:46 +00:00
|
|
|
opt_a.step()
|
2020-10-10 16:19:22 +00:00
|
|
|
"""
|
[accelerator][FeatBugFix] Improve manual optimization API (#5771)
* fix trainer.model access
* move properties
* fix test_transfer_batch_hook
* fix auto_select_gpus
* fix omegaconf test
* fix test that needs to simulate slurm ddp
* add horovod plugin
* fix test with named arguments
* clean up whitespace
* fix datamodules test
* remove old accelerators
* fix naming
* move old plugins
* move to plugins
* create precision subpackage
* create training_type subpackage
* fix all new import errors
* fix wrong arguments order passed to test
* fix LR finder
* Added sharded training type and amp plugin
* Move clip grad to precision plugin
* Added sharded spawn, select accelerators based on distributed_backend + enable custom fp16 plugin automatically
* Fix import issue, attempting to fix tests
* Fix initial test
* Reflect hook logic from master, should wrap model after move to device
* Optional state consolidation, since master has optimizers not wrapped
* change attribute for instance test
* reset optimizers
optimizers are not used in main process, so state would be wrong.
* legacy
* imports in accel
* legacy2
* trainer imports
* fix import errors after rebase
* move hook to new setup location
* provide unwrapping logic
* fix trainer callback system
* added ddp2 implementation
* fix imports .legacy
* move plugins
* restore legacy
* drop test.py from root
* add tpu accelerator and plugins
* fixes
* fix lightning optimizer merge
* reset bugreportmodel
* unwrapping
* step routing forward
* model access
* unwrap
* opt
* integrate distrib_type
* sync changes
* sync
* fixes
* add forgotten generators
* add missing logic
* update
* import
* missed imports
* import fixes
* isort
* mv f
* changelog
* format
* move helper to parallel plugin
* d
* add world size
* clean up
* duplicate
* activate ddp_sharded and tpu
* set nvidia flags
* remove unused colab var
* use_tpu <-> on_tpu attrs
* make some ddp_cpu and clusterplugin tests pass
* Ref/accelerator connector (#5742)
* final cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* connector cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* trainer cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* accelerator cleanup + missing logic in accelerator connector
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* add missing changes to callbacks
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* reflect accelerator changes to lightning module
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* clean cluster envs
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* cleanup plugins
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* add broadcasting
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* yapf
* remove plugin connector
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* plugins
* manual optimization
* update optimizer routing
* add rank to torchelastic
* fix memory mixed precision
* setstate on trainer for pickling in ddp spawn
* add predict method
* add back commented accelerator code
* adapt test for sync_batch_norm to new plugin
* fix deprecated tests
* fix ddp cpu choice when no num_processes are given
* yapf format
* skip a memory test that cannot pass anymore
* update on comments
* fix pickle error in spawn plugin
* x
* avoid
* x
* fix cyclic import in docs build
* add support for sharded
* update typing
* add sharded and sharded_spawn to distributed types
* make unwrap model default
* refactor LightningShardedDataParallel similar to LightningDistributedDataParallel
* update sharded spawn to reflect changes
* update sharded to reflect changes
* Merge 1.1.5 changes
* fix merge
* fix merge
* yapf isort
* fix merge
* yapf isort
* fix indentation in test
* copy over reinit scheduler implementation from dev1.2
* fix apex tracking calls with dev_debugger
* reduce diff to dev1.2, clean up
* fix trainer config test when gpus>0 and num_processes >0 and ddp_cpu
* sort plugin tests legacy/new
* fix error handling for amp on cpu
* fix merge
fix merge
fix merge
* [Feat] Resolve manual_backward (#5837)
* resolve manual_backward
* resolve flake8
* update
* resolve for ddp_spawn
* resolve flake8
* resolve flake8
* resolve flake8
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* fix tests/accelerator tests on cpu
* [BugFix] Resolve manual optimization (#5852)
* resolve manual_optimization
* update
* update
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* Remove copy trainer parameters to happen earlier within the loop and add safe guard to get ref model (#5856)
* resovle a bug
* Accelerator refactor sharded rpc (#5854)
* rpc branch
* merge
* update handling of rpc
* make devices etc. Optional in RPC
* set devices etc. later if necessary
* remove devices from sequential
* make devices optional in rpc
* fix import
* uncomment everything
* fix cluster selection
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* resolve bug
* fix assert in rpc test
* resolve a test
* fix docs compilation
* accelerator refactor - fix for sharded parity test (#5866)
* fix memory issue with ddp_spawn
* x
x
x
x
x
x
x
x
x
* x
* Remove DDP2 as this does not apply
* Add missing pre optimizer hook to ensure lambda closure is called
* fix apex docstring
* [accelerator][BugFix] Resolve some test for 1 gpu (#5863)
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* update
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* revert init
* update
* resolve flake8
* update
* update
* update
* update
* update
* all_gather
* update
* make plugins work, add misconfig for RPC
* update
* update
* remove breaking test
* resolve some tests
* resolve flake8
* revert to ddp_spawn
Co-authored-by: root <root@ip-172-31-88-60.ec2.internal>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de>
* yapf isort
* resolve flake8
* fix apex doctests
* fix apex doctests 2
* resolve docs
* update drone
* clean env
* update
* update
* update
* update
* merge
* Fix RPC related tests, clean out old API, update for new accelerator API [skip ci] (#5881)
* Fix RPC related tests, clean out old API, update for new accelerator API
* Move tests out of legacy folder, update paths and names
* Update test_remove_1-4.py
* Expose properties for tpu cores/gpus/num_gpus
* Add root GPU property
* Move properties to properties.py
* move tests that were previously in drone
* Fix root GPU property (#5908)
* Move root GPU to property, remove horovod set as this is handled in horovod plugin, ensure we mock correctly to set GPU accelerator
* Add missing tests back
* fix best model path transfer when no checkpoint callback available
* Fix setup hook order [wip] (#5858)
* Call trainer setup hook before accelerator setup
* Add test case
* add new test
* typo
* fix callback order in test
Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* rename ddp sequential -> rpc sequential for special test
* revert
* fix stupid merge problem
* Use property in connector for sampler (#5913)
* merge the import conflicts
* fix spawning of processes in slurm
* [wip] Fix some bugs for TPU [skip ci] (#5878)
* fixed for single tpu
* fixed spawn
* fixed spawn
* update
* update
* wip
* resolve bugs
* resolve bug
* update on comment
* removed decorator
* resolve comments
* set to 4
* update
* update
* need cleaning
* update
* update
* update
* resolve flake8
* resolve bugs
* exclude broadcast
* resolve bugs
* change test
* update
* update
* skip if meet fails
* properly raise trace
* update
* add catch
* wrap test
* resolve typo
* update
* typo
Co-authored-by: Lezwon Castelino <lezwon@gmail.com>
Co-authored-by: Your Name <you@example.com>
* resolve some tests
* update
* fix imports
* update
* resolve flake8
* update azure pipeline
* skip a sharded test on cpu that requires a gpu
* resolve tpus
* resolve bug
* resolve flake8
* update
* updat utils
* revert permission change on files
* suggestions from carlos
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
* remove unrelated formatting changes
* remove incomplete comment
* Update pytorch_lightning/accelerators/__init__.py
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
* remove unrelated formatting change
* add types
* warn 1.7 ddp manual backward only if ddp kwarg unset
* yapf + isort
* pep8 unused imports
* fix cyclic import in docs
* Apply suggestions from code review
* typer in accelerator.py
* typo
* Apply suggestions from code review
* formatting
* update on comments
* update typo
* Update pytorch_lightning/trainer/properties.py
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* update
* update on comments
* resolve some comments
* update on comments
* resolve test
* add toggle_model
* update
* update on comments
* update doc
* typo
* update
* typo
* remove space
* update
* update on comments
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: justusschock <justus.schock@posteo.de>
Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: root <root@ip-172-31-88-60.ec2.internal>
Co-authored-by: Lezwon Castelino <lezwon@gmail.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
2021-02-16 21:00:35 +00:00
|
|
|
if optimizer is not None:
|
|
|
|
rank_zero_warn(
|
|
|
|
"`optimizer` argument to `manual_backward` is deprecated in v1.2 and will be removed in v1.4",
|
|
|
|
DeprecationWarning
|
|
|
|
)
|
|
|
|
|
2020-10-11 17:12:35 +00:00
|
|
|
# make sure we're using manual opt
|
|
|
|
self._verify_is_manual_optimization('manual_backward')
|
|
|
|
|
|
|
|
# backward
|
2020-11-10 19:44:51 +00:00
|
|
|
self._running_manual_backward = True
|
[accelerator][FeatBugFix] Improve manual optimization API (#5771)
* fix trainer.model access
* move properties
* fix test_transfer_batch_hook
* fix auto_select_gpus
* fix omegaconf test
* fix test that needs to simulate slurm ddp
* add horovod plugin
* fix test with named arguments
* clean up whitespace
* fix datamodules test
* remove old accelerators
* fix naming
* move old plugins
* move to plugins
* create precision subpackage
* create training_type subpackage
* fix all new import errors
* fix wrong arguments order passed to test
* fix LR finder
* Added sharded training type and amp plugin
* Move clip grad to precision plugin
* Added sharded spawn, select accelerators based on distributed_backend + enable custom fp16 plugin automatically
* Fix import issue, attempting to fix tests
* Fix initial test
* Reflect hook logic from master, should wrap model after move to device
* Optional state consolidation, since master has optimizers not wrapped
* change attribute for instance test
* reset optimizers
optimizers are not used in main process, so state would be wrong.
* legacy
* imports in accel
* legacy2
* trainer imports
* fix import errors after rebase
* move hook to new setup location
* provide unwrapping logic
* fix trainer callback system
* added ddp2 implementation
* fix imports .legacy
* move plugins
* restore legacy
* drop test.py from root
* add tpu accelerator and plugins
* fixes
* fix lightning optimizer merge
* reset bugreportmodel
* unwrapping
* step routing forward
* model access
* unwrap
* opt
* integrate distrib_type
* sync changes
* sync
* fixes
* add forgotten generators
* add missing logic
* update
* import
* missed imports
* import fixes
* isort
* mv f
* changelog
* format
* move helper to parallel plugin
* d
* add world size
* clean up
* duplicate
* activate ddp_sharded and tpu
* set nvidia flags
* remove unused colab var
* use_tpu <-> on_tpu attrs
* make some ddp_cpu and clusterplugin tests pass
* Ref/accelerator connector (#5742)
* final cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* connector cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* trainer cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* accelerator cleanup + missing logic in accelerator connector
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* add missing changes to callbacks
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* reflect accelerator changes to lightning module
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* clean cluster envs
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* cleanup plugins
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* add broadcasting
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* yapf
* remove plugin connector
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* plugins
* manual optimization
* update optimizer routing
* add rank to torchelastic
* fix memory mixed precision
* setstate on trainer for pickling in ddp spawn
* add predict method
* add back commented accelerator code
* adapt test for sync_batch_norm to new plugin
* fix deprecated tests
* fix ddp cpu choice when no num_processes are given
* yapf format
* skip a memory test that cannot pass anymore
* update on comments
* fix pickle error in spawn plugin
* x
* avoid
* x
* fix cyclic import in docs build
* add support for sharded
* update typing
* add sharded and sharded_spawn to distributed types
* make unwrap model default
* refactor LightningShardedDataParallel similar to LightningDistributedDataParallel
* update sharded spawn to reflect changes
* update sharded to reflect changes
* Merge 1.1.5 changes
* fix merge
* fix merge
* yapf isort
* fix merge
* yapf isort
* fix indentation in test
* copy over reinit scheduler implementation from dev1.2
* fix apex tracking calls with dev_debugger
* reduce diff to dev1.2, clean up
* fix trainer config test when gpus>0 and num_processes >0 and ddp_cpu
* sort plugin tests legacy/new
* fix error handling for amp on cpu
* fix merge
fix merge
fix merge
* [Feat] Resolve manual_backward (#5837)
* resolve manual_backward
* resolve flake8
* update
* resolve for ddp_spawn
* resolve flake8
* resolve flake8
* resolve flake8
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* fix tests/accelerator tests on cpu
* [BugFix] Resolve manual optimization (#5852)
* resolve manual_optimization
* update
* update
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* Remove copy trainer parameters to happen earlier within the loop and add safe guard to get ref model (#5856)
* resovle a bug
* Accelerator refactor sharded rpc (#5854)
* rpc branch
* merge
* update handling of rpc
* make devices etc. Optional in RPC
* set devices etc. later if necessary
* remove devices from sequential
* make devices optional in rpc
* fix import
* uncomment everything
* fix cluster selection
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* resolve bug
* fix assert in rpc test
* resolve a test
* fix docs compilation
* accelerator refactor - fix for sharded parity test (#5866)
* fix memory issue with ddp_spawn
* x
x
x
x
x
x
x
x
x
* x
* Remove DDP2 as this does not apply
* Add missing pre optimizer hook to ensure lambda closure is called
* fix apex docstring
* [accelerator][BugFix] Resolve some test for 1 gpu (#5863)
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* update
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* revert init
* update
* resolve flake8
* update
* update
* update
* update
* update
* all_gather
* update
* make plugins work, add misconfig for RPC
* update
* update
* remove breaking test
* resolve some tests
* resolve flake8
* revert to ddp_spawn
Co-authored-by: root <root@ip-172-31-88-60.ec2.internal>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de>
* yapf isort
* resolve flake8
* fix apex doctests
* fix apex doctests 2
* resolve docs
* update drone
* clean env
* update
* update
* update
* update
* merge
* Fix RPC related tests, clean out old API, update for new accelerator API [skip ci] (#5881)
* Fix RPC related tests, clean out old API, update for new accelerator API
* Move tests out of legacy folder, update paths and names
* Update test_remove_1-4.py
* Expose properties for tpu cores/gpus/num_gpus
* Add root GPU property
* Move properties to properties.py
* move tests that were previously in drone
* Fix root GPU property (#5908)
* Move root GPU to property, remove horovod set as this is handled in horovod plugin, ensure we mock correctly to set GPU accelerator
* Add missing tests back
* fix best model path transfer when no checkpoint callback available
* Fix setup hook order [wip] (#5858)
* Call trainer setup hook before accelerator setup
* Add test case
* add new test
* typo
* fix callback order in test
Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* rename ddp sequential -> rpc sequential for special test
* revert
* fix stupid merge problem
* Use property in connector for sampler (#5913)
* merge the import conflicts
* fix spawning of processes in slurm
* [wip] Fix some bugs for TPU [skip ci] (#5878)
* fixed for single tpu
* fixed spawn
* fixed spawn
* update
* update
* wip
* resolve bugs
* resolve bug
* update on comment
* removed decorator
* resolve comments
* set to 4
* update
* update
* need cleaning
* update
* update
* update
* resolve flake8
* resolve bugs
* exclude broadcast
* resolve bugs
* change test
* update
* update
* skip if meet fails
* properly raise trace
* update
* add catch
* wrap test
* resolve typo
* update
* typo
Co-authored-by: Lezwon Castelino <lezwon@gmail.com>
Co-authored-by: Your Name <you@example.com>
* resolve some tests
* update
* fix imports
* update
* resolve flake8
* update azure pipeline
* skip a sharded test on cpu that requires a gpu
* resolve tpus
* resolve bug
* resolve flake8
* update
* updat utils
* revert permission change on files
* suggestions from carlos
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
* remove unrelated formatting changes
* remove incomplete comment
* Update pytorch_lightning/accelerators/__init__.py
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
* remove unrelated formatting change
* add types
* warn 1.7 ddp manual backward only if ddp kwarg unset
* yapf + isort
* pep8 unused imports
* fix cyclic import in docs
* Apply suggestions from code review
* typer in accelerator.py
* typo
* Apply suggestions from code review
* formatting
* update on comments
* update typo
* Update pytorch_lightning/trainer/properties.py
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* update
* update on comments
* resolve some comments
* update on comments
* resolve test
* add toggle_model
* update
* update on comments
* update doc
* typo
* update
* typo
* remove space
* update
* update on comments
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: justusschock <justus.schock@posteo.de>
Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: root <root@ip-172-31-88-60.ec2.internal>
Co-authored-by: Lezwon Castelino <lezwon@gmail.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
2021-02-16 21:00:35 +00:00
|
|
|
self.trainer.train_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs)
|
2020-11-10 19:44:51 +00:00
|
|
|
self._running_manual_backward = False
|
|
|
|
|
2020-10-21 18:34:29 +00:00
|
|
|
def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
|
2020-10-10 22:44:24 +00:00
|
|
|
"""
|
|
|
|
Override backward with your own implementation if you need to.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
loss: Loss is already scaled by accumulated grads
|
|
|
|
optimizer: Current optimizer being used
|
|
|
|
optimizer_idx: Index of the current optimizer being used
|
|
|
|
|
|
|
|
Called to perform backward step.
|
|
|
|
Feel free to override as needed.
|
|
|
|
The loss passed in has already been scaled for accumulated gradients if requested.
|
|
|
|
|
|
|
|
Example::
|
|
|
|
|
2020-10-11 02:04:50 +00:00
|
|
|
def backward(self, loss, optimizer, optimizer_idx):
|
2020-10-10 22:44:24 +00:00
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
"""
|
2020-11-10 19:44:51 +00:00
|
|
|
if self.trainer.train_loop.automatic_optimization or self._running_manual_backward:
|
|
|
|
loss.backward(*args, **kwargs)
|
2020-10-10 22:44:24 +00:00
|
|
|
|
2020-10-10 18:35:25 +00:00
|
|
|
def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
|
|
|
|
"""
|
|
|
|
Makes sure only the gradients of the current optimizer's parameters are calculated
|
|
|
|
in the training step to prevent dangling gradients in multiple-optimizer setup.
|
|
|
|
|
|
|
|
.. note:: Only called when using multiple optimizers
|
|
|
|
|
|
|
|
Override for your own behavior
|
|
|
|
|
2021-01-25 23:45:49 +00:00
|
|
|
It works with ``untoggle_optimizer`` to make sure param_requires_grad_state is properly reset.
|
|
|
|
|
2020-10-10 18:35:25 +00:00
|
|
|
Args:
|
2021-01-25 23:45:49 +00:00
|
|
|
optimizer: Current optimizer used in training_loop
|
|
|
|
optimizer_idx: Current optimizer idx in training_loop
|
|
|
|
"""
|
2021-02-04 22:50:57 +00:00
|
|
|
|
|
|
|
# Iterate over all optimizer parameters to preserve their `requires_grad` information
|
|
|
|
# in case these are pre-defined during `configure_optimizers`
|
2021-01-25 23:45:49 +00:00
|
|
|
param_requires_grad_state = {}
|
2021-02-04 22:50:57 +00:00
|
|
|
for opt in self.optimizers(use_pl_optimizer=False):
|
2021-01-25 23:45:49 +00:00
|
|
|
for group in opt.param_groups:
|
|
|
|
for param in group['params']:
|
2021-02-04 22:50:57 +00:00
|
|
|
# If a param already appear in param_requires_grad_state, continue
|
|
|
|
if param in param_requires_grad_state:
|
|
|
|
continue
|
|
|
|
param_requires_grad_state[param] = param.requires_grad
|
|
|
|
param.requires_grad = False
|
|
|
|
|
|
|
|
# Then iterate over the current optimizer's parameters and set its `requires_grad`
|
|
|
|
# properties accordingly
|
|
|
|
for group in optimizer.param_groups:
|
|
|
|
for param in group['params']:
|
|
|
|
param.requires_grad = param_requires_grad_state[param]
|
2021-01-25 23:45:49 +00:00
|
|
|
self._param_requires_grad_state = param_requires_grad_state
|
|
|
|
|
|
|
|
def untoggle_optimizer(self, optimizer_idx: int):
|
2020-10-10 18:35:25 +00:00
|
|
|
"""
|
2021-01-25 23:45:49 +00:00
|
|
|
.. note:: Only called when using multiple optimizers
|
2020-10-10 18:35:25 +00:00
|
|
|
|
2021-01-25 23:45:49 +00:00
|
|
|
Override for your own behavior
|
|
|
|
|
|
|
|
Args:
|
|
|
|
optimizer_idx: Current optimizer idx in training_loop
|
|
|
|
"""
|
|
|
|
for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)):
|
|
|
|
if optimizer_idx != opt_idx:
|
|
|
|
for group in opt.param_groups:
|
|
|
|
for param in group['params']:
|
|
|
|
if param in self._param_requires_grad_state:
|
|
|
|
param.requires_grad = self._param_requires_grad_state[param]
|
|
|
|
# save memory
|
|
|
|
del self._param_requires_grad_state
|
2020-10-10 18:35:25 +00:00
|
|
|
|
2020-03-12 16:47:23 +00:00
|
|
|
def optimizer_step(
|
2020-07-24 15:42:15 +00:00
|
|
|
self,
|
2020-11-12 19:22:06 +00:00
|
|
|
epoch: int = None,
|
|
|
|
batch_idx: int = None,
|
|
|
|
optimizer: Optimizer = None,
|
|
|
|
optimizer_idx: int = None,
|
|
|
|
optimizer_closure: Optional[Callable] = None,
|
|
|
|
on_tpu: bool = None,
|
|
|
|
using_native_amp: bool = None,
|
|
|
|
using_lbfgs: bool = None,
|
2020-03-12 16:47:23 +00:00
|
|
|
) -> None:
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2020-04-06 12:12:44 +00:00
|
|
|
Override this method to adjust the default way the
|
|
|
|
:class:`~pytorch_lightning.trainer.trainer.Trainer` calls each optimizer.
|
|
|
|
By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example
|
2020-03-05 23:52:17 +00:00
|
|
|
once per optimizer.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-11-02 22:13:34 +00:00
|
|
|
Warning:
|
|
|
|
If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter
|
|
|
|
to ``optimizer.step()`` function as shown in the examples. This ensures that
|
|
|
|
``train_step_and_backward_closure`` is called within
|
|
|
|
:meth:`~pytorch_lightning.trainer.training_loop.TrainLoop.run_training_batch`.
|
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Args:
|
2020-03-12 16:47:23 +00:00
|
|
|
epoch: Current epoch
|
|
|
|
batch_idx: Index of current batch
|
|
|
|
optimizer: A PyTorch optimizer
|
2020-04-06 12:12:44 +00:00
|
|
|
optimizer_idx: If you used multiple optimizers this indexes into that list.
|
2020-10-21 18:34:29 +00:00
|
|
|
optimizer_closure: closure for all optimizers
|
2020-06-25 20:02:16 +00:00
|
|
|
on_tpu: true if TPU backward is required
|
|
|
|
using_native_amp: True if using native amp
|
|
|
|
using_lbfgs: True if the matching optimizer is lbfgs
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
Examples::
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# DEFAULT
|
|
|
|
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
|
|
|
|
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
|
|
|
|
optimizer.step(closure=optimizer_closure)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# Alternating schedule for optimizer steps (i.e.: GANs)
|
|
|
|
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
|
|
|
|
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
|
|
|
|
# update generator opt every 2 steps
|
|
|
|
if optimizer_idx == 0:
|
|
|
|
if batch_idx % 2 == 0 :
|
|
|
|
optimizer.step(closure=optimizer_closure)
|
|
|
|
optimizer.zero_grad()
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# update discriminator opt every 4 steps
|
|
|
|
if optimizer_idx == 1:
|
|
|
|
if batch_idx % 4 == 0 :
|
|
|
|
optimizer.step(closure=optimizer_closure)
|
|
|
|
optimizer.zero_grad()
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# ...
|
|
|
|
# add as many optimizers as you want
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
Here's another example showing how to use this for more advanced things such as
|
|
|
|
learning rate warm-up:
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
.. code-block:: python
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# learning rate warm-up
|
|
|
|
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx,
|
|
|
|
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
|
|
|
|
# warm up lr
|
|
|
|
if self.trainer.global_step < 500:
|
|
|
|
lr_scale = min(1., float(self.trainer.global_step + 1) / 500.)
|
|
|
|
for pg in optimizer.param_groups:
|
|
|
|
pg['lr'] = lr_scale * self.learning_rate
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
# update params
|
|
|
|
optimizer.step(closure=optimizer_closure)
|
|
|
|
optimizer.zero_grad()
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2019-08-13 13:32:45 +00:00
|
|
|
"""
|
2020-12-07 12:55:49 +00:00
|
|
|
if not isinstance(optimizer, LightningOptimizer):
|
|
|
|
# wraps into LightingOptimizer only for running step
|
PoC: Accelerator refactor (#5743)
* restoring the result from subprocess
* fix queue.get() order for results
* add missing "block_backward_sync" context manager
* add missing "block_backward_sync" context manager
* fix sync_batchnorm
* fix supported gpu-ids for tuple
* fix clip gradients and inf recursion
* accelerator selection: added cluster_environment plugin
* fix torchelastic test
* fix reduce early stopping decision for DDP
* fix tests: callbacks, conversion to lightning optimizer
* fix lightning optimizer does not pickle
* fix setting benchmark and deterministic option
* fix slurm amp test
* fix prepare_data test and determine node_rank
* fix retrieving last path when testing
* remove obsolete plugin argument
* fix test: test_trainer_config
* fix torchscript tests
* fix trainer.model access
* move properties
* fix test_transfer_batch_hook
* fix auto_select_gpus
* fix omegaconf test
* fix test that needs to simulate slurm ddp
* add horovod plugin
* fix test with named arguments
* clean up whitespace
* fix datamodules test
* remove old accelerators
* fix naming
* move old plugins
* move to plugins
* create precision subpackage
* create training_type subpackage
* fix all new import errors
* fix wrong arguments order passed to test
* fix LR finder
* Added sharded training type and amp plugin
* Move clip grad to precision plugin
* Added sharded spawn, select accelerators based on distributed_backend + enable custom fp16 plugin automatically
* Fix import issue, attempting to fix tests
* Fix initial test
* Reflect hook logic from master, should wrap model after move to device
* Optional state consolidation, since master has optimizers not wrapped
* change attribute for instance test
* reset optimizers
optimizers are not used in main process, so state would be wrong.
* legacy
* imports in accel
* legacy2
* trainer imports
* fix import errors after rebase
* move hook to new setup location
* provide unwrapping logic
* fix trainer callback system
* added ddp2 implementation
* fix imports .legacy
* move plugins
* restore legacy
* drop test.py from root
* add tpu accelerator and plugins
* fixes
* fix lightning optimizer merge
* reset bugreportmodel
* unwrapping
* step routing forward
* model access
* unwrap
* opt
* integrate distrib_type
* sync changes
* sync
* fixes
* add forgotten generators
* add missing logic
* update
* import
* missed imports
* import fixes
* isort
* mv f
* changelog
* format
* move helper to parallel plugin
* d
* add world size
* clean up
* duplicate
* activate ddp_sharded and tpu
* set nvidia flags
* remove unused colab var
* use_tpu <-> on_tpu attrs
* make some ddp_cpu and clusterplugin tests pass
* Ref/accelerator connector (#5742)
* final cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* connector cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* trainer cleanup
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* accelerator cleanup + missing logic in accelerator connector
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* add missing changes to callbacks
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* reflect accelerator changes to lightning module
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* clean cluster envs
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* cleanup plugins
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* add broadcasting
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* yapf
* remove plugin connector
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* plugins
* manual optimization
* update optimizer routing
* add rank to torchelastic
* fix memory mixed precision
* setstate on trainer for pickling in ddp spawn
* add predict method
* add back commented accelerator code
* adapt test for sync_batch_norm to new plugin
* fix deprecated tests
* fix ddp cpu choice when no num_processes are given
* yapf format
* skip a memory test that cannot pass anymore
* fix pickle error in spawn plugin
* x
* avoid
* x
* fix cyclic import in docs build
* add support for sharded
* update typing
* add sharded and sharded_spawn to distributed types
* make unwrap model default
* refactor LightningShardedDataParallel similar to LightningDistributedDataParallel
* update sharded spawn to reflect changes
* update sharded to reflect changes
* Merge 1.1.5 changes
* fix merge
* fix merge
* yapf isort
* fix merge
* yapf isort
* fix indentation in test
* copy over reinit scheduler implementation from dev1.2
* fix apex tracking calls with dev_debugger
* reduce diff to dev1.2, clean up
* fix trainer config test when gpus>0 and num_processes >0 and ddp_cpu
* sort plugin tests legacy/new
* fix error handling for amp on cpu
* fix merge
fix merge
fix merge
* [Feat] Resolve manual_backward (#5837)
* resolve manual_backward
* resolve flake8
* update
* resolve for ddp_spawn
* resolve flake8
* resolve flake8
* resolve flake8
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* fix tests/accelerator tests on cpu
* [BugFix] Resolve manual optimization (#5852)
* resolve manual_optimization
* update
* update
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* Remove copy trainer parameters to happen earlier within the loop and add safe guard to get ref model (#5856)
* resovle a bug
* Accelerator refactor sharded rpc (#5854)
* rpc branch
* merge
* update handling of rpc
* make devices etc. Optional in RPC
* set devices etc. later if necessary
* remove devices from sequential
* make devices optional in rpc
* fix import
* uncomment everything
* fix cluster selection
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
* resolve bug
* fix assert in rpc test
* resolve a test
* fix docs compilation
* accelerator refactor - fix for sharded parity test (#5866)
* fix memory issue with ddp_spawn
* x
x
x
x
x
x
x
x
x
* x
* Remove DDP2 as this does not apply
* Add missing pre optimizer hook to ensure lambda closure is called
* fix apex docstring
* [accelerator][BugFix] Resolve some test for 1 gpu (#5863)
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* update
* update
* revert init
* resolve a bug
* update
* resolve flake8
* update
* update
* update
* revert init
* update
* resolve flake8
* update
* update
* update
* update
* update
* all_gather
* update
* make plugins work, add misconfig for RPC
* update
* update
* remove breaking test
* resolve some tests
* resolve flake8
* revert to ddp_spawn
Co-authored-by: root <root@ip-172-31-88-60.ec2.internal>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de>
* yapf isort
* resolve flake8
* fix apex doctests
* fix apex doctests 2
* resolve docs
* update drone
* clean env
* update
* update
* update
* update
* merge
* Fix RPC related tests, clean out old API, update for new accelerator API [skip ci] (#5881)
* Fix RPC related tests, clean out old API, update for new accelerator API
* Move tests out of legacy folder, update paths and names
* Update test_remove_1-4.py
* Expose properties for tpu cores/gpus/num_gpus
* Add root GPU property
* Move properties to properties.py
* move tests that were previously in drone
* Fix root GPU property (#5908)
* Move root GPU to property, remove horovod set as this is handled in horovod plugin, ensure we mock correctly to set GPU accelerator
* Add missing tests back
* fix best model path transfer when no checkpoint callback available
* Fix setup hook order [wip] (#5858)
* Call trainer setup hook before accelerator setup
* Add test case
* add new test
* typo
* fix callback order in test
Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* rename ddp sequential -> rpc sequential for special test
* revert
* fix stupid merge problem
* Use property in connector for sampler (#5913)
* merge the import conflicts
* fix spawning of processes in slurm
* [wip] Fix some bugs for TPU [skip ci] (#5878)
* fixed for single tpu
* fixed spawn
* fixed spawn
* update
* update
* wip
* resolve bugs
* resolve bug
* update on comment
* removed decorator
* resolve comments
* set to 4
* update
* update
* need cleaning
* update
* update
* update
* resolve flake8
* resolve bugs
* exclude broadcast
* resolve bugs
* change test
* update
* update
* skip if meet fails
* properly raise trace
* update
* add catch
* wrap test
* resolve typo
* update
* typo
Co-authored-by: Lezwon Castelino <lezwon@gmail.com>
Co-authored-by: Your Name <you@example.com>
* resolve some tests
* update
* fix imports
* update
* resolve flake8
* update azure pipeline
* skip a sharded test on cpu that requires a gpu
* resolve tpus
* resolve bug
* resolve flake8
* update
* updat utils
* revert permission change on files
* suggestions from carlos
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
* remove unrelated formatting changes
* remove incomplete comment
* Update pytorch_lightning/accelerators/__init__.py
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
* remove unrelated formatting change
* add types
* warn 1.7 ddp manual backward only if ddp kwarg unset
* yapf + isort
* pep8 unused imports
* fix cyclic import in docs
* Apply suggestions from code review
* typer in accelerator.py
* typo
* Apply suggestions from code review
* formatting
* update on comments
* update typo
* Update pytorch_lightning/trainer/properties.py
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
* update
* suggestion from code review
* suggestion from code review
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: root <root@ip-172-31-88-60.ec2.internal>
Co-authored-by: Lezwon Castelino <lezwon@gmail.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
2021-02-12 20:48:56 +00:00
|
|
|
optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, optimizer_idx)
|
2020-12-11 19:24:59 +00:00
|
|
|
optimizer.step(closure=optimizer_closure)
|
2020-04-16 16:01:41 +00:00
|
|
|
|
2021-02-08 19:29:43 +00:00
|
|
|
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
|
2019-08-13 13:32:45 +00:00
|
|
|
optimizer.zero_grad()
|
|
|
|
|
2020-03-12 16:47:23 +00:00
|
|
|
def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list:
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2020-03-05 23:52:17 +00:00
|
|
|
When using truncated backpropagation through time, each batch must be split along the
|
2020-04-06 12:12:44 +00:00
|
|
|
time dimension. Lightning handles this by default, but for custom behavior override
|
2020-03-05 23:52:17 +00:00
|
|
|
this function.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Args:
|
2020-03-12 16:47:23 +00:00
|
|
|
batch: Current batch
|
2020-04-06 12:12:44 +00:00
|
|
|
split_size: The size of the split
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Return:
|
2020-04-06 12:12:44 +00:00
|
|
|
List of batch splits. Each split will be passed to :meth:`training_step` to enable truncated
|
2020-01-17 11:03:31 +00:00
|
|
|
back propagation through time. The default implementation splits root level Tensors and
|
|
|
|
Sequences at dim=1 (i.e. time dim). It assumes that each time dim is the same length.
|
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
Examples::
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
def tbptt_split_batch(self, batch, split_size):
|
|
|
|
splits = []
|
|
|
|
for t in range(0, time_dims[0], split_size):
|
|
|
|
batch_split = []
|
|
|
|
for i, x in enumerate(batch):
|
|
|
|
if isinstance(x, torch.Tensor):
|
|
|
|
split_x = x[:, t:t + split_size]
|
|
|
|
elif isinstance(x, collections.Sequence):
|
|
|
|
split_x = [None] * len(x)
|
|
|
|
for batch_idx in range(len(x)):
|
|
|
|
split_x[batch_idx] = x[batch_idx][t:t + split_size]
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
batch_split.append(split_x)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
splits.append(batch_split)
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
return splits
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-04-06 12:12:44 +00:00
|
|
|
Note:
|
|
|
|
Called in the training loop after
|
|
|
|
:meth:`~pytorch_lightning.callbacks.base.Callback.on_batch_start`
|
|
|
|
if :paramref:`~pytorch_lightning.trainer.Trainer.truncated_bptt_steps` > 0.
|
|
|
|
Each returned batch split is passed separately to :meth:`training_step`.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2019-10-31 10:45:28 +00:00
|
|
|
"""
|
2021-02-08 19:29:43 +00:00
|
|
|
time_dims = [len(x[0]) for x in batch if isinstance(x, (torch.Tensor, collections.Sequence))]
|
2019-10-31 10:45:28 +00:00
|
|
|
assert len(time_dims) >= 1, "Unable to determine batch time dimension"
|
2021-02-08 19:29:43 +00:00
|
|
|
assert all(x == time_dims[0] for x in time_dims), "Batch time dimension length is ambiguous"
|
2019-10-31 10:45:28 +00:00
|
|
|
|
|
|
|
splits = []
|
|
|
|
for t in range(0, time_dims[0], split_size):
|
|
|
|
batch_split = []
|
|
|
|
for i, x in enumerate(batch):
|
|
|
|
if isinstance(x, torch.Tensor):
|
2021-02-08 19:29:43 +00:00
|
|
|
split_x = x[:, t:t + split_size]
|
2019-10-31 10:45:28 +00:00
|
|
|
elif isinstance(x, collections.Sequence):
|
|
|
|
split_x = [None] * len(x)
|
|
|
|
for batch_idx in range(len(x)):
|
2021-02-08 19:29:43 +00:00
|
|
|
split_x[batch_idx] = x[batch_idx][t:t + split_size]
|
2019-10-31 10:45:28 +00:00
|
|
|
|
|
|
|
batch_split.append(split_x)
|
|
|
|
|
|
|
|
splits.append(batch_split)
|
|
|
|
|
|
|
|
return splits
|
|
|
|
|
2021-01-05 07:43:18 +00:00
|
|
|
def summarize(self, mode: Optional[str] = ModelSummary.MODE_DEFAULT) -> Optional[ModelSummary]:
|
|
|
|
model_summary = None
|
|
|
|
|
|
|
|
if mode in ModelSummary.MODES:
|
|
|
|
model_summary = ModelSummary(self, mode=mode)
|
|
|
|
log.info("\n" + str(model_summary))
|
|
|
|
elif mode is not None:
|
2021-02-08 19:29:43 +00:00
|
|
|
raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}")
|
2021-01-05 07:43:18 +00:00
|
|
|
|
2020-06-15 21:05:58 +00:00
|
|
|
return model_summary
|
2019-07-25 16:01:52 +00:00
|
|
|
|
2020-03-12 16:47:23 +00:00
|
|
|
def freeze(self) -> None:
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
2020-04-06 12:12:44 +00:00
|
|
|
Freeze all params for inference.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
Example::
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2021-01-26 09:44:54 +00:00
|
|
|
model = MyLightningModule(...)
|
|
|
|
model.freeze()
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
"""
|
2019-07-25 16:01:52 +00:00
|
|
|
for param in self.parameters():
|
|
|
|
param.requires_grad = False
|
|
|
|
|
2019-11-05 14:14:33 +00:00
|
|
|
self.eval()
|
|
|
|
|
2020-03-12 16:47:23 +00:00
|
|
|
def unfreeze(self) -> None:
|
2020-04-06 12:12:44 +00:00
|
|
|
"""
|
|
|
|
Unfreeze all parameters for training.
|
2019-11-28 17:48:55 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
model = MyLightningModule(...)
|
|
|
|
model.unfreeze()
|
|
|
|
|
|
|
|
"""
|
2019-07-25 16:01:52 +00:00
|
|
|
for param in self.parameters():
|
|
|
|
param.requires_grad = True
|
2019-11-05 14:14:33 +00:00
|
|
|
|
|
|
|
self.train()
|
2019-11-28 17:48:55 +00:00
|
|
|
|
2020-04-24 00:46:18 +00:00
|
|
|
def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]:
|
2020-02-05 11:24:43 +00:00
|
|
|
r"""
|
2020-07-28 20:32:34 +00:00
|
|
|
Implement this to override the default items displayed in the progress bar.
|
|
|
|
By default it includes the average loss value, split index of BPTT (if used)
|
|
|
|
and the version of the experiment when using a logger.
|
|
|
|
|
|
|
|
.. code-block::
|
|
|
|
|
|
|
|
Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, loss=4.501, v_num=10]
|
|
|
|
|
|
|
|
Here is an example how to override the defaults:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def get_progress_bar_dict(self):
|
|
|
|
# don't show the version number
|
|
|
|
items = super().get_progress_bar_dict()
|
|
|
|
items.pop("v_num", None)
|
|
|
|
return items
|
2020-02-05 11:24:43 +00:00
|
|
|
|
|
|
|
Return:
|
|
|
|
Dictionary with the items to be displayed in the progress bar.
|
|
|
|
"""
|
2020-03-30 00:20:34 +00:00
|
|
|
# call .item() only once but store elements without graphs
|
2020-09-10 11:24:42 +00:00
|
|
|
running_train_loss = self.trainer.train_loop.running_loss.mean()
|
2020-12-16 21:07:35 +00:00
|
|
|
avg_training_loss = None
|
|
|
|
if running_train_loss is not None:
|
|
|
|
avg_training_loss = running_train_loss.cpu().item()
|
|
|
|
elif self.trainer.train_loop.automatic_optimization:
|
|
|
|
avg_training_loss = float('NaN')
|
|
|
|
|
|
|
|
tqdm_dict = {}
|
|
|
|
if avg_training_loss is not None:
|
|
|
|
tqdm_dict["loss"] = f"{avg_training_loss:.3g}"
|
2020-02-05 11:24:43 +00:00
|
|
|
|
|
|
|
if self.trainer.truncated_bptt_steps is not None:
|
2020-09-21 02:59:21 +00:00
|
|
|
tqdm_dict["split_idx"] = self.trainer.split_idx
|
2020-02-05 11:24:43 +00:00
|
|
|
|
|
|
|
if self.trainer.logger is not None and self.trainer.logger.version is not None:
|
2020-07-28 20:32:34 +00:00
|
|
|
version = self.trainer.logger.version
|
|
|
|
# show last 4 places of long version strings
|
|
|
|
version = version[-4:] if isinstance(version, str) else version
|
2020-09-21 02:59:21 +00:00
|
|
|
tqdm_dict["v_num"] = version
|
2020-02-05 11:24:43 +00:00
|
|
|
|
|
|
|
return tqdm_dict
|
2020-04-24 00:46:18 +00:00
|
|
|
|
2020-10-11 17:12:35 +00:00
|
|
|
def _verify_is_manual_optimization(self, fn_name):
|
|
|
|
if self.trainer.train_loop.automatic_optimization:
|
2020-12-10 10:01:33 +00:00
|
|
|
raise MisconfigurationException(
|
|
|
|
f'to use {fn_name}, please disable automatic optimization:'
|
|
|
|
' set model property `automatic_optimization` as False'
|
|
|
|
)
|
2020-10-11 17:12:35 +00:00
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
@classmethod
|
|
|
|
def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]:
|
2020-06-04 12:35:50 +00:00
|
|
|
"""
|
|
|
|
Collect all module arguments in the current constructor and all child constructors.
|
|
|
|
The child constructors are all the ``__init__`` methods that reach the current class through
|
|
|
|
(chained) ``super().__init__()`` calls.
|
2020-06-08 11:19:34 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
frame: instance frame
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
self_arguments: arguments dictionary of the first instance
|
|
|
|
parents_arguments: arguments dictionary of the parent's instances
|
2020-06-04 12:35:50 +00:00
|
|
|
"""
|
2020-06-08 11:19:34 +00:00
|
|
|
if not frame:
|
|
|
|
frame = inspect.currentframe()
|
2020-05-24 22:59:08 +00:00
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
frame_args = collect_init_args(frame.f_back, [])
|
2020-06-04 12:35:50 +00:00
|
|
|
self_arguments = frame_args[-1]
|
2020-05-24 22:59:08 +00:00
|
|
|
|
2020-11-03 11:13:10 +00:00
|
|
|
# set hyper_parameters in child
|
2020-06-08 11:19:34 +00:00
|
|
|
self_arguments = self_arguments
|
|
|
|
parents_arguments = {}
|
2020-06-04 12:35:50 +00:00
|
|
|
|
|
|
|
# add all arguments from parents
|
2020-05-24 22:59:08 +00:00
|
|
|
for args in frame_args[:-1]:
|
2020-06-08 11:19:34 +00:00
|
|
|
parents_arguments.update(args)
|
|
|
|
return self_arguments, parents_arguments
|
2020-05-24 22:59:08 +00:00
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
def save_hyperparameters(self, *args, frame=None) -> None:
|
|
|
|
"""Save all model arguments.
|
2020-06-04 12:35:50 +00:00
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
Args:
|
|
|
|
args: single object of `dict`, `NameSpace` or `OmegaConf`
|
2021-02-01 13:09:01 +00:00
|
|
|
or string names or arguments from class `__init__`
|
2020-06-08 11:19:34 +00:00
|
|
|
|
|
|
|
>>> class ManuallyArgsModel(LightningModule):
|
|
|
|
... def __init__(self, arg1, arg2, arg3):
|
|
|
|
... super().__init__()
|
2020-08-09 19:00:08 +00:00
|
|
|
... # manually assign arguments
|
2020-06-08 11:19:34 +00:00
|
|
|
... self.save_hyperparameters('arg1', 'arg3')
|
|
|
|
... def forward(self, *args, **kwargs):
|
|
|
|
... ...
|
|
|
|
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
|
|
|
|
>>> model.hparams
|
|
|
|
"arg1": 1
|
|
|
|
"arg3": 3.14
|
|
|
|
|
|
|
|
>>> class AutomaticArgsModel(LightningModule):
|
|
|
|
... def __init__(self, arg1, arg2, arg3):
|
|
|
|
... super().__init__()
|
|
|
|
... # equivalent automatic
|
|
|
|
... self.save_hyperparameters()
|
|
|
|
... def forward(self, *args, **kwargs):
|
|
|
|
... ...
|
|
|
|
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
|
|
|
|
>>> model.hparams
|
|
|
|
"arg1": 1
|
|
|
|
"arg2": abc
|
|
|
|
"arg3": 3.14
|
|
|
|
|
|
|
|
>>> class SingleArgModel(LightningModule):
|
|
|
|
... def __init__(self, params):
|
|
|
|
... super().__init__()
|
|
|
|
... # manually assign single argument
|
|
|
|
... self.save_hyperparameters(params)
|
|
|
|
... def forward(self, *args, **kwargs):
|
|
|
|
... ...
|
|
|
|
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
|
|
|
|
>>> model.hparams
|
|
|
|
"p1": 1
|
|
|
|
"p2": abc
|
|
|
|
"p3": 3.14
|
2020-06-04 12:35:50 +00:00
|
|
|
"""
|
2020-06-08 11:19:34 +00:00
|
|
|
if not frame:
|
|
|
|
frame = inspect.currentframe().f_back
|
|
|
|
init_args = get_init_args(frame)
|
2020-09-21 02:59:21 +00:00
|
|
|
assert init_args, "failed to inspect the self init"
|
2020-06-08 11:19:34 +00:00
|
|
|
if not args:
|
2020-10-15 12:30:49 +00:00
|
|
|
# take all arguments
|
2020-06-08 11:19:34 +00:00
|
|
|
hp = init_args
|
2020-09-21 02:59:21 +00:00
|
|
|
self._hparams_name = "kwargs" if hp else None
|
2020-06-08 11:19:34 +00:00
|
|
|
else:
|
2020-10-15 12:30:49 +00:00
|
|
|
# take only listed arguments in `save_hparams`
|
2020-06-08 11:19:34 +00:00
|
|
|
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
|
|
|
|
if len(isx_non_str) == 1:
|
|
|
|
hp = args[isx_non_str[0]]
|
|
|
|
cand_names = [k for k, v in init_args.items() if v == hp]
|
|
|
|
self._hparams_name = cand_names[0] if cand_names else None
|
|
|
|
else:
|
|
|
|
hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)}
|
2020-09-21 02:59:21 +00:00
|
|
|
self._hparams_name = "kwargs"
|
2020-06-08 11:19:34 +00:00
|
|
|
|
|
|
|
# `hparams` are expected here
|
|
|
|
if hp:
|
|
|
|
self._set_hparams(hp)
|
2020-10-15 12:30:49 +00:00
|
|
|
# make deep copy so there is not other runtime changes reflected
|
|
|
|
self._hparams_initial = copy.deepcopy(self._hparams)
|
2020-06-08 11:19:34 +00:00
|
|
|
|
|
|
|
def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
|
|
|
|
if isinstance(hp, Namespace):
|
|
|
|
hp = vars(hp)
|
|
|
|
if isinstance(hp, dict):
|
|
|
|
hp = AttributeDict(hp)
|
|
|
|
elif isinstance(hp, PRIMITIVE_TYPES):
|
2020-09-21 02:59:21 +00:00
|
|
|
raise ValueError(f"Primitives {PRIMITIVE_TYPES} are not allowed.")
|
2020-06-08 11:19:34 +00:00
|
|
|
elif not isinstance(hp, ALLOWED_CONFIG_TYPES):
|
2020-09-21 02:59:21 +00:00
|
|
|
raise ValueError(f"Unsupported config type of {type(hp)}.")
|
2020-06-08 11:19:34 +00:00
|
|
|
|
|
|
|
if isinstance(hp, dict) and isinstance(self.hparams, dict):
|
|
|
|
self.hparams.update(hp)
|
|
|
|
else:
|
|
|
|
self._hparams = hp
|
2020-06-04 12:35:50 +00:00
|
|
|
|
2020-12-12 10:17:03 +00:00
|
|
|
@torch.no_grad()
|
|
|
|
def to_onnx(
|
|
|
|
self,
|
|
|
|
file_path: Union[str, Path],
|
|
|
|
input_sample: Optional[Any] = None,
|
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
Saves the model in ONNX format
|
2020-07-31 10:27:57 +00:00
|
|
|
|
|
|
|
Args:
|
2020-12-12 10:17:03 +00:00
|
|
|
file_path: The path of the file the onnx model should be saved to.
|
|
|
|
input_sample: An input for tracing. Default: None (Use self.example_input_array)
|
2020-07-31 10:27:57 +00:00
|
|
|
**kwargs: Will be passed to torch.onnx.export function.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> class SimpleModel(LightningModule):
|
|
|
|
... def __init__(self):
|
|
|
|
... super().__init__()
|
|
|
|
... self.l1 = torch.nn.Linear(in_features=64, out_features=4)
|
|
|
|
...
|
|
|
|
... def forward(self, x):
|
|
|
|
... return torch.relu(self.l1(x.view(x.size(0), -1)))
|
|
|
|
|
|
|
|
>>> with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile:
|
|
|
|
... model = SimpleModel()
|
|
|
|
... input_sample = torch.randn((1, 64))
|
|
|
|
... model.to_onnx(tmpfile.name, input_sample, export_params=True)
|
|
|
|
... os.path.isfile(tmpfile.name)
|
|
|
|
True
|
|
|
|
"""
|
2020-12-12 10:17:03 +00:00
|
|
|
mode = self.training
|
2020-07-31 10:27:57 +00:00
|
|
|
|
2020-12-12 10:17:03 +00:00
|
|
|
if input_sample is None:
|
|
|
|
if self.example_input_array is None:
|
2020-09-21 02:59:21 +00:00
|
|
|
raise ValueError(
|
2020-12-12 10:17:03 +00:00
|
|
|
"Could not export to ONNX since neither `input_sample` nor"
|
|
|
|
" `model.example_input_array` attribute is set."
|
2020-09-21 02:59:21 +00:00
|
|
|
)
|
2020-12-12 10:17:03 +00:00
|
|
|
input_sample = self.example_input_array
|
|
|
|
|
2021-02-18 11:58:12 +00:00
|
|
|
input_sample = self._apply_batch_transfer_handler(input_sample)
|
2020-12-12 10:17:03 +00:00
|
|
|
|
2020-09-21 02:59:21 +00:00
|
|
|
if "example_outputs" not in kwargs:
|
2020-07-31 10:27:57 +00:00
|
|
|
self.eval()
|
2020-12-12 10:17:03 +00:00
|
|
|
kwargs["example_outputs"] = self(input_sample)
|
2020-07-31 10:27:57 +00:00
|
|
|
|
2020-12-12 10:17:03 +00:00
|
|
|
torch.onnx.export(self, input_sample, file_path, **kwargs)
|
|
|
|
self.train(mode)
|
2020-07-31 10:27:57 +00:00
|
|
|
|
2020-12-12 10:17:03 +00:00
|
|
|
@torch.no_grad()
|
2020-09-21 02:59:21 +00:00
|
|
|
def to_torchscript(
|
2020-12-12 10:17:03 +00:00
|
|
|
self,
|
|
|
|
file_path: Optional[Union[str, Path]] = None,
|
|
|
|
method: Optional[str] = 'script',
|
|
|
|
example_inputs: Optional[Any] = None,
|
|
|
|
**kwargs,
|
2020-09-21 02:59:21 +00:00
|
|
|
) -> Union[ScriptModule, Dict[str, ScriptModule]]:
|
2020-09-03 18:24:44 +00:00
|
|
|
"""
|
|
|
|
By default compiles the whole model to a :class:`~torch.jit.ScriptModule`.
|
2020-10-14 13:20:52 +00:00
|
|
|
If you want to use tracing, please provided the argument `method='trace'` and make sure that either the
|
|
|
|
example_inputs argument is provided, or the model has self.example_input_array set.
|
|
|
|
If you would like to customize the modules that are scripted you should override this method.
|
|
|
|
In case you want to return multiple modules, we recommend using a dictionary.
|
2020-09-03 18:24:44 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
file_path: Path where to save the torchscript. Default: None (no file saved).
|
2020-10-14 13:20:52 +00:00
|
|
|
method: Whether to use TorchScript's script or trace method. Default: 'script'
|
2020-12-12 10:17:03 +00:00
|
|
|
example_inputs: An input to be used to do tracing when method is set to 'trace'.
|
2020-10-14 13:20:52 +00:00
|
|
|
Default: None (Use self.example_input_array)
|
|
|
|
**kwargs: Additional arguments that will be passed to the :func:`torch.jit.script` or
|
|
|
|
:func:`torch.jit.trace` function.
|
2020-09-03 18:24:44 +00:00
|
|
|
|
|
|
|
Note:
|
|
|
|
- Requires the implementation of the
|
|
|
|
:meth:`~pytorch_lightning.core.lightning.LightningModule.forward` method.
|
|
|
|
- The exported script will be set to evaluation mode.
|
|
|
|
- It is recommended that you install the latest supported version of PyTorch
|
|
|
|
to use this feature without limitations. See also the :mod:`torch.jit`
|
|
|
|
documentation for supported features.
|
|
|
|
|
|
|
|
Example:
|
|
|
|
>>> class SimpleModel(LightningModule):
|
|
|
|
... def __init__(self):
|
|
|
|
... super().__init__()
|
|
|
|
... self.l1 = torch.nn.Linear(in_features=64, out_features=4)
|
|
|
|
...
|
|
|
|
... def forward(self, x):
|
|
|
|
... return torch.relu(self.l1(x.view(x.size(0), -1)))
|
|
|
|
...
|
|
|
|
>>> model = SimpleModel()
|
|
|
|
>>> torch.jit.save(model.to_torchscript(), "model.pt") # doctest: +SKIP
|
|
|
|
>>> os.path.isfile("model.pt") # doctest: +SKIP
|
2020-10-29 05:46:57 +00:00
|
|
|
>>> torch.jit.save(model.to_torchscript(file_path="model_trace.pt", method='trace', # doctest: +SKIP
|
|
|
|
... example_inputs=torch.randn(1, 64))) # doctest: +SKIP
|
|
|
|
>>> os.path.isfile("model_trace.pt") # doctest: +SKIP
|
2020-09-03 18:24:44 +00:00
|
|
|
True
|
|
|
|
|
|
|
|
Return:
|
|
|
|
This LightningModule as a torchscript, regardless of whether file_path is
|
|
|
|
defined or not.
|
|
|
|
"""
|
|
|
|
mode = self.training
|
2020-12-12 10:17:03 +00:00
|
|
|
|
|
|
|
if method == 'script':
|
|
|
|
torchscript_module = torch.jit.script(self.eval(), **kwargs)
|
|
|
|
elif method == 'trace':
|
|
|
|
# if no example inputs are provided, try to see if model has example_input_array set
|
|
|
|
if example_inputs is None:
|
|
|
|
if self.example_input_array is None:
|
|
|
|
raise ValueError(
|
|
|
|
'Choosing method=`trace` requires either `example_inputs`'
|
2021-02-18 11:58:12 +00:00
|
|
|
' or `model.example_input_array` to be defined.'
|
2020-12-12 10:17:03 +00:00
|
|
|
)
|
|
|
|
example_inputs = self.example_input_array
|
|
|
|
|
|
|
|
# automatically send example inputs to the right device and use trace
|
2021-02-18 11:58:12 +00:00
|
|
|
example_inputs = self._apply_batch_transfer_handler(example_inputs)
|
2020-12-12 10:17:03 +00:00
|
|
|
torchscript_module = torch.jit.trace(func=self.eval(), example_inputs=example_inputs, **kwargs)
|
|
|
|
else:
|
2021-02-18 11:58:12 +00:00
|
|
|
raise ValueError(f"The 'method' parameter only supports 'script' or 'trace', but value given was: {method}")
|
2020-12-12 10:17:03 +00:00
|
|
|
|
2020-09-03 18:24:44 +00:00
|
|
|
self.train(mode)
|
|
|
|
|
|
|
|
if file_path is not None:
|
2020-10-14 13:20:52 +00:00
|
|
|
torch.jit.save(torchscript_module, file_path)
|
2020-09-03 18:24:44 +00:00
|
|
|
|
2020-10-14 13:20:52 +00:00
|
|
|
return torchscript_module
|
2020-09-03 18:24:44 +00:00
|
|
|
|
2020-06-08 11:19:34 +00:00
|
|
|
@property
|
2020-10-15 12:30:49 +00:00
|
|
|
def hparams(self) -> Union[AttributeDict, dict, Namespace]:
|
2020-09-21 02:59:21 +00:00
|
|
|
if not hasattr(self, "_hparams"):
|
2020-06-08 11:19:34 +00:00
|
|
|
self._hparams = AttributeDict()
|
|
|
|
return self._hparams
|
|
|
|
|
2020-10-15 12:30:49 +00:00
|
|
|
@property
|
|
|
|
def hparams_initial(self) -> AttributeDict:
|
|
|
|
if not hasattr(self, "_hparams_initial"):
|
2020-10-16 12:57:21 +00:00
|
|
|
return AttributeDict()
|
2020-10-15 12:30:49 +00:00
|
|
|
# prevent any change
|
|
|
|
return copy.deepcopy(self._hparams_initial)
|
|
|
|
|
2021-02-11 12:04:57 +00:00
|
|
|
@property
|
|
|
|
def model_size(self) -> float:
|
|
|
|
# todo: think about better way without need to dump model to drive
|
|
|
|
tmp_name = f"{uuid.uuid4().hex}.pt"
|
|
|
|
torch.save(self.state_dict(), tmp_name)
|
|
|
|
size_mb = os.path.getsize(tmp_name) / 1e6
|
|
|
|
os.remove(tmp_name)
|
|
|
|
return size_mb
|