2020-07-25 18:38:51 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2020-03-14 17:01:57 +00:00
|
|
|
import inspect
|
2019-07-09 00:11:20 +00:00
|
|
|
import os
|
2020-07-24 15:42:15 +00:00
|
|
|
import warnings
|
2020-05-14 21:56:11 +00:00
|
|
|
from argparse import ArgumentParser, Namespace
|
2020-07-24 15:42:15 +00:00
|
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
2019-07-09 00:11:20 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
import torch
|
2020-03-17 00:50:36 +00:00
|
|
|
import torch.distributed as torch_distrib
|
2020-02-23 02:23:30 +00:00
|
|
|
from torch.utils.data import DataLoader
|
2019-07-09 00:11:20 +00:00
|
|
|
|
2020-07-24 15:42:15 +00:00
|
|
|
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
|
|
|
|
from pytorch_lightning.core.datamodule import LightningDataModule
|
2020-03-24 18:55:27 +00:00
|
|
|
from pytorch_lightning.core.lightning import LightningModule
|
2020-06-15 21:05:58 +00:00
|
|
|
from pytorch_lightning.core.memory import ModelSummary
|
2020-07-24 15:42:15 +00:00
|
|
|
from pytorch_lightning.core.step_result import EvalResult
|
2020-02-23 02:23:30 +00:00
|
|
|
from pytorch_lightning.loggers import LightningLoggerBase
|
2020-07-24 15:42:15 +00:00
|
|
|
from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, SimpleProfiler
|
2020-08-08 09:07:32 +00:00
|
|
|
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
|
2019-12-04 16:39:14 +00:00
|
|
|
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
|
2020-03-06 17:00:05 +00:00
|
|
|
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
|
2020-08-07 22:33:51 +00:00
|
|
|
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
|
2020-03-24 18:55:27 +00:00
|
|
|
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
|
2020-07-30 21:19:28 +00:00
|
|
|
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_10
|
2020-03-24 18:55:27 +00:00
|
|
|
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
|
2020-07-24 15:42:15 +00:00
|
|
|
from pytorch_lightning.trainer.distrib_parts import (TrainerDPMixin, _parse_gpu_ids, _parse_tpu_cores,
|
|
|
|
determine_root_gpu_device, pick_multiple_gpus)
|
2019-12-04 16:39:14 +00:00
|
|
|
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
|
|
|
|
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
|
2020-07-24 15:42:15 +00:00
|
|
|
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
|
2019-12-04 16:39:14 +00:00
|
|
|
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
|
2020-04-02 15:48:53 +00:00
|
|
|
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
|
2020-08-09 10:24:09 +00:00
|
|
|
from pytorch_lightning.trainer.states import TrainerState, trainer_state
|
2020-04-08 12:35:47 +00:00
|
|
|
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
2020-01-14 03:20:38 +00:00
|
|
|
from pytorch_lightning.trainer.training_io import TrainerIOMixin
|
2020-01-20 19:50:31 +00:00
|
|
|
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
|
2019-12-04 16:39:14 +00:00
|
|
|
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
|
2020-08-08 09:07:32 +00:00
|
|
|
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType
|
2020-07-20 23:00:20 +00:00
|
|
|
from pytorch_lightning.utilities.debugging import InternalDebugger
|
2020-07-24 15:42:15 +00:00
|
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
2020-09-03 12:19:20 +00:00
|
|
|
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
2020-08-25 01:27:11 +00:00
|
|
|
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
|
2020-08-31 15:08:22 +00:00
|
|
|
from pytorch_lightning.trainer.data_connector import DataConnector
|
2020-09-01 19:48:28 +00:00
|
|
|
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
|
2020-09-01 22:03:28 +00:00
|
|
|
from pytorch_lightning.trainer.training_loop_temp import TrainLoop
|
2020-09-06 21:50:47 +00:00
|
|
|
from pytorch_lightning import _logger as log
|
2020-09-01 19:48:28 +00:00
|
|
|
|
2020-08-31 16:12:02 +00:00
|
|
|
from pytorch_lightning.utilities.model_utils import is_overridden
|
2020-06-30 22:09:16 +00:00
|
|
|
|
2020-07-09 15:36:21 +00:00
|
|
|
# warnings to ignore in trainer
|
2020-07-24 15:42:15 +00:00
|
|
|
warnings.filterwarnings(
|
|
|
|
'ignore', message='torch.distributed.reduce_op is deprecated, ' 'please use torch.distributed.ReduceOp instead'
|
|
|
|
)
|
2019-10-04 19:35:02 +00:00
|
|
|
|
2019-05-14 00:40:07 +00:00
|
|
|
try:
|
|
|
|
from apex import amp
|
2019-08-05 21:28:04 +00:00
|
|
|
except ImportError:
|
2020-08-08 09:07:32 +00:00
|
|
|
amp = None
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-02-17 21:01:20 +00:00
|
|
|
try:
|
|
|
|
import torch_xla
|
|
|
|
import torch_xla.core.xla_model as xm
|
|
|
|
import torch_xla.distributed.xla_multiprocessing as xmp
|
|
|
|
except ImportError:
|
|
|
|
XLA_AVAILABLE = False
|
2020-02-27 21:21:14 +00:00
|
|
|
else:
|
|
|
|
XLA_AVAILABLE = True
|
2020-02-17 21:01:20 +00:00
|
|
|
|
2020-05-04 17:02:57 +00:00
|
|
|
try:
|
|
|
|
import horovod.torch as hvd
|
2020-06-27 01:38:25 +00:00
|
|
|
except (ModuleNotFoundError, ImportError):
|
2020-05-04 17:02:57 +00:00
|
|
|
HOROVOD_AVAILABLE = False
|
|
|
|
else:
|
|
|
|
HOROVOD_AVAILABLE = True
|
|
|
|
|
2019-07-09 00:12:27 +00:00
|
|
|
|
2020-03-06 17:00:05 +00:00
|
|
|
class Trainer(
|
|
|
|
TrainerIOMixin,
|
2020-06-19 15:00:46 +00:00
|
|
|
TrainerCallbackHookMixin,
|
|
|
|
TrainerModelHooksMixin,
|
2020-04-02 15:48:53 +00:00
|
|
|
TrainerOptimizersMixin,
|
2020-04-06 12:13:24 +00:00
|
|
|
TrainerAMPMixin,
|
2020-03-06 17:00:05 +00:00
|
|
|
TrainerDPMixin,
|
|
|
|
TrainerDDPMixin,
|
|
|
|
TrainerLoggingMixin,
|
|
|
|
TrainerTrainingTricksMixin,
|
|
|
|
TrainerDataLoadingMixin,
|
|
|
|
TrainerEvaluationLoopMixin,
|
|
|
|
TrainerTrainLoopMixin,
|
|
|
|
TrainerCallbackConfigMixin,
|
2020-04-10 18:34:23 +00:00
|
|
|
TrainerLRFinderMixin,
|
2020-06-17 17:42:28 +00:00
|
|
|
TrainerDeprecatedAPITillVer0_10,
|
2020-03-06 17:00:05 +00:00
|
|
|
):
|
2020-06-30 23:35:54 +00:00
|
|
|
"""
|
|
|
|
Example:
|
|
|
|
|
|
|
|
>>> import torch
|
|
|
|
>>> from torch.nn import functional as F
|
|
|
|
>>> from torch.utils.data import Dataset, DataLoader
|
|
|
|
|
|
|
|
>>> # Define model
|
|
|
|
>>> class SimpleModel(LightningModule):
|
|
|
|
... def __init__(self):
|
|
|
|
... super().__init__()
|
|
|
|
... self.l1 = torch.nn.Linear(in_features=64, out_features=4)
|
|
|
|
...
|
|
|
|
... def forward(self, x):
|
|
|
|
... return torch.relu(self.l1(x.view(x.size(0), -1)))
|
|
|
|
...
|
|
|
|
... def training_step(self, batch, batch_nb):
|
|
|
|
... x, y = batch
|
|
|
|
... loss = F.cross_entropy(self(x), y)
|
|
|
|
... return {'loss': loss, 'log': {'train_loss': loss}}
|
|
|
|
...
|
|
|
|
... def test_step(self, batch, batch_nb):
|
|
|
|
... x, y = batch
|
|
|
|
... loss = F.cross_entropy(self(x), y)
|
2020-07-20 18:13:56 +00:00
|
|
|
... return {'loss': loss, 'log': {'test_loss': loss}}
|
2020-06-30 23:35:54 +00:00
|
|
|
...
|
|
|
|
... def configure_optimizers(self):
|
|
|
|
... return torch.optim.Adam(self.parameters(), lr=0.02)
|
|
|
|
...
|
|
|
|
>>> # Define dataset
|
|
|
|
>>> class SimpleDataset(Dataset):
|
|
|
|
... def __init__(self, num_samples=200):
|
|
|
|
... self.input_seq = torch.randn(num_samples, 64)
|
|
|
|
... self.output_seq = torch.randint(0, 4, (num_samples,))
|
|
|
|
...
|
|
|
|
... def __len__(self):
|
|
|
|
... return len(self.input_seq)
|
|
|
|
...
|
|
|
|
... def __getitem__(self, item):
|
|
|
|
... return self.input_seq[item], self.output_seq[item]
|
|
|
|
...
|
|
|
|
>>> train_loader = DataLoader(SimpleDataset(), batch_size=8)
|
|
|
|
>>> model = SimpleModel()
|
|
|
|
>>> # Define Trainer and fit model
|
|
|
|
>>> trainer = Trainer(max_epochs=1, progress_bar_refresh_rate=0)
|
|
|
|
>>> trainer.fit(model, train_loader)
|
|
|
|
1
|
2020-07-14 18:20:45 +00:00
|
|
|
>>> test_outputs = trainer.test(model, train_loader, verbose=False)
|
|
|
|
>>> len(test_outputs)
|
|
|
|
25
|
2020-06-30 23:35:54 +00:00
|
|
|
"""
|
2020-07-24 15:42:15 +00:00
|
|
|
|
2019-12-04 11:57:10 +00:00
|
|
|
def __init__(
|
2020-06-12 18:37:52 +00:00
|
|
|
self,
|
|
|
|
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
|
|
|
|
checkpoint_callback: Union[ModelCheckpoint, bool] = True,
|
|
|
|
early_stop_callback: Optional[Union[EarlyStopping, bool]] = False,
|
|
|
|
callbacks: Optional[List[Callback]] = None,
|
|
|
|
default_root_dir: Optional[str] = None,
|
|
|
|
gradient_clip_val: float = 0,
|
|
|
|
process_position: int = 0,
|
|
|
|
num_nodes: int = 1,
|
|
|
|
num_processes: int = 1,
|
|
|
|
gpus: Optional[Union[List[int], str, int]] = None,
|
|
|
|
auto_select_gpus: bool = False,
|
2020-06-23 16:06:57 +00:00
|
|
|
tpu_cores: Optional[Union[List[int], str, int]] = None,
|
2020-06-12 18:37:52 +00:00
|
|
|
log_gpu_memory: Optional[str] = None,
|
|
|
|
progress_bar_refresh_rate: int = 1,
|
2020-06-17 12:03:28 +00:00
|
|
|
overfit_batches: Union[int, float] = 0.0,
|
2020-06-12 18:37:52 +00:00
|
|
|
track_grad_norm: Union[int, float, str] = -1,
|
|
|
|
check_val_every_n_epoch: int = 1,
|
|
|
|
fast_dev_run: bool = False,
|
|
|
|
accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
|
|
|
|
max_epochs: int = 1000,
|
|
|
|
min_epochs: int = 1,
|
|
|
|
max_steps: Optional[int] = None,
|
|
|
|
min_steps: Optional[int] = None,
|
2020-06-17 17:42:28 +00:00
|
|
|
limit_train_batches: Union[int, float] = 1.0,
|
2020-06-17 12:03:28 +00:00
|
|
|
limit_val_batches: Union[int, float] = 1.0,
|
|
|
|
limit_test_batches: Union[int, float] = 1.0,
|
2020-06-17 17:42:28 +00:00
|
|
|
val_check_interval: Union[int, float] = 1.0,
|
2020-06-12 18:37:52 +00:00
|
|
|
log_save_interval: int = 100,
|
2020-06-15 00:17:49 +00:00
|
|
|
row_log_interval: int = 50,
|
2020-06-12 18:37:52 +00:00
|
|
|
distributed_backend: Optional[str] = None,
|
2020-08-05 23:12:11 +00:00
|
|
|
sync_batchnorm: bool = False,
|
2020-06-12 18:37:52 +00:00
|
|
|
precision: int = 32,
|
2020-06-15 21:05:58 +00:00
|
|
|
weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT,
|
2020-06-12 18:37:52 +00:00
|
|
|
weights_save_path: Optional[str] = None,
|
|
|
|
num_sanity_val_steps: int = 2,
|
|
|
|
truncated_bptt_steps: Optional[int] = None,
|
|
|
|
resume_from_checkpoint: Optional[str] = None,
|
|
|
|
profiler: Optional[Union[BaseProfiler, bool]] = None,
|
|
|
|
benchmark: bool = False,
|
|
|
|
deterministic: bool = False,
|
|
|
|
reload_dataloaders_every_epoch: bool = False,
|
|
|
|
auto_lr_find: Union[bool, str] = False,
|
|
|
|
replace_sampler_ddp: bool = True,
|
|
|
|
terminate_on_nan: bool = False,
|
|
|
|
auto_scale_batch_size: Union[str, bool] = False,
|
2020-06-13 16:00:14 +00:00
|
|
|
prepare_data_per_node: bool = True,
|
2020-08-13 14:03:13 +00:00
|
|
|
amp_backend: str = 'native',
|
2020-06-25 22:54:32 +00:00
|
|
|
amp_level: str = 'O2', # backward compatible, todo: remove in v1.0.0
|
2020-06-17 17:42:28 +00:00
|
|
|
val_percent_check: float = None, # backward compatible, todo: remove in v0.10.0
|
|
|
|
test_percent_check: float = None, # backward compatible, todo: remove in v0.10.0
|
|
|
|
train_percent_check: float = None, # backward compatible, todo: remove in v0.10.0
|
2020-07-24 15:42:15 +00:00
|
|
|
overfit_pct: float = None, # backward compatible, todo: remove in v1.0.0
|
2019-12-04 11:57:10 +00:00
|
|
|
):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
|
|
|
|
|
|
|
Customize every aspect of training via flags
|
|
|
|
|
|
|
|
Args:
|
2020-02-25 19:52:39 +00:00
|
|
|
logger: Logger (or iterable collection of loggers) for experiment tracking.
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
checkpoint_callback: Callback for checkpointing.
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-02-27 21:07:51 +00:00
|
|
|
early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`):
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-02-26 04:17:27 +00:00
|
|
|
callbacks: Add a list of callbacks.
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-07-27 16:53:11 +00:00
|
|
|
default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed.
|
|
|
|
Default: ``os.getcwd()``.
|
2020-08-09 22:38:43 +00:00
|
|
|
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
|
2020-04-10 16:02:59 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
gradient_clip_val: 0 means don't clip.
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-04-24 00:46:18 +00:00
|
|
|
process_position: orders the progress bar when running multiple models on same machine.
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
num_nodes: number of GPU nodes for distributed training.
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-09-02 15:57:56 +00:00
|
|
|
gpus: number of gpus to train on (int) or which GPUs to train on (list or str) applied per node
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-04-10 15:45:29 +00:00
|
|
|
auto_select_gpus:
|
|
|
|
|
|
|
|
If enabled and `gpus` is an integer, pick available
|
|
|
|
gpus automatically. This is especially useful when
|
|
|
|
GPUs are configured to be in "exclusive mode", such
|
|
|
|
that only one process at a time can access them.
|
|
|
|
|
2020-05-17 20:30:54 +00:00
|
|
|
tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1]
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
log_gpu_memory: None, 'min_max', 'all'. Might slow performance
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-04-02 22:53:00 +00:00
|
|
|
progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
|
2020-04-24 00:46:18 +00:00
|
|
|
Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`.
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-06-23 15:19:38 +00:00
|
|
|
overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0
|
2020-06-17 12:03:28 +00:00
|
|
|
|
|
|
|
overfit_pct:
|
|
|
|
.. warning:: .. deprecated:: 0.8.0
|
|
|
|
|
2020-06-23 15:19:38 +00:00
|
|
|
Use `overfit_batches` instead. Will be removed in 0.10.0.
|
2020-03-24 18:49:11 +00:00
|
|
|
|
2020-06-02 22:51:09 +00:00
|
|
|
track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
check_val_every_n_epoch: Check val every n train epochs.
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-07-27 21:56:55 +00:00
|
|
|
fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
max_epochs: Stop training once this number of epochs is reached.
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
min_epochs: Force training for at least these many epochs
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
max_steps: Stop training after this number of steps. Disabled by default (None).
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
min_steps: Force training for at least these number of steps. Disabled by default (None).
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-06-23 15:21:24 +00:00
|
|
|
limit_train_batches: How much of training dataset to check (floats = percent, int = num_batches)
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-06-17 12:03:28 +00:00
|
|
|
limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches)
|
|
|
|
|
|
|
|
limit_test_batches: How much of test dataset to check (floats = percent, int = num_batches)
|
|
|
|
|
2020-06-17 17:42:28 +00:00
|
|
|
train_percent_check:
|
|
|
|
.. warning:: .. deprecated:: 0.8.0
|
|
|
|
|
|
|
|
Use `limit_train_batches` instead. Will remove v0.10.0.
|
|
|
|
|
2020-06-17 12:03:28 +00:00
|
|
|
val_percent_check:
|
|
|
|
.. warning:: .. deprecated:: 0.8.0
|
|
|
|
|
2020-06-17 17:42:28 +00:00
|
|
|
Use `limit_val_batches` instead. Will remove v0.10.0.
|
2020-06-17 12:03:28 +00:00
|
|
|
|
|
|
|
test_percent_check:
|
|
|
|
.. warning:: .. deprecated:: 0.8.0
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-06-17 17:42:28 +00:00
|
|
|
Use `limit_test_batches` instead. Will remove v0.10.0.
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-08-03 13:13:05 +00:00
|
|
|
val_check_interval: How often to check the validation set. Use float to check within a training epoch,
|
|
|
|
use int to check every n steps (batches).
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
log_save_interval: Writes logs to disk this often
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
row_log_interval: How often to add logging rows (does not write to disk)
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-06-27 01:38:25 +00:00
|
|
|
distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn, ddp_cpu)
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-08-05 23:12:11 +00:00
|
|
|
sync_batchnorm: Synchronize batch norm layers between process groups/whole world.
|
2020-08-05 17:29:05 +00:00
|
|
|
|
2020-08-03 19:57:21 +00:00
|
|
|
precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs.
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
weights_summary: Prints a summary of the weights when training begins.
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-04-10 16:02:59 +00:00
|
|
|
weights_save_path: Where to save weights if specified. Will override default_root_dir
|
|
|
|
for checkpoints only. Use this if for whatever reason you need the checkpoints
|
|
|
|
stored in a different place than the logs written in `default_root_dir`.
|
2020-08-09 22:38:43 +00:00
|
|
|
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
|
2020-07-27 16:53:11 +00:00
|
|
|
Defaults to `default_root_dir`.
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-08-13 21:25:56 +00:00
|
|
|
amp_backend: The mixed precision backend to use ("native" or "apex")
|
2020-08-09 22:38:43 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
amp_level: The optimization level to use (O1, O2, etc...).
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-07-23 11:07:03 +00:00
|
|
|
num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
|
|
|
|
Set it to `-1` to run all batches in all validation dataloaders. Default: 2
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-08-03 19:57:21 +00:00
|
|
|
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of much longer
|
|
|
|
sequence.
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-04-03 05:35:09 +00:00
|
|
|
resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.
|
2020-06-11 21:12:48 +00:00
|
|
|
This can be a URL.
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-08-03 19:57:21 +00:00
|
|
|
profiler: To profile individual steps during training and assist in identifying bottlenecks.
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-08-03 19:57:21 +00:00
|
|
|
reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch.
|
2020-03-05 18:11:06 +00:00
|
|
|
|
2020-04-10 18:34:23 +00:00
|
|
|
auto_lr_find: If set to True, will `initially` run a learning rate finder,
|
|
|
|
trying to optimize initial learning for faster convergence. Sets learning
|
2020-05-24 22:59:08 +00:00
|
|
|
rate in self.lr or self.learning_rate in the LightningModule.
|
2020-04-10 18:34:23 +00:00
|
|
|
To use a different key, set a string instead of True with the key name.
|
|
|
|
|
2020-08-03 19:57:21 +00:00
|
|
|
replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this
|
|
|
|
will toggled automatically when DDP is used. By default it will add ``shuffle=True`` for
|
|
|
|
train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it,
|
2020-08-08 05:52:35 +00:00
|
|
|
you can set ``replace_sampler_ddp=False`` and add your own distributed sampler.
|
2020-04-19 20:58:57 +00:00
|
|
|
|
2020-03-05 23:52:17 +00:00
|
|
|
benchmark: If true enables cudnn.benchmark.
|
2020-04-13 18:06:25 +00:00
|
|
|
|
2020-08-03 19:57:21 +00:00
|
|
|
deterministic: If true enables cudnn.deterministic.
|
2020-05-12 11:53:20 +00:00
|
|
|
|
2020-04-13 18:06:25 +00:00
|
|
|
terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
|
|
|
|
end of each training batch, if any of the parameters or the loss are NaN or +/-inf.
|
2020-05-09 12:28:36 +00:00
|
|
|
|
|
|
|
auto_scale_batch_size: If set to True, will `initially` run a batch size
|
|
|
|
finder trying to find the largest batch size that fits into memory.
|
2020-05-24 22:59:08 +00:00
|
|
|
The result will be stored in self.batch_size in the LightningModule.
|
2020-05-09 12:28:36 +00:00
|
|
|
Additionally, can be set to either `power` that estimates the batch size through
|
|
|
|
a power search or `binsearch` that estimates the batch size through a binary search.
|
2020-06-13 16:00:14 +00:00
|
|
|
|
|
|
|
prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
|
|
|
|
Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data
|
2019-07-18 16:04:19 +00:00
|
|
|
"""
|
2020-05-17 13:14:54 +00:00
|
|
|
super().__init__()
|
2019-07-18 16:04:19 +00:00
|
|
|
|
2020-05-12 11:53:20 +00:00
|
|
|
self.deterministic = deterministic
|
|
|
|
torch.backends.cudnn.deterministic = self.deterministic
|
|
|
|
if self.deterministic:
|
|
|
|
# fixing non-deterministic part of horovod
|
|
|
|
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
|
|
|
|
os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)
|
|
|
|
|
2020-06-18 04:19:06 +00:00
|
|
|
# init the default rank if exists
|
|
|
|
# we need to call this here or NVIDIA flags and other messaging in init will show on all ranks
|
|
|
|
# this way we only show it on rank 0
|
|
|
|
if 'LOCAL_RANK' in os.environ:
|
2020-07-19 00:12:06 +00:00
|
|
|
rank_zero_only.rank = int(os.environ['LOCAL_RANK'])
|
2020-06-18 04:19:06 +00:00
|
|
|
|
2020-06-29 01:36:46 +00:00
|
|
|
# training bookeeping
|
|
|
|
self.total_batch_idx = 0
|
|
|
|
self.running_loss = TensorRunningAccum(window_length=20)
|
|
|
|
self.batch_idx = 0
|
|
|
|
self.progress_bar_metrics = {}
|
|
|
|
self.callback_metrics = {}
|
2020-08-20 00:34:09 +00:00
|
|
|
self.logged_metrics = {}
|
2020-06-29 01:36:46 +00:00
|
|
|
self.num_training_batches = 0
|
|
|
|
self.num_val_batches = []
|
2020-08-21 18:11:31 +00:00
|
|
|
self.num_sanity_val_batches = []
|
2020-06-29 01:36:46 +00:00
|
|
|
self.num_test_batches = []
|
|
|
|
self.train_dataloader = None
|
|
|
|
self.test_dataloaders = None
|
|
|
|
self.val_dataloaders = None
|
|
|
|
|
2020-07-14 18:20:45 +00:00
|
|
|
# when true, prints test results
|
|
|
|
self.verbose_test = True
|
|
|
|
|
2020-07-09 22:36:36 +00:00
|
|
|
# when .test() is called, it sets this
|
|
|
|
self.tested_ckpt_path = None
|
|
|
|
|
2020-06-29 01:36:46 +00:00
|
|
|
# training state
|
|
|
|
self.model = None
|
2020-08-02 00:17:57 +00:00
|
|
|
self.datamodule = None
|
2020-06-29 01:36:46 +00:00
|
|
|
self.testing = False
|
2020-06-13 16:00:14 +00:00
|
|
|
self.prepare_data_per_node = prepare_data_per_node
|
2020-06-29 01:36:46 +00:00
|
|
|
self.lr_schedulers = []
|
|
|
|
self.optimizers = None
|
|
|
|
self.optimizer_frequencies = []
|
|
|
|
self.global_step = 0
|
|
|
|
self.current_epoch = 0
|
|
|
|
self.interrupted = False
|
|
|
|
self.should_stop = False
|
2020-07-22 17:53:10 +00:00
|
|
|
self.running_sanity_check = False
|
2020-08-24 14:49:33 +00:00
|
|
|
self._state = TrainerState.INITIALIZING
|
2020-06-29 01:36:46 +00:00
|
|
|
|
2020-07-27 16:53:11 +00:00
|
|
|
self._default_root_dir = default_root_dir or os.getcwd()
|
|
|
|
self._weights_save_path = weights_save_path or self._default_root_dir
|
2020-06-29 01:36:46 +00:00
|
|
|
|
|
|
|
# init callbacks
|
2020-04-20 11:02:53 +00:00
|
|
|
self.callbacks = callbacks or []
|
2020-06-29 01:36:46 +00:00
|
|
|
|
|
|
|
# configure early stop callback
|
|
|
|
# creates a default one if none passed in
|
|
|
|
early_stop_callback = self.configure_early_stopping(early_stop_callback)
|
|
|
|
if early_stop_callback:
|
|
|
|
self.callbacks.append(early_stop_callback)
|
|
|
|
|
|
|
|
# configure checkpoint callback
|
|
|
|
# it is important that this is the last callback to run
|
|
|
|
# pass through the required args to figure out defaults
|
|
|
|
checkpoint_callback = self.configure_checkpoint_callback(checkpoint_callback)
|
|
|
|
if checkpoint_callback:
|
|
|
|
self.callbacks.append(checkpoint_callback)
|
|
|
|
|
|
|
|
# TODO refactor codebase (tests) to not directly reach into these callbacks
|
|
|
|
self.checkpoint_callback = checkpoint_callback
|
|
|
|
self.early_stop_callback = early_stop_callback
|
|
|
|
|
2020-03-03 04:51:32 +00:00
|
|
|
self.on_init_start()
|
2020-02-26 04:17:27 +00:00
|
|
|
|
2020-02-25 20:05:41 +00:00
|
|
|
# benchmarking
|
|
|
|
self.benchmark = benchmark
|
2020-04-15 00:32:33 +00:00
|
|
|
torch.backends.cudnn.benchmark = self.benchmark
|
2020-02-25 20:05:41 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
# Transfer params
|
2020-03-03 14:32:15 +00:00
|
|
|
self.num_nodes = num_nodes
|
2019-09-04 14:43:46 +00:00
|
|
|
self.log_gpu_memory = log_gpu_memory
|
2020-01-14 19:40:41 +00:00
|
|
|
|
2020-08-05 17:29:05 +00:00
|
|
|
# sync-bn backend
|
2020-08-05 23:12:11 +00:00
|
|
|
self.sync_batchnorm = sync_batchnorm
|
2020-08-05 17:29:05 +00:00
|
|
|
|
2020-03-06 17:00:05 +00:00
|
|
|
self.gradient_clip_val = gradient_clip_val
|
2019-03-31 01:45:16 +00:00
|
|
|
self.check_val_every_n_epoch = check_val_every_n_epoch
|
2020-06-02 22:51:09 +00:00
|
|
|
|
|
|
|
if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf':
|
2020-07-24 15:42:15 +00:00
|
|
|
raise MisconfigurationException("track_grad_norm can be an int, a float or 'inf' (infinity norm).")
|
2020-06-02 22:51:09 +00:00
|
|
|
self.track_grad_norm = float(track_grad_norm)
|
|
|
|
|
2020-06-23 16:06:57 +00:00
|
|
|
self.tpu_cores = _parse_tpu_cores(tpu_cores)
|
|
|
|
self.on_tpu = self.tpu_cores is not None
|
2020-05-17 20:30:54 +00:00
|
|
|
|
2020-06-23 16:06:57 +00:00
|
|
|
self.tpu_id = self.tpu_cores[0] if isinstance(self.tpu_cores, list) else None
|
2020-02-17 21:01:20 +00:00
|
|
|
|
2020-04-16 03:17:31 +00:00
|
|
|
if num_processes != 1 and distributed_backend != "ddp_cpu":
|
|
|
|
rank_zero_warn("num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it.")
|
|
|
|
self.num_processes = num_processes
|
|
|
|
|
2019-10-08 19:30:06 +00:00
|
|
|
self.weights_summary = weights_summary
|
2020-01-14 19:40:41 +00:00
|
|
|
|
2020-03-06 17:00:05 +00:00
|
|
|
self.max_epochs = max_epochs
|
|
|
|
self.min_epochs = min_epochs
|
2020-02-18 16:23:22 +00:00
|
|
|
self.max_steps = max_steps
|
|
|
|
self.min_steps = min_steps
|
|
|
|
|
2020-07-27 21:56:55 +00:00
|
|
|
if num_sanity_val_steps == -1:
|
2020-08-21 18:11:31 +00:00
|
|
|
self.num_sanity_val_steps = float('inf')
|
2020-07-27 21:56:55 +00:00
|
|
|
else:
|
2020-08-21 18:11:31 +00:00
|
|
|
self.num_sanity_val_steps = num_sanity_val_steps
|
2020-07-27 21:56:55 +00:00
|
|
|
|
2020-04-02 09:41:56 +00:00
|
|
|
self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch
|
|
|
|
|
2020-04-10 18:34:23 +00:00
|
|
|
self.auto_lr_find = auto_lr_find
|
2020-05-09 12:28:36 +00:00
|
|
|
self.auto_scale_batch_size = auto_scale_batch_size
|
2020-05-25 11:43:56 +00:00
|
|
|
self._is_data_prepared = False
|
2020-04-19 20:58:57 +00:00
|
|
|
self.replace_sampler_ddp = replace_sampler_ddp
|
2020-04-10 18:34:23 +00:00
|
|
|
|
2019-10-31 10:45:28 +00:00
|
|
|
self.truncated_bptt_steps = truncated_bptt_steps
|
2019-11-30 21:48:38 +00:00
|
|
|
self.resume_from_checkpoint = resume_from_checkpoint
|
2020-04-13 18:06:25 +00:00
|
|
|
self.terminate_on_nan = terminate_on_nan
|
2019-10-24 10:43:35 +00:00
|
|
|
self.shown_warnings = set()
|
2019-07-08 13:42:13 +00:00
|
|
|
|
2019-10-09 14:23:08 +00:00
|
|
|
self.fast_dev_run = fast_dev_run
|
|
|
|
if self.fast_dev_run:
|
2020-07-27 21:56:55 +00:00
|
|
|
limit_train_batches = 1
|
|
|
|
limit_val_batches = 1
|
|
|
|
limit_test_batches = 1
|
2020-04-03 19:00:26 +00:00
|
|
|
self.num_sanity_val_steps = 0
|
2019-12-07 13:50:21 +00:00
|
|
|
self.max_epochs = 1
|
2020-07-24 15:42:15 +00:00
|
|
|
rank_zero_info(
|
|
|
|
'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch'
|
|
|
|
)
|
2019-10-09 14:23:08 +00:00
|
|
|
|
2020-02-07 03:01:21 +00:00
|
|
|
# configure profiler
|
|
|
|
if profiler is True:
|
2020-03-31 12:57:48 +00:00
|
|
|
profiler = SimpleProfiler()
|
2020-02-07 03:01:21 +00:00
|
|
|
self.profiler = profiler or PassThroughProfiler()
|
|
|
|
|
2019-09-06 04:29:38 +00:00
|
|
|
# accumulated grads
|
2020-03-03 14:32:15 +00:00
|
|
|
self.accumulate_grad_batches = accumulate_grad_batches
|
2019-10-22 01:16:51 +00:00
|
|
|
self.configure_accumulated_gradients(accumulate_grad_batches)
|
2019-09-06 04:29:38 +00:00
|
|
|
|
2020-08-14 20:22:48 +00:00
|
|
|
# override with environment flag
|
|
|
|
gpus = os.environ.get('PL_TRAINER_GPUS', gpus)
|
|
|
|
|
2020-04-10 15:45:29 +00:00
|
|
|
# for gpus allow int, string and gpu list
|
|
|
|
if auto_select_gpus and isinstance(gpus, int):
|
|
|
|
self.gpus = pick_multiple_gpus(gpus)
|
|
|
|
else:
|
|
|
|
self.gpus = gpus
|
|
|
|
|
2020-06-23 16:06:57 +00:00
|
|
|
self.data_parallel_device_ids = _parse_gpu_ids(self.gpus)
|
2019-10-23 09:05:09 +00:00
|
|
|
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)
|
2020-04-03 21:56:19 +00:00
|
|
|
self.root_device = torch.device("cpu")
|
2019-09-06 04:29:38 +00:00
|
|
|
|
2020-07-24 08:26:05 +00:00
|
|
|
self.on_gpu = True if (self.data_parallel_device_ids and torch.cuda.is_available()) else False
|
|
|
|
|
2020-02-17 21:01:20 +00:00
|
|
|
# tpu state flags
|
|
|
|
self.use_tpu = False
|
|
|
|
self.tpu_local_core_rank = None
|
|
|
|
self.tpu_global_core_rank = None
|
|
|
|
|
2019-09-06 04:29:38 +00:00
|
|
|
# distributed backend choice
|
2019-10-04 19:07:54 +00:00
|
|
|
self.distributed_backend = distributed_backend
|
2020-04-16 03:17:31 +00:00
|
|
|
self.set_distributed_mode(distributed_backend)
|
2019-09-06 04:29:38 +00:00
|
|
|
|
2020-02-17 21:01:20 +00:00
|
|
|
# override dist backend when using tpus
|
|
|
|
if self.on_tpu:
|
2020-08-13 22:57:23 +00:00
|
|
|
self.distributed_backend = 'tpu'
|
2020-02-17 21:01:20 +00:00
|
|
|
self.init_tpu()
|
|
|
|
|
2020-06-17 17:42:28 +00:00
|
|
|
# init flags for SLURM+DDP to work
|
2019-09-06 04:29:38 +00:00
|
|
|
self.world_size = 1
|
2020-06-01 15:00:32 +00:00
|
|
|
self.interactive_ddp_procs = []
|
2020-03-06 17:00:05 +00:00
|
|
|
self.configure_slurm_ddp(self.num_nodes)
|
2020-05-13 18:06:59 +00:00
|
|
|
self.node_rank = self.determine_ddp_node_rank()
|
2020-06-13 16:00:14 +00:00
|
|
|
self.local_rank = self.determine_local_rank()
|
|
|
|
self.global_rank = 0
|
2019-09-08 19:36:58 +00:00
|
|
|
|
2020-06-17 17:42:28 +00:00
|
|
|
# NVIDIA setup
|
2019-10-22 01:16:51 +00:00
|
|
|
self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)
|
2019-09-06 04:29:38 +00:00
|
|
|
|
2020-05-25 11:49:23 +00:00
|
|
|
self._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)
|
2020-04-24 00:46:18 +00:00
|
|
|
|
2019-09-06 04:29:38 +00:00
|
|
|
# logging
|
2020-06-30 22:09:16 +00:00
|
|
|
self.configure_logger(logger)
|
2019-09-06 04:29:38 +00:00
|
|
|
self.log_save_interval = log_save_interval
|
2019-09-25 23:05:06 +00:00
|
|
|
self.row_log_interval = row_log_interval
|
2019-09-06 04:29:38 +00:00
|
|
|
|
|
|
|
# how much of the data to use
|
2020-06-17 17:42:28 +00:00
|
|
|
# TODO: remove in 0.10.0
|
|
|
|
if overfit_pct is not None:
|
2020-07-24 15:42:15 +00:00
|
|
|
rank_zero_warn(
|
|
|
|
"Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0"
|
|
|
|
" and this argument will be removed in v0.10.0",
|
|
|
|
DeprecationWarning,
|
|
|
|
)
|
2020-06-17 12:03:28 +00:00
|
|
|
overfit_batches = overfit_pct
|
|
|
|
|
2020-06-17 17:42:28 +00:00
|
|
|
# TODO: remove in 0.10.0
|
|
|
|
if val_percent_check is not None:
|
2020-07-24 15:42:15 +00:00
|
|
|
rank_zero_warn(
|
|
|
|
"Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0"
|
|
|
|
" and this argument will be removed in v0.10.0",
|
|
|
|
DeprecationWarning,
|
|
|
|
)
|
2020-06-17 12:03:28 +00:00
|
|
|
limit_val_batches = val_percent_check
|
|
|
|
|
2020-06-17 17:42:28 +00:00
|
|
|
# TODO: remove in 0.10.0
|
|
|
|
if test_percent_check is not None:
|
2020-07-24 15:42:15 +00:00
|
|
|
rank_zero_warn(
|
|
|
|
"Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0"
|
|
|
|
" and this argument will be removed in v0.10.0",
|
|
|
|
DeprecationWarning,
|
|
|
|
)
|
2020-06-17 12:03:28 +00:00
|
|
|
limit_test_batches = test_percent_check
|
|
|
|
|
2020-06-17 17:42:28 +00:00
|
|
|
# TODO: remove in 0.10.0
|
|
|
|
if train_percent_check is not None:
|
2020-07-24 15:42:15 +00:00
|
|
|
rank_zero_warn(
|
|
|
|
"Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0"
|
|
|
|
" and this argument will be removed in v0.10.0",
|
|
|
|
DeprecationWarning,
|
|
|
|
)
|
2020-06-17 17:42:28 +00:00
|
|
|
limit_train_batches = train_percent_check
|
2020-06-17 12:03:28 +00:00
|
|
|
|
2020-08-07 11:02:36 +00:00
|
|
|
self.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches')
|
|
|
|
self.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches')
|
|
|
|
self.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches')
|
|
|
|
self.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval')
|
|
|
|
self.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches')
|
2020-06-17 17:42:28 +00:00
|
|
|
self.determine_data_use_amount(self.overfit_batches)
|
2019-09-06 04:29:38 +00:00
|
|
|
|
2020-04-23 18:47:08 +00:00
|
|
|
# AMP init
|
|
|
|
# These are the only lines needed after v0.8.0
|
|
|
|
# we wrap the user's forward with autocast and give it back at the end of fit
|
|
|
|
self.autocast_original_forward = None
|
2020-04-23 19:24:02 +00:00
|
|
|
self.precision = precision
|
2020-05-12 04:25:06 +00:00
|
|
|
self.scaler = None
|
2020-04-23 18:47:08 +00:00
|
|
|
|
2019-09-06 04:29:38 +00:00
|
|
|
self.amp_level = amp_level
|
2020-08-13 14:03:13 +00:00
|
|
|
self.init_amp(amp_backend)
|
2019-09-06 04:29:38 +00:00
|
|
|
|
2020-05-13 23:17:04 +00:00
|
|
|
self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')
|
|
|
|
|
2020-07-20 23:00:20 +00:00
|
|
|
# tracks internal state for debugging
|
|
|
|
self.dev_debugger = InternalDebugger(self)
|
2020-07-25 16:57:40 +00:00
|
|
|
self.config_validator = ConfigValidator(self)
|
2020-08-31 15:08:22 +00:00
|
|
|
self.data_connector = DataConnector(self)
|
2020-09-01 19:48:28 +00:00
|
|
|
self.accelerator_connector = AcceleratorConnector(self)
|
2020-07-26 00:56:50 +00:00
|
|
|
self.accelerator_backend = None
|
2020-07-20 23:00:20 +00:00
|
|
|
|
2020-08-25 01:27:11 +00:00
|
|
|
# loops
|
|
|
|
self.evaluation_loop = EvaluationLoop(self)
|
2020-09-01 22:03:28 +00:00
|
|
|
self.train_loop = TrainLoop(self)
|
2020-08-25 01:27:11 +00:00
|
|
|
|
2020-02-26 04:17:27 +00:00
|
|
|
# Callback system
|
2020-03-03 04:51:32 +00:00
|
|
|
self.on_init_end()
|
2020-02-26 04:17:27 +00:00
|
|
|
|
2020-08-24 14:49:33 +00:00
|
|
|
@property
|
|
|
|
def state(self) -> TrainerState:
|
|
|
|
return self._state
|
|
|
|
|
2020-06-13 16:00:14 +00:00
|
|
|
@property
|
2020-06-19 15:00:46 +00:00
|
|
|
def is_global_zero(self) -> bool:
|
2020-06-13 16:00:14 +00:00
|
|
|
return self.global_rank == 0
|
|
|
|
|
2019-10-05 18:45:37 +00:00
|
|
|
@property
|
2020-06-12 15:23:18 +00:00
|
|
|
def slurm_job_id(self) -> Optional[int]:
|
2019-10-05 18:45:37 +00:00
|
|
|
try:
|
|
|
|
job_id = os.environ['SLURM_JOB_ID']
|
|
|
|
job_id = int(job_id)
|
2020-04-25 20:01:15 +00:00
|
|
|
|
|
|
|
# in interactive mode, don't make logs use the same job id
|
|
|
|
in_slurm_interactive_mode = os.environ['SLURM_JOB_NAME'] == 'bash'
|
|
|
|
if in_slurm_interactive_mode:
|
|
|
|
job_id = None
|
|
|
|
|
2019-12-04 11:57:10 +00:00
|
|
|
except Exception:
|
2019-10-05 18:45:37 +00:00
|
|
|
job_id = None
|
|
|
|
return job_id
|
|
|
|
|
2020-03-03 14:32:15 +00:00
|
|
|
@classmethod
|
|
|
|
def default_attributes(cls):
|
2020-03-06 19:43:17 +00:00
|
|
|
init_signature = inspect.signature(Trainer)
|
|
|
|
|
|
|
|
args = {}
|
|
|
|
for param_name in init_signature.parameters:
|
|
|
|
value = init_signature.parameters[param_name].default
|
|
|
|
args[param_name] = value
|
|
|
|
|
|
|
|
return args
|
2020-03-03 14:32:15 +00:00
|
|
|
|
2020-03-24 18:55:27 +00:00
|
|
|
@classmethod
|
|
|
|
def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
|
|
|
|
r"""Scans the Trainer signature and returns argument names, types and default values.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
List with tuples of 3 values:
|
|
|
|
(argument name, set with argument types, argument default value).
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> args = Trainer.get_init_arguments_and_types()
|
|
|
|
>>> import pprint
|
|
|
|
>>> pprint.pprint(sorted(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
|
|
|
[('accumulate_grad_batches',
|
|
|
|
(<class 'int'>, typing.Dict[int, int], typing.List[list]),
|
|
|
|
1),
|
|
|
|
...
|
2020-04-20 11:02:53 +00:00
|
|
|
('callbacks',
|
|
|
|
(typing.List[pytorch_lightning.callbacks.base.Callback],
|
|
|
|
<class 'NoneType'>),
|
|
|
|
None),
|
2020-03-24 18:55:27 +00:00
|
|
|
('check_val_every_n_epoch', (<class 'int'>,), 1),
|
|
|
|
...
|
|
|
|
('max_epochs', (<class 'int'>,), 1000),
|
|
|
|
...
|
|
|
|
('precision', (<class 'int'>,), 32),
|
2020-06-13 16:00:14 +00:00
|
|
|
('prepare_data_per_node', (<class 'bool'>,), True),
|
2020-03-24 18:55:27 +00:00
|
|
|
('process_position', (<class 'int'>,), 0),
|
|
|
|
('profiler',
|
2020-03-31 12:57:48 +00:00
|
|
|
(<class 'pytorch_lightning.profiler.profilers.BaseProfiler'>,
|
2020-05-12 12:53:26 +00:00
|
|
|
<class 'bool'>,
|
2020-03-24 18:55:27 +00:00
|
|
|
<class 'NoneType'>),
|
|
|
|
None),
|
2020-03-31 12:57:48 +00:00
|
|
|
...
|
2020-03-24 18:55:27 +00:00
|
|
|
"""
|
|
|
|
trainer_default_params = inspect.signature(cls).parameters
|
|
|
|
name_type_default = []
|
|
|
|
for arg in trainer_default_params:
|
|
|
|
arg_type = trainer_default_params[arg].annotation
|
|
|
|
arg_default = trainer_default_params[arg].default
|
|
|
|
try:
|
|
|
|
arg_types = tuple(arg_type.__args__)
|
|
|
|
except AttributeError:
|
|
|
|
arg_types = (arg_type,)
|
|
|
|
|
|
|
|
name_type_default.append((arg, arg_types, arg_default))
|
|
|
|
|
|
|
|
return name_type_default
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def get_deprecated_arg_names(cls) -> List:
|
|
|
|
"""Returns a list with deprecated Trainer arguments."""
|
|
|
|
depr_arg_names = []
|
|
|
|
for name, val in cls.__dict__.items():
|
|
|
|
if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)):
|
|
|
|
depr_arg_names.extend(val)
|
|
|
|
return depr_arg_names
|
|
|
|
|
2020-03-03 14:32:15 +00:00
|
|
|
@classmethod
|
|
|
|
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
|
2020-03-24 18:55:27 +00:00
|
|
|
r"""Extends existing argparse by default `Trainer` attributes.
|
2020-03-03 14:32:15 +00:00
|
|
|
|
2020-03-24 18:55:27 +00:00
|
|
|
Args:
|
|
|
|
parent_parser:
|
|
|
|
The custom cli arguments parser, which will be extended by
|
|
|
|
the Trainer default arguments.
|
|
|
|
|
|
|
|
Only arguments of the allowed types (str, float, int, bool) will
|
|
|
|
extend the `parent_parser`.
|
2020-05-12 12:53:26 +00:00
|
|
|
|
|
|
|
Examples:
|
|
|
|
>>> import argparse
|
|
|
|
>>> import pprint
|
|
|
|
>>> parser = argparse.ArgumentParser()
|
|
|
|
>>> parser = Trainer.add_argparse_args(parser)
|
|
|
|
>>> args = parser.parse_args([])
|
|
|
|
>>> pprint.pprint(vars(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
|
|
|
{...
|
|
|
|
'check_val_every_n_epoch': 1,
|
|
|
|
'checkpoint_callback': True,
|
|
|
|
'default_root_dir': None,
|
2020-05-12 21:06:38 +00:00
|
|
|
'deterministic': False,
|
2020-05-12 12:53:26 +00:00
|
|
|
'distributed_backend': None,
|
|
|
|
'early_stop_callback': False,
|
|
|
|
...
|
|
|
|
'logger': True,
|
|
|
|
'max_epochs': 1000,
|
|
|
|
'max_steps': None,
|
|
|
|
'min_epochs': 1,
|
|
|
|
'min_steps': None,
|
|
|
|
...
|
|
|
|
'profiler': None,
|
|
|
|
'progress_bar_refresh_rate': 1,
|
|
|
|
...}
|
|
|
|
|
2020-03-24 18:55:27 +00:00
|
|
|
"""
|
2020-07-24 15:42:15 +00:00
|
|
|
parser = ArgumentParser(parents=[parent_parser], add_help=False,)
|
2020-03-03 14:32:15 +00:00
|
|
|
|
2020-04-26 13:20:06 +00:00
|
|
|
blacklist = ['kwargs']
|
|
|
|
depr_arg_names = cls.get_deprecated_arg_names() + blacklist
|
2020-03-24 18:55:27 +00:00
|
|
|
|
2020-08-14 01:44:55 +00:00
|
|
|
allowed_types = (str, int, float, bool)
|
2020-04-26 13:20:06 +00:00
|
|
|
|
2020-03-06 19:43:17 +00:00
|
|
|
# TODO: get "help" from docstring :)
|
2020-07-24 15:42:15 +00:00
|
|
|
for arg, arg_types, arg_default in (
|
|
|
|
at for at in cls.get_init_arguments_and_types() if at[0] not in depr_arg_names
|
|
|
|
):
|
2020-05-14 21:56:11 +00:00
|
|
|
arg_types = [at for at in allowed_types if at in arg_types]
|
|
|
|
if not arg_types:
|
|
|
|
# skip argument with not supported type
|
|
|
|
continue
|
|
|
|
arg_kwargs = {}
|
|
|
|
if bool in arg_types:
|
2020-08-29 14:39:42 +00:00
|
|
|
arg_kwargs.update(nargs="?", const=True)
|
2020-05-14 21:56:11 +00:00
|
|
|
# if the only arg type is bool
|
|
|
|
if len(arg_types) == 1:
|
2020-08-29 14:39:42 +00:00
|
|
|
use_type = parsing.str_to_bool
|
|
|
|
# if only two args (str, bool)
|
|
|
|
elif len(arg_types) == 2 and set(arg_types) == {str, bool}:
|
|
|
|
use_type = parsing.str_to_bool_or_str
|
2020-05-14 21:56:11 +00:00
|
|
|
else:
|
|
|
|
# filter out the bool as we need to use more general
|
|
|
|
use_type = [at for at in arg_types if at is not bool][0]
|
|
|
|
else:
|
|
|
|
use_type = arg_types[0]
|
|
|
|
|
2020-06-23 16:06:57 +00:00
|
|
|
if arg == 'gpus' or arg == 'tpu_cores':
|
2020-08-20 17:49:34 +00:00
|
|
|
use_type = Trainer._gpus_allowed_type
|
|
|
|
arg_default = Trainer._gpus_arg_default
|
|
|
|
|
|
|
|
# hack for types in (int, float)
|
|
|
|
if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types):
|
|
|
|
use_type = Trainer._int_or_float_type
|
|
|
|
|
|
|
|
# hack for track_grad_norm
|
|
|
|
if arg == 'track_grad_norm':
|
|
|
|
use_type = float
|
2020-05-14 21:56:11 +00:00
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
f'--{arg}',
|
|
|
|
dest=arg,
|
|
|
|
default=arg_default,
|
|
|
|
type=use_type,
|
|
|
|
help='autogenerated by pl.Trainer',
|
|
|
|
**arg_kwargs,
|
|
|
|
)
|
2020-03-03 14:32:15 +00:00
|
|
|
|
|
|
|
return parser
|
|
|
|
|
2020-08-20 17:49:34 +00:00
|
|
|
def _gpus_allowed_type(x) -> Union[int, str]:
|
2020-04-27 11:08:34 +00:00
|
|
|
if ',' in x:
|
|
|
|
return str(x)
|
|
|
|
else:
|
|
|
|
return int(x)
|
|
|
|
|
2020-08-20 17:49:34 +00:00
|
|
|
def _gpus_arg_default(x) -> Union[int, str]:
|
2020-04-27 11:08:34 +00:00
|
|
|
if ',' in x:
|
|
|
|
return str(x)
|
|
|
|
else:
|
|
|
|
return int(x)
|
|
|
|
|
2020-08-20 17:49:34 +00:00
|
|
|
def _int_or_float_type(x) -> Union[int, float]:
|
|
|
|
if '.' in str(x):
|
|
|
|
return float(x)
|
|
|
|
else:
|
|
|
|
return int(x)
|
|
|
|
|
2020-07-09 11:10:30 +00:00
|
|
|
@classmethod
|
|
|
|
def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
|
2020-05-14 21:56:11 +00:00
|
|
|
"""Parse CLI arguments, required for custom bool types."""
|
|
|
|
args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser
|
2020-07-09 11:10:30 +00:00
|
|
|
|
|
|
|
types_default = {
|
|
|
|
arg: (arg_types, arg_default) for arg, arg_types, arg_default in cls.get_init_arguments_and_types()
|
|
|
|
}
|
|
|
|
|
|
|
|
modified_args = {}
|
|
|
|
for k, v in vars(args).items():
|
|
|
|
if k in types_default and v is None:
|
|
|
|
# We need to figure out if the None is due to using nargs="?" or if it comes from the default value
|
|
|
|
arg_types, arg_default = types_default[k]
|
|
|
|
if bool in arg_types and isinstance(arg_default, bool):
|
|
|
|
# Value has been passed as a flag => It is currently None, so we need to set it to True
|
|
|
|
# We always set to True, regardless of the default value.
|
|
|
|
# Users must pass False directly, but when passing nothing True is assumed.
|
|
|
|
# i.e. the only way to disable somthing that defaults to True is to use the long form:
|
|
|
|
# "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None,
|
|
|
|
# which then becomes True here.
|
|
|
|
|
|
|
|
v = True
|
|
|
|
|
|
|
|
modified_args[k] = v
|
|
|
|
return Namespace(**modified_args)
|
2020-05-14 21:56:11 +00:00
|
|
|
|
2020-03-03 14:32:15 +00:00
|
|
|
@classmethod
|
2020-05-14 21:56:11 +00:00
|
|
|
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer':
|
2020-05-25 20:01:29 +00:00
|
|
|
"""
|
|
|
|
Create an instance from CLI arguments.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
args: The parser or namespace to take arguments from. Only known arguments will be
|
|
|
|
parsed and passed to the :class:`Trainer`.
|
|
|
|
**kwargs: Additional keyword arguments that may override ones in the parser or namespace.
|
|
|
|
These must be valid Trainer arguments.
|
2020-03-03 14:32:15 +00:00
|
|
|
|
2020-05-14 21:56:11 +00:00
|
|
|
Example:
|
|
|
|
>>> parser = ArgumentParser(add_help=False)
|
|
|
|
>>> parser = Trainer.add_argparse_args(parser)
|
2020-05-25 20:01:29 +00:00
|
|
|
>>> parser.add_argument('--my_custom_arg', default='something') # doctest: +SKIP
|
2020-05-14 21:56:11 +00:00
|
|
|
>>> args = Trainer.parse_argparser(parser.parse_args(""))
|
2020-05-25 20:01:29 +00:00
|
|
|
>>> trainer = Trainer.from_argparse_args(args, logger=False)
|
2020-05-14 21:56:11 +00:00
|
|
|
"""
|
|
|
|
if isinstance(args, ArgumentParser):
|
2020-05-25 20:01:29 +00:00
|
|
|
args = cls.parse_argparser(args)
|
2020-03-03 14:32:15 +00:00
|
|
|
params = vars(args)
|
2020-04-26 13:20:06 +00:00
|
|
|
|
2020-05-25 20:01:29 +00:00
|
|
|
# we only want to pass in valid Trainer args, the rest may be user specific
|
|
|
|
valid_kwargs = inspect.signature(cls.__init__).parameters
|
|
|
|
trainer_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params)
|
|
|
|
trainer_kwargs.update(**kwargs)
|
|
|
|
|
|
|
|
return cls(**trainer_kwargs)
|
2020-03-03 14:32:15 +00:00
|
|
|
|
2019-09-08 19:36:58 +00:00
|
|
|
@property
|
2020-02-23 02:23:30 +00:00
|
|
|
def num_gpus(self) -> int:
|
2019-09-08 19:36:58 +00:00
|
|
|
gpus = self.data_parallel_device_ids
|
|
|
|
if gpus is None:
|
|
|
|
return 0
|
2020-02-01 23:44:05 +00:00
|
|
|
return len(gpus)
|
2019-09-08 19:36:58 +00:00
|
|
|
|
2019-07-18 15:08:48 +00:00
|
|
|
@property
|
2020-02-23 02:23:30 +00:00
|
|
|
def data_parallel(self) -> bool:
|
2019-10-05 20:39:05 +00:00
|
|
|
return self.use_dp or self.use_ddp or self.use_ddp2
|
2019-07-18 15:08:48 +00:00
|
|
|
|
2020-05-25 11:49:23 +00:00
|
|
|
@property
|
|
|
|
def progress_bar_callback(self):
|
|
|
|
return self._progress_bar_callback
|
|
|
|
|
2019-10-22 01:16:51 +00:00
|
|
|
@property
|
2020-04-24 00:46:18 +00:00
|
|
|
def progress_bar_dict(self) -> dict:
|
|
|
|
""" Read-only for progress bar metrics. """
|
2020-02-05 11:24:43 +00:00
|
|
|
ref_model = self.model if not self.data_parallel else self.model.module
|
2020-04-24 00:46:18 +00:00
|
|
|
return dict(**ref_model.get_progress_bar_dict(), **self.progress_bar_metrics)
|
2019-10-21 06:16:55 +00:00
|
|
|
|
2020-07-23 11:07:03 +00:00
|
|
|
@property
|
|
|
|
def disable_validation(self) -> bool:
|
|
|
|
""" Check if validation is disabled during training. """
|
|
|
|
return not self.enable_validation
|
|
|
|
|
|
|
|
@property
|
|
|
|
def enable_validation(self) -> bool:
|
|
|
|
""" Check if we should run validation during training. """
|
2020-08-31 16:12:02 +00:00
|
|
|
val_loop_enabled = is_overridden('validation_step', self.get_model()) and self.limit_val_batches > 0
|
2020-07-23 11:07:03 +00:00
|
|
|
return val_loop_enabled or self.fast_dev_run
|
|
|
|
|
2020-07-27 16:53:11 +00:00
|
|
|
@property
|
|
|
|
def default_root_dir(self) -> str:
|
|
|
|
"""
|
|
|
|
The default location to save artifacts of loggers, checkpoints etc.
|
|
|
|
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
|
|
|
|
"""
|
2020-09-03 12:19:20 +00:00
|
|
|
if get_filesystem(self._default_root_dir).protocol == "file":
|
|
|
|
return os.path.normpath(self._default_root_dir)
|
|
|
|
return self._default_root_dir
|
2020-07-27 16:53:11 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def weights_save_path(self) -> str:
|
|
|
|
"""
|
|
|
|
The default root location to save weights (checkpoints), e.g., when the
|
|
|
|
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
|
|
|
|
"""
|
2020-09-03 12:19:20 +00:00
|
|
|
if get_filesystem(self._weights_save_path).protocol == "file":
|
|
|
|
return os.path.normpath(self._weights_save_path)
|
|
|
|
return self._weights_save_path
|
2020-07-27 16:53:11 +00:00
|
|
|
|
2020-08-31 21:36:09 +00:00
|
|
|
def tune(
|
|
|
|
self,
|
|
|
|
model: LightningModule,
|
|
|
|
train_dataloader: Optional[DataLoader] = None,
|
|
|
|
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
|
|
|
|
datamodule: Optional[LightningDataModule] = None,
|
|
|
|
):
|
|
|
|
# TODO: temporary, need to decide if tune or separate object
|
|
|
|
|
|
|
|
# setup data, etc...
|
|
|
|
self.setup_fit(model, train_dataloader, val_dataloaders, datamodule)
|
|
|
|
|
|
|
|
# hook
|
|
|
|
self.call_hook('on_fit_start', model)
|
|
|
|
|
|
|
|
# hook
|
2020-09-01 18:59:09 +00:00
|
|
|
self.data_connector.prepare_data(model)
|
2020-08-31 21:36:09 +00:00
|
|
|
|
|
|
|
# Run auto batch size scaling
|
|
|
|
if self.auto_scale_batch_size:
|
|
|
|
if isinstance(self.auto_scale_batch_size, bool):
|
|
|
|
self.auto_scale_batch_size = 'power'
|
2020-09-03 20:07:49 +00:00
|
|
|
self.scale_batch_size(
|
|
|
|
model,
|
|
|
|
mode=self.auto_scale_batch_size,
|
|
|
|
train_dataloader=train_dataloader,
|
|
|
|
val_dataloaders=val_dataloaders,
|
|
|
|
datamodule=datamodule,
|
|
|
|
)
|
2020-08-31 21:36:09 +00:00
|
|
|
model.logger = self.logger # reset logger binding
|
|
|
|
|
|
|
|
# Run learning rate finder:
|
|
|
|
if self.auto_lr_find:
|
|
|
|
self._run_lr_finder_internally(model)
|
|
|
|
model.logger = self.logger # reset logger binding
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
# -----------------------------
|
|
|
|
# MODEL TRAINING
|
|
|
|
# -----------------------------
|
2020-08-09 10:24:09 +00:00
|
|
|
@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
|
2020-02-23 02:23:30 +00:00
|
|
|
def fit(
|
2020-07-24 15:42:15 +00:00
|
|
|
self,
|
|
|
|
model: LightningModule,
|
|
|
|
train_dataloader: Optional[DataLoader] = None,
|
|
|
|
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
|
|
|
|
datamodule: Optional[LightningDataModule] = None,
|
2020-02-23 02:23:30 +00:00
|
|
|
):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
|
|
|
Runs the full optimization routine.
|
|
|
|
|
2020-02-19 11:00:08 +00:00
|
|
|
Args:
|
2020-02-23 02:23:30 +00:00
|
|
|
model: Model to fit.
|
2020-02-19 11:00:08 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
train_dataloader: A Pytorch
|
2020-02-19 11:00:08 +00:00
|
|
|
DataLoader with training samples. If the model has
|
|
|
|
a predefined train_dataloader method this will be skipped.
|
|
|
|
|
2020-02-25 03:23:25 +00:00
|
|
|
val_dataloaders: Either a single
|
2020-02-19 11:00:08 +00:00
|
|
|
Pytorch Dataloader or a list of them, specifying validation samples.
|
2020-02-25 03:23:25 +00:00
|
|
|
If the model has a predefined val_dataloaders method this will be skipped
|
2020-02-19 11:00:08 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
|
|
|
|
2020-02-19 11:00:08 +00:00
|
|
|
# Option 1,
|
2020-04-10 15:44:03 +00:00
|
|
|
# Define the train_dataloader() and val_dataloader() fxs
|
2020-02-19 11:00:08 +00:00
|
|
|
# in the lightningModule
|
|
|
|
# RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY
|
2020-01-17 11:03:31 +00:00
|
|
|
trainer = Trainer()
|
|
|
|
model = LightningModule()
|
2020-02-19 11:00:08 +00:00
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
# Option 2
|
|
|
|
# in production cases we might want to pass different datasets to the same model
|
|
|
|
# Recommended for PRODUCTION SYSTEMS
|
2020-04-10 15:44:03 +00:00
|
|
|
train, val = DataLoader(...), DataLoader(...)
|
2020-02-19 11:00:08 +00:00
|
|
|
trainer = Trainer()
|
|
|
|
model = LightningModule()
|
2020-06-04 15:24:12 +00:00
|
|
|
trainer.fit(model, train_dataloader=train, val_dataloaders=val)
|
2020-02-19 11:00:08 +00:00
|
|
|
|
|
|
|
# Option 1 & 2 can be mixed, for example the training set can be
|
2020-04-10 15:44:03 +00:00
|
|
|
# defined as part of the model, and validation can then be feed to .fit()
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
"""
|
2020-07-07 16:24:56 +00:00
|
|
|
results = None
|
|
|
|
|
2020-08-31 15:08:22 +00:00
|
|
|
# setup data, etc...
|
|
|
|
self.setup_fit(model, train_dataloader, val_dataloaders, datamodule)
|
2020-04-02 15:53:37 +00:00
|
|
|
|
2020-08-26 17:53:23 +00:00
|
|
|
# hook
|
|
|
|
self.call_hook('on_fit_start', model)
|
2020-06-17 11:37:16 +00:00
|
|
|
|
2020-08-27 02:20:00 +00:00
|
|
|
# hook
|
2020-09-01 18:59:09 +00:00
|
|
|
self.data_connector.prepare_data(model)
|
2020-02-19 11:00:08 +00:00
|
|
|
|
2020-08-02 12:13:31 +00:00
|
|
|
# set testing if set in environ
|
|
|
|
self.testing = os.environ.get('PL_TESTING_MODE', self.testing)
|
2020-06-13 16:00:14 +00:00
|
|
|
|
2020-08-27 02:20:00 +00:00
|
|
|
# -------------------------
|
|
|
|
# TRAIN
|
|
|
|
# -------------------------
|
2020-09-01 19:48:28 +00:00
|
|
|
self.accelerator_backend = self.accelerator_connector.select_accelerator()
|
2020-08-27 01:29:10 +00:00
|
|
|
self.accelerator_backend.setup(model)
|
|
|
|
results = self.accelerator_backend.train()
|
|
|
|
self.accelerator_backend.teardown()
|
|
|
|
|
2020-08-27 02:20:00 +00:00
|
|
|
# -------------------------
|
|
|
|
# POST-Training
|
|
|
|
# -------------------------
|
2020-08-27 01:29:10 +00:00
|
|
|
# hook
|
|
|
|
self.call_hook('on_fit_end')
|
|
|
|
|
|
|
|
# hook
|
|
|
|
self.teardown('fit')
|
|
|
|
if self.is_function_implemented('teardown'):
|
|
|
|
model.teardown('fit')
|
|
|
|
|
|
|
|
# return 1 when finished
|
|
|
|
# used for testing or when we need to know that training succeeded
|
|
|
|
return results or 1
|
|
|
|
|
2020-08-31 15:08:22 +00:00
|
|
|
def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
|
|
|
|
# bind logger and other properties
|
|
|
|
self.copy_trainer_model_properties(model)
|
|
|
|
|
|
|
|
# clean hparams
|
|
|
|
if hasattr(model, 'hparams'):
|
|
|
|
parsing.clean_namespace(model.hparams)
|
|
|
|
|
|
|
|
# links data to the trainer
|
|
|
|
self.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule)
|
|
|
|
|
|
|
|
# check that model is configured correctly
|
|
|
|
self.config_validator.verify_loop_configurations(model)
|
|
|
|
|
2020-08-31 22:06:11 +00:00
|
|
|
def setup_training(self, model: LightningModule):
|
2019-12-04 11:57:10 +00:00
|
|
|
"""Sanity check a few things before starting actual training.
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
Args:
|
|
|
|
model: The model to run sanity test on.
|
2019-07-03 19:09:49 +00:00
|
|
|
"""
|
2020-08-31 22:06:11 +00:00
|
|
|
# --------------------------
|
|
|
|
# Setup??
|
|
|
|
# --------------------------
|
2019-07-08 21:38:57 +00:00
|
|
|
ref_model = model
|
2019-07-14 02:21:17 +00:00
|
|
|
if self.data_parallel:
|
2019-07-08 21:38:57 +00:00
|
|
|
ref_model = model.module
|
|
|
|
|
2019-08-30 22:56:09 +00:00
|
|
|
# give model convenience properties
|
2019-07-08 22:55:05 +00:00
|
|
|
ref_model.trainer = self
|
|
|
|
|
2019-07-08 21:15:26 +00:00
|
|
|
# set local properties on the model
|
2019-10-22 01:16:51 +00:00
|
|
|
self.copy_trainer_model_properties(ref_model)
|
2019-07-08 21:15:26 +00:00
|
|
|
|
2020-05-12 04:25:06 +00:00
|
|
|
# init amp. Must be done here instead of __init__ to allow ddp to work
|
2020-08-13 14:03:13 +00:00
|
|
|
if self.amp_backend == AMPType.NATIVE and self.precision == 16 and not self.use_tpu:
|
2020-05-12 04:25:06 +00:00
|
|
|
self.scaler = torch.cuda.amp.GradScaler()
|
|
|
|
|
2020-02-27 20:54:06 +00:00
|
|
|
# log hyper-parameters
|
2019-10-10 19:16:19 +00:00
|
|
|
if self.logger is not None:
|
|
|
|
# save exp to get started
|
2020-06-08 11:19:34 +00:00
|
|
|
self.logger.log_hyperparams(ref_model.hparams)
|
2020-08-19 23:08:46 +00:00
|
|
|
self.logger.log_graph(ref_model)
|
2019-10-10 19:16:19 +00:00
|
|
|
self.logger.save()
|
|
|
|
|
|
|
|
if self.use_ddp or self.use_ddp2:
|
2020-03-17 00:50:36 +00:00
|
|
|
torch_distrib.barrier()
|
2019-10-10 19:16:19 +00:00
|
|
|
|
2020-02-22 01:39:12 +00:00
|
|
|
# wait for all models to restore weights
|
|
|
|
if self.on_tpu and XLA_AVAILABLE:
|
|
|
|
# wait for all processes to catch up
|
2020-08-31 22:06:11 +00:00
|
|
|
torch_xla.core.xla_model.rendezvous("pl.Trainer.setup_training")
|
2020-02-22 01:39:12 +00:00
|
|
|
|
2020-05-04 17:02:57 +00:00
|
|
|
elif self.use_horovod:
|
|
|
|
# wait for all processes to catch up
|
|
|
|
hvd.join()
|
|
|
|
|
2019-09-06 15:54:51 +00:00
|
|
|
# register auto-resubmit when on SLURM
|
|
|
|
self.register_slurm_signal_handlers()
|
|
|
|
|
2020-08-31 22:06:11 +00:00
|
|
|
# --------------------------
|
|
|
|
# Pre-train
|
|
|
|
# --------------------------
|
2020-08-07 13:29:57 +00:00
|
|
|
# on pretrain routine start
|
|
|
|
self.on_pretrain_routine_start(ref_model)
|
|
|
|
if self.is_function_implemented('on_pretrain_routine_start'):
|
|
|
|
ref_model.on_pretrain_routine_start()
|
|
|
|
|
2019-07-08 21:15:26 +00:00
|
|
|
# print model summary
|
2020-06-13 16:00:14 +00:00
|
|
|
if self.is_global_zero and self.weights_summary is not None and not self.testing:
|
2020-06-15 21:05:58 +00:00
|
|
|
if self.weights_summary in ModelSummary.MODES:
|
2019-10-08 21:11:47 +00:00
|
|
|
ref_model.summarize(mode=self.weights_summary)
|
|
|
|
else:
|
2020-07-24 15:42:15 +00:00
|
|
|
raise MisconfigurationException("weights_summary can be None, " + ", ".join(ModelSummary.MODES))
|
2019-07-08 21:15:26 +00:00
|
|
|
|
2019-07-27 02:57:49 +00:00
|
|
|
# track model now.
|
|
|
|
# if cluster resets state, the model will update with the saved weights
|
|
|
|
self.model = model
|
|
|
|
|
2020-07-07 16:24:56 +00:00
|
|
|
# restore training and model before hpc is called
|
2019-09-06 15:54:51 +00:00
|
|
|
self.restore_weights(model)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-08-31 22:06:11 +00:00
|
|
|
# on pretrain routine end
|
|
|
|
self.on_pretrain_routine_end(ref_model)
|
|
|
|
if self.is_function_implemented('on_pretrain_routine_end'):
|
|
|
|
ref_model.on_pretrain_routine_end()
|
|
|
|
|
2020-09-06 21:50:47 +00:00
|
|
|
def train(self):
|
|
|
|
self.run_sanity_check(self.get_model())
|
|
|
|
|
|
|
|
# enable train mode
|
|
|
|
model = self.get_model()
|
|
|
|
model.train()
|
|
|
|
torch.set_grad_enabled(True)
|
|
|
|
|
|
|
|
# reload data when needed
|
|
|
|
self.train_loop.reset_train_val_dataloaders(model)
|
|
|
|
|
|
|
|
# hook
|
|
|
|
self.train_loop.on_train_start()
|
|
|
|
|
|
|
|
try:
|
|
|
|
# run all epochs
|
|
|
|
for epoch in range(self.current_epoch, self.max_epochs):
|
|
|
|
|
|
|
|
# reset train dataloader
|
|
|
|
if self.reload_dataloaders_every_epoch:
|
|
|
|
self.reset_train_dataloader(model)
|
|
|
|
|
|
|
|
# hook
|
|
|
|
self.train_loop.on_train_epoch_start(epoch)
|
|
|
|
|
|
|
|
# run train epoch
|
|
|
|
self.run_training_epoch()
|
|
|
|
|
|
|
|
if self.max_steps and self.max_steps <= self.global_step:
|
|
|
|
|
|
|
|
# hook
|
|
|
|
self.train_loop.on_train_end()
|
|
|
|
return
|
|
|
|
|
|
|
|
# update LR schedulers
|
|
|
|
self.update_learning_rates(interval='epoch')
|
|
|
|
|
|
|
|
# early stopping
|
|
|
|
met_min_epochs = epoch >= self.min_epochs - 1
|
|
|
|
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
|
|
|
|
|
|
|
|
if self.should_stop:
|
|
|
|
if (met_min_epochs and met_min_steps):
|
|
|
|
self.train_loop.on_train_end()
|
|
|
|
return
|
|
|
|
else:
|
|
|
|
log.info('Trainer was signaled to stop but required minimum epochs'
|
|
|
|
f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
|
|
|
|
' not been met. Training will continue...')
|
|
|
|
|
|
|
|
# hook
|
|
|
|
self.train_loop.on_train_end()
|
|
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')
|
|
|
|
|
|
|
|
# user could press ctrl+c many times... only shutdown once
|
|
|
|
if not self.interrupted:
|
|
|
|
self.interrupted = True
|
|
|
|
self._state = TrainerState.INTERRUPTED
|
|
|
|
self.on_keyboard_interrupt()
|
|
|
|
|
|
|
|
# hook
|
|
|
|
self.train_loop.on_train_end()
|
|
|
|
|
2020-09-01 00:36:52 +00:00
|
|
|
def run_test(self):
|
|
|
|
# only load test dataloader for testing
|
|
|
|
# self.reset_test_dataloader(ref_model)
|
|
|
|
eval_loop_results, _ = self.run_evaluation(test_mode=True)
|
2020-07-07 16:24:56 +00:00
|
|
|
|
2020-09-01 00:36:52 +00:00
|
|
|
if len(eval_loop_results) == 0:
|
|
|
|
return 1
|
2019-08-30 22:56:09 +00:00
|
|
|
|
2020-09-01 00:36:52 +00:00
|
|
|
# remove the tensors from the eval results
|
|
|
|
for i, result in enumerate(eval_loop_results):
|
|
|
|
if isinstance(result, dict):
|
|
|
|
for k, v in result.items():
|
|
|
|
if isinstance(v, torch.Tensor):
|
|
|
|
result[k] = v.cpu().item()
|
2020-07-14 18:20:45 +00:00
|
|
|
|
2020-09-01 00:36:52 +00:00
|
|
|
return eval_loop_results
|
2020-07-14 18:20:45 +00:00
|
|
|
|
2020-09-01 00:36:52 +00:00
|
|
|
def train_or_test(self):
|
|
|
|
if self.testing:
|
|
|
|
results = self.run_test()
|
|
|
|
else:
|
|
|
|
results = self.train()
|
|
|
|
return results
|
2020-07-14 18:20:45 +00:00
|
|
|
|
2020-09-01 00:36:52 +00:00
|
|
|
def run_sanity_check(self, ref_model):
|
|
|
|
using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model)
|
2020-07-25 16:57:40 +00:00
|
|
|
should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0
|
2020-07-23 11:07:03 +00:00
|
|
|
|
2019-08-30 22:56:09 +00:00
|
|
|
# run tiny validation (if validation defined)
|
|
|
|
# to make sure program won't crash during val
|
2020-07-23 11:07:03 +00:00
|
|
|
if should_sanity_check:
|
2020-02-26 21:55:18 +00:00
|
|
|
self.reset_val_dataloader(ref_model)
|
2020-08-21 18:11:31 +00:00
|
|
|
self.num_sanity_val_batches = [
|
|
|
|
min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches
|
|
|
|
]
|
2020-04-24 00:46:18 +00:00
|
|
|
|
|
|
|
# hook and callback
|
2020-07-22 17:53:10 +00:00
|
|
|
self.running_sanity_check = True
|
2020-04-24 00:46:18 +00:00
|
|
|
self.on_sanity_check_start()
|
2019-08-24 01:23:27 +00:00
|
|
|
|
2020-08-26 16:28:14 +00:00
|
|
|
# run eval step
|
|
|
|
_, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches)
|
2020-07-01 11:38:00 +00:00
|
|
|
|
|
|
|
# allow no returns from eval
|
|
|
|
if eval_results is not None and len(eval_results) > 0:
|
2020-07-14 18:20:45 +00:00
|
|
|
# when we get a list back, used only the last item
|
|
|
|
if isinstance(eval_results, list):
|
|
|
|
eval_results = eval_results[-1]
|
2020-07-22 17:53:10 +00:00
|
|
|
|
|
|
|
if isinstance(eval_results, EvalResult):
|
|
|
|
callback_metrics = eval_results.callback_metrics
|
|
|
|
else:
|
|
|
|
_, _, _, callback_metrics, _ = self.process_output(eval_results)
|
2020-07-01 11:38:00 +00:00
|
|
|
self.callback_metrics = callback_metrics
|
2019-08-07 11:51:55 +00:00
|
|
|
|
2020-04-24 00:46:18 +00:00
|
|
|
self.on_sanity_check_end()
|
2020-07-22 17:53:10 +00:00
|
|
|
self.running_sanity_check = False
|
2019-11-03 10:42:53 +00:00
|
|
|
|
2020-08-09 10:24:09 +00:00
|
|
|
@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
|
2020-05-04 12:24:34 +00:00
|
|
|
def test(
|
2020-07-24 15:42:15 +00:00
|
|
|
self,
|
|
|
|
model: Optional[LightningModule] = None,
|
|
|
|
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
|
|
|
|
ckpt_path: Optional[str] = 'best',
|
|
|
|
verbose: bool = True,
|
|
|
|
datamodule: Optional[LightningDataModule] = None,
|
2020-05-04 12:24:34 +00:00
|
|
|
):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
|
|
|
|
|
|
|
Separates from fit to make sure you never run on your test set until you want to.
|
|
|
|
|
|
|
|
Args:
|
2020-03-20 19:49:01 +00:00
|
|
|
model: The model to test.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-04-10 15:44:03 +00:00
|
|
|
test_dataloaders: Either a single
|
|
|
|
Pytorch Dataloader or a list of them, specifying validation samples.
|
|
|
|
|
2020-06-15 12:02:37 +00:00
|
|
|
ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
|
|
|
|
If ``None``, use the weights from the last epoch to test. Default to ``best``.
|
|
|
|
|
2020-07-14 18:20:45 +00:00
|
|
|
verbose: If True, prints the test results
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries
|
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
|
|
|
|
|
|
|
# Option 1
|
2020-06-15 12:02:37 +00:00
|
|
|
# run test with the best checkpoint from ``ModelCheckpoint`` after fitting.
|
2020-04-10 15:44:03 +00:00
|
|
|
test = DataLoader(...)
|
2020-01-17 11:03:31 +00:00
|
|
|
trainer = Trainer()
|
|
|
|
model = LightningModule()
|
|
|
|
|
2020-04-10 15:44:03 +00:00
|
|
|
trainer.fit(model)
|
|
|
|
trainer.test(test_dataloaders=test)
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
# Option 2
|
2020-06-15 12:02:37 +00:00
|
|
|
# run test with the specified checkpoint after fitting
|
|
|
|
test = DataLoader(...)
|
|
|
|
trainer = Trainer()
|
|
|
|
model = LightningModule()
|
|
|
|
|
|
|
|
trainer.fit(model)
|
|
|
|
trainer.test(test_dataloaders=test, ckpt_path='path/to/checkpoint.ckpt')
|
|
|
|
|
|
|
|
# Option 3
|
|
|
|
# run test with the weights from the end of training after fitting
|
|
|
|
test = DataLoader(...)
|
|
|
|
trainer = Trainer()
|
|
|
|
model = LightningModule()
|
|
|
|
|
|
|
|
trainer.fit(model)
|
|
|
|
trainer.test(test_dataloaders=test, ckpt_path=None)
|
|
|
|
|
|
|
|
# Option 4
|
|
|
|
# run test from a loaded model. ``ckpt_path`` is ignored in this case.
|
2020-04-10 15:44:03 +00:00
|
|
|
test = DataLoader(...)
|
2020-01-17 11:03:31 +00:00
|
|
|
model = LightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
|
|
|
|
trainer = Trainer()
|
2020-04-10 15:44:03 +00:00
|
|
|
trainer.test(model, test_dataloaders=test)
|
2020-01-17 11:03:31 +00:00
|
|
|
"""
|
2020-07-07 16:24:56 +00:00
|
|
|
# --------------------
|
|
|
|
# SETUP HOOK
|
|
|
|
# --------------------
|
2020-07-14 18:20:45 +00:00
|
|
|
self.verbose_test = verbose
|
|
|
|
|
2020-07-09 22:36:36 +00:00
|
|
|
if self.global_rank != 0:
|
|
|
|
return
|
|
|
|
|
2020-07-24 15:42:15 +00:00
|
|
|
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
|
|
|
|
if test_dataloaders and datamodule:
|
|
|
|
raise MisconfigurationException(
|
|
|
|
'You cannot pass test_dataloaders to trainer.test if you supply a datamodule'
|
|
|
|
)
|
|
|
|
|
|
|
|
# Attach datamodule to get setup/prepare_data added to model before the call to it below
|
2020-08-31 15:08:22 +00:00
|
|
|
self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test')
|
2020-07-09 22:36:36 +00:00
|
|
|
|
|
|
|
if model is not None:
|
|
|
|
results = self.__test_given_model(model, test_dataloaders)
|
|
|
|
else:
|
|
|
|
results = self.__test_using_best_weights(ckpt_path, test_dataloaders)
|
|
|
|
|
|
|
|
self.teardown('test')
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
def __test_using_best_weights(self, ckpt_path, test_dataloaders):
|
|
|
|
model = self.get_model()
|
2020-06-17 23:49:58 +00:00
|
|
|
|
2020-07-07 16:24:56 +00:00
|
|
|
# if user requests the best checkpoint but we don't have it, error
|
2020-07-09 22:36:36 +00:00
|
|
|
if ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0:
|
2020-06-15 12:02:37 +00:00
|
|
|
raise MisconfigurationException(
|
2020-07-24 15:42:15 +00:00
|
|
|
'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.'
|
|
|
|
)
|
2020-06-15 12:02:37 +00:00
|
|
|
|
2020-07-09 22:36:36 +00:00
|
|
|
# load best weights
|
|
|
|
if ckpt_path is not None:
|
2020-06-15 12:02:37 +00:00
|
|
|
# ckpt_path is 'best' so load the best model
|
|
|
|
if ckpt_path == 'best':
|
|
|
|
ckpt_path = self.checkpoint_callback.best_model_path
|
2020-03-03 04:38:47 +00:00
|
|
|
|
2020-07-10 01:28:11 +00:00
|
|
|
if len(ckpt_path) == 0:
|
2020-07-24 15:42:15 +00:00
|
|
|
rank_zero_warn(
|
|
|
|
f'.test() found no path for the best weights, {ckpt_path}. Please '
|
|
|
|
f'specify a path for a checkpoint .test(ckpt_path=PATH)'
|
|
|
|
)
|
2020-07-10 01:28:11 +00:00
|
|
|
return {}
|
|
|
|
|
2020-07-09 22:36:36 +00:00
|
|
|
ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
|
|
|
|
model.load_state_dict(ckpt['state_dict'])
|
2020-04-10 15:44:03 +00:00
|
|
|
|
2020-07-09 22:36:36 +00:00
|
|
|
# attach dataloaders
|
2020-04-10 15:44:03 +00:00
|
|
|
if test_dataloaders is not None:
|
2020-08-31 15:08:22 +00:00
|
|
|
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)
|
2020-04-10 15:44:03 +00:00
|
|
|
|
2020-07-09 22:36:36 +00:00
|
|
|
# run tests
|
|
|
|
self.tested_ckpt_path = ckpt_path
|
2020-07-07 16:24:56 +00:00
|
|
|
self.testing = True
|
2020-07-09 22:36:36 +00:00
|
|
|
os.environ['PL_TESTING_MODE'] = '1'
|
2020-07-07 16:24:56 +00:00
|
|
|
self.model = model
|
|
|
|
results = self.fit(model)
|
2020-03-06 11:57:14 +00:00
|
|
|
self.testing = False
|
2020-07-09 22:36:36 +00:00
|
|
|
del os.environ['PL_TESTING_MODE']
|
2020-03-06 11:57:14 +00:00
|
|
|
|
2020-07-09 22:36:36 +00:00
|
|
|
# teardown
|
2020-06-17 23:49:58 +00:00
|
|
|
if self.is_function_implemented('teardown'):
|
2020-06-25 15:10:17 +00:00
|
|
|
model_ref = self.get_model()
|
|
|
|
model_ref.teardown('test')
|
2020-06-17 23:49:58 +00:00
|
|
|
|
2020-07-07 16:24:56 +00:00
|
|
|
return results
|
|
|
|
|
2020-07-09 22:36:36 +00:00
|
|
|
def __test_given_model(self, model, test_dataloaders):
|
|
|
|
|
|
|
|
# attach data
|
|
|
|
if test_dataloaders is not None:
|
2020-08-31 15:08:22 +00:00
|
|
|
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)
|
2020-07-09 22:36:36 +00:00
|
|
|
|
|
|
|
# run test
|
|
|
|
# sets up testing so we short circuit to eval
|
|
|
|
self.testing = True
|
|
|
|
self.model = model
|
|
|
|
results = self.fit(model)
|
|
|
|
self.testing = False
|
|
|
|
|
|
|
|
# teardown
|
|
|
|
if self.is_function_implemented('teardown'):
|
|
|
|
model.teardown('test')
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
2020-06-19 00:15:02 +00:00
|
|
|
def barrier(self, name):
|
|
|
|
if self.use_ddp or self.use_ddp2:
|
2020-07-07 16:24:56 +00:00
|
|
|
pass
|
|
|
|
# torch_distrib.barrier()
|
2020-06-19 00:15:02 +00:00
|
|
|
|
|
|
|
if self.on_tpu and XLA_AVAILABLE:
|
|
|
|
# wait for all processes to catch up
|
|
|
|
torch_xla.core.xla_model.rendezvous(f'pl.Trainer.{name}')
|
|
|
|
|
2020-08-02 00:17:57 +00:00
|
|
|
def call_setup_hook(self, model):
|
|
|
|
# call setup after the ddp process has connected
|
|
|
|
stage_name = 'test' if self.testing else 'fit'
|
|
|
|
if self.datamodule is not None:
|
|
|
|
called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit
|
|
|
|
if not called:
|
|
|
|
self.datamodule.setup(stage_name)
|
|
|
|
self.setup(stage_name)
|
|
|
|
model.setup(stage_name)
|
|
|
|
|
2020-08-13 14:03:13 +00:00
|
|
|
def init_amp(self, amp_type: str):
|
|
|
|
assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'
|
|
|
|
self.amp_backend = None
|
|
|
|
self._setup_amp_backend(amp_type)
|
|
|
|
|
2020-08-24 17:46:46 +00:00
|
|
|
def call_hook(self, hook_name, *args, **kwargs):
|
2020-08-24 19:48:14 +00:00
|
|
|
# always profile hooks
|
|
|
|
with self.profiler.profile(hook_name):
|
|
|
|
|
|
|
|
# first call trainer hook
|
|
|
|
if hasattr(self, hook_name):
|
|
|
|
trainer_hook = getattr(self, hook_name)
|
|
|
|
trainer_hook(*args, **kwargs)
|
|
|
|
|
|
|
|
# next call hook in lightningModule
|
|
|
|
output = None
|
2020-08-31 16:12:02 +00:00
|
|
|
model_ref = self.get_model()
|
|
|
|
if is_overridden(hook_name, model_ref):
|
2020-08-24 17:46:46 +00:00
|
|
|
hook_fx = getattr(model_ref, hook_name)
|
|
|
|
output = hook_fx(*args, **kwargs)
|
|
|
|
|
2020-08-24 21:50:47 +00:00
|
|
|
# if the PL module doesn't have the hook then call the accelator
|
|
|
|
# used to auto-reduce things for the user with Results obj
|
|
|
|
elif hasattr(self.accelerator_backend, hook_name):
|
|
|
|
accelerator_hook = getattr(self.accelerator_backend, hook_name)
|
|
|
|
output = accelerator_hook(*args, **kwargs)
|
|
|
|
|
2020-08-24 19:48:14 +00:00
|
|
|
return output
|
2020-08-24 17:46:46 +00:00
|
|
|
|
2020-02-26 04:17:27 +00:00
|
|
|
|
2020-08-07 11:02:36 +00:00
|
|
|
def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
|
2020-06-17 17:42:28 +00:00
|
|
|
if 0 <= batches <= 1:
|
|
|
|
return batches
|
|
|
|
elif batches > 1 and batches % 1.0 == 0:
|
|
|
|
return int(batches)
|
|
|
|
else:
|
|
|
|
raise MisconfigurationException(
|
2020-08-07 11:02:36 +00:00
|
|
|
f'You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int.'
|
2020-06-17 17:42:28 +00:00
|
|
|
)
|