2019-07-09 00:11:20 +00:00
|
|
|
import os
|
2019-11-21 18:26:24 +00:00
|
|
|
import sys
|
2019-08-05 21:57:39 +00:00
|
|
|
import warnings
|
2020-02-01 20:47:58 +00:00
|
|
|
import logging as log
|
2020-02-23 02:23:30 +00:00
|
|
|
from typing import Union, Optional, List, Dict, Tuple
|
2019-07-09 00:11:20 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
import torch
|
2019-07-09 00:11:20 +00:00
|
|
|
import torch.distributed as dist
|
2019-10-22 08:32:40 +00:00
|
|
|
import torch.multiprocessing as mp
|
2020-02-23 02:23:30 +00:00
|
|
|
from torch.utils.data import DataLoader
|
2020-01-26 15:19:09 +00:00
|
|
|
from tqdm.auto import tqdm
|
2019-08-15 15:31:56 +00:00
|
|
|
from torch.optim.optimizer import Optimizer
|
2019-07-09 00:11:20 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
|
|
|
from pytorch_lightning.loggers import LightningLoggerBase
|
|
|
|
from pytorch_lightning.profiler.profiler import BaseProfiler
|
2019-12-04 16:39:14 +00:00
|
|
|
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
|
|
|
|
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
|
|
|
|
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
|
|
|
|
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
|
|
|
|
from pytorch_lightning.trainer.distrib_parts import (
|
|
|
|
TrainerDPMixin,
|
2019-10-23 09:05:09 +00:00
|
|
|
parse_gpu_ids,
|
|
|
|
determine_root_gpu_device
|
|
|
|
)
|
2020-02-19 11:00:08 +00:00
|
|
|
from pytorch_lightning.core.lightning import LightningModule
|
2019-12-04 16:39:14 +00:00
|
|
|
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
|
|
|
|
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
|
|
|
|
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
|
2020-01-14 03:20:38 +00:00
|
|
|
from pytorch_lightning.trainer.training_io import TrainerIOMixin
|
2020-01-20 19:50:31 +00:00
|
|
|
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
|
2019-12-04 16:39:14 +00:00
|
|
|
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
|
2019-08-07 14:14:59 +00:00
|
|
|
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
2020-02-07 03:01:21 +00:00
|
|
|
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
|
|
|
|
|
2019-10-04 19:35:02 +00:00
|
|
|
|
2019-05-14 00:40:07 +00:00
|
|
|
try:
|
|
|
|
from apex import amp
|
2019-10-22 08:32:40 +00:00
|
|
|
|
2019-05-14 00:40:07 +00:00
|
|
|
APEX_AVAILABLE = True
|
2019-08-05 21:28:04 +00:00
|
|
|
except ImportError:
|
2019-05-14 00:40:07 +00:00
|
|
|
APEX_AVAILABLE = False
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2020-02-17 21:01:20 +00:00
|
|
|
try:
|
|
|
|
import torch_xla
|
|
|
|
import torch_xla.core.xla_model as xm
|
|
|
|
import torch_xla.distributed.xla_multiprocessing as xmp
|
|
|
|
|
|
|
|
XLA_AVAILABLE = True
|
|
|
|
except ImportError:
|
|
|
|
XLA_AVAILABLE = False
|
|
|
|
|
2019-07-09 00:12:27 +00:00
|
|
|
|
2019-10-22 01:16:51 +00:00
|
|
|
class Trainer(TrainerIOMixin,
|
|
|
|
TrainerDPMixin,
|
2019-12-04 15:57:32 +00:00
|
|
|
TrainerDDPMixin,
|
|
|
|
TrainerLoggingMixin,
|
|
|
|
TrainerModelHooksMixin,
|
|
|
|
TrainerTrainingTricksMixin,
|
2019-10-22 01:16:51 +00:00
|
|
|
TrainerDataLoadingMixin,
|
|
|
|
TrainerAMPMixin,
|
|
|
|
TrainerEvaluationLoopMixin,
|
|
|
|
TrainerTrainLoopMixin,
|
|
|
|
TrainerCallbackConfigMixin,
|
2019-12-04 15:57:32 +00:00
|
|
|
):
|
2020-02-09 22:39:10 +00:00
|
|
|
|
2019-12-04 11:57:10 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
2020-02-23 02:23:30 +00:00
|
|
|
logger: Union[LightningLoggerBase, bool] = True,
|
|
|
|
checkpoint_callback: Union[ModelCheckpoint, bool] = True,
|
|
|
|
early_stop_callback: Optional[Union[EarlyStopping, bool]] = None,
|
|
|
|
default_save_path: Optional[str] = None,
|
|
|
|
gradient_clip_val: float = 0,
|
2019-12-04 11:57:10 +00:00
|
|
|
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
|
2020-02-23 02:23:30 +00:00
|
|
|
process_position: int = 0,
|
2019-12-04 11:57:10 +00:00
|
|
|
nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0
|
2020-02-23 02:23:30 +00:00
|
|
|
num_nodes: int = 1,
|
|
|
|
gpus: Optional[Union[List[int], str, int]] = None,
|
|
|
|
num_tpu_cores: Optional[int] = None,
|
|
|
|
log_gpu_memory: Optional[str] = None,
|
|
|
|
show_progress_bar: bool = True,
|
|
|
|
overfit_pct: float = 0.0,
|
|
|
|
track_grad_norm: int = -1,
|
|
|
|
check_val_every_n_epoch: int = 1,
|
|
|
|
fast_dev_run: bool = False,
|
|
|
|
accumulate_grad_batches: Union[int, Dict[int, int]] = 1,
|
2019-12-04 11:57:10 +00:00
|
|
|
max_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
|
|
|
|
min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
|
2020-02-23 02:23:30 +00:00
|
|
|
max_epochs: int = 1000,
|
|
|
|
min_epochs: int = 1,
|
|
|
|
max_steps: Optional[int] = None,
|
|
|
|
min_steps: Optional[int] = None,
|
|
|
|
train_percent_check: float = 1.0,
|
|
|
|
val_percent_check: float = 1.0,
|
|
|
|
test_percent_check: float = 1.0,
|
|
|
|
val_check_interval: Union[float] = 1.0,
|
|
|
|
log_save_interval: int = 100,
|
|
|
|
row_log_interval: int = 10,
|
2019-12-04 11:57:10 +00:00
|
|
|
add_row_log_interval=None, # backward compatible, todo: remove in v0.8.0
|
2020-02-23 02:23:30 +00:00
|
|
|
distributed_backend: Optional[str] = None,
|
2020-02-17 21:01:20 +00:00
|
|
|
use_amp=False, # backward compatible, todo: remove in v0.8.0
|
2020-02-23 02:23:30 +00:00
|
|
|
precision: int = 32,
|
|
|
|
print_nan_grads: bool = False,
|
|
|
|
weights_summary: str = 'full',
|
|
|
|
weights_save_path: Optional[str] = None,
|
|
|
|
amp_level: str = 'O1',
|
2019-12-04 11:57:10 +00:00
|
|
|
nb_sanity_val_steps=None, # backward compatible, todo: remove in v0.8.0
|
2020-02-23 02:23:30 +00:00
|
|
|
num_sanity_val_steps: int = 5,
|
|
|
|
truncated_bptt_steps: Optional[int] = None,
|
|
|
|
resume_from_checkpoint: Optional[str] = None,
|
|
|
|
profiler: Optional[BaseProfiler] = None,
|
2019-12-04 11:57:10 +00:00
|
|
|
):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
|
|
|
|
|
|
|
Customize every aspect of training via flags
|
|
|
|
|
|
|
|
Args:
|
2020-02-23 02:23:30 +00:00
|
|
|
logger: Logger for experiment tracking.
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-02-01 20:47:58 +00:00
|
|
|
from pytorch_lightning.loggers import TensorBoardLogger
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
# default logger used by trainer
|
|
|
|
logger = TensorBoardLogger(
|
|
|
|
save_dir=os.getcwd(),
|
|
|
|
version=self.slurm_job_id,
|
|
|
|
name='lightning_logs'
|
|
|
|
)
|
|
|
|
|
|
|
|
Trainer(logger=logger)
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
checkpoint_callback: Callback for checkpointing.
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
|
|
|
|
|
|
# default used by the Trainer
|
|
|
|
checkpoint_callback = ModelCheckpoint(
|
|
|
|
filepath=os.getcwd(),
|
|
|
|
save_best_only=True,
|
|
|
|
verbose=True,
|
|
|
|
monitor='val_loss',
|
|
|
|
mode='min',
|
|
|
|
prefix=''
|
|
|
|
)
|
|
|
|
|
|
|
|
trainer = Trainer(checkpoint_callback=checkpoint_callback)
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
early_stop_callback: Callback for early stopping. If
|
2020-01-23 16:12:51 +00:00
|
|
|
set to ``True``, then the default callback monitoring ``'val_loss'`` is created.
|
|
|
|
Will raise an error if ``'val_loss'`` is not found.
|
|
|
|
If set to ``False``, then early stopping will be disabled.
|
|
|
|
If set to ``None``, then the default callback monitoring ``'val_loss'`` is created.
|
|
|
|
If ``'val_loss'`` is not found will work as if early stopping is disabled.
|
|
|
|
Default: ``None``.
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
from pytorch_lightning.callbacks import EarlyStopping
|
|
|
|
|
|
|
|
# default used by the Trainer
|
|
|
|
early_stop_callback = EarlyStopping(
|
|
|
|
monitor='val_loss',
|
|
|
|
patience=3,
|
2020-01-23 16:12:51 +00:00
|
|
|
strict=False,
|
|
|
|
verbose=False,
|
2020-01-17 11:03:31 +00:00
|
|
|
mode='min'
|
|
|
|
)
|
|
|
|
|
|
|
|
trainer = Trainer(early_stop_callback=early_stop_callback)
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
default_save_path: Default path for logs and weights when no logger/ckpt_callback passed
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(default_save_path=os.getcwd())
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
gradient_clip_val: 0 means don't clip.
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(gradient_clip_val=0.0)
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
gradient_clip:
|
2020-02-17 21:01:20 +00:00
|
|
|
.. warning: .. deprecated:: 0.5.0
|
2020-01-17 11:03:31 +00:00
|
|
|
Use `gradient_clip_val` instead. Will remove 0.8.0.
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
process_position: orders the tqdm bar when running multiple models on same machine.
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(process_position=0)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
num_nodes: number of GPU nodes for distributed training.
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(num_nodes=1)
|
|
|
|
|
|
|
|
# to train on 8 nodes
|
|
|
|
trainer = Trainer(num_nodes=8)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
nb_gpu_nodes:
|
2020-02-17 21:01:20 +00:00
|
|
|
..warning:: .. deprecated:: 0.5.0
|
2020-01-17 11:03:31 +00:00
|
|
|
Use `num_nodes` instead. Will remove 0.8.0.
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
gpus: Which GPUs to train on.
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer (ie: train on CPU)
|
|
|
|
trainer = Trainer(gpus=None)
|
|
|
|
|
|
|
|
# int: train on 2 gpus
|
|
|
|
trainer = Trainer(gpus=2)
|
|
|
|
|
|
|
|
# list: train on GPUs 1, 4 (by bus ordering)
|
|
|
|
trainer = Trainer(gpus=[1, 4])
|
|
|
|
trainer = Trainer(gpus='1, 4') # equivalent
|
|
|
|
|
|
|
|
# -1: train on all gpus
|
|
|
|
trainer = Trainer(gpus=-1)
|
|
|
|
trainer = Trainer(gpus='-1') # equivalent
|
|
|
|
|
|
|
|
# combine with num_nodes to train on multiple GPUs across nodes
|
|
|
|
trainer = Trainer(gpus=2, num_nodes=4) # uses 8 gpus in total
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
num_tpu_cores: How many TPU cores to train on (1 or 8).
|
2020-02-17 21:01:20 +00:00
|
|
|
A single TPU v2 or v3 has 8 cores. A TPU pod has
|
|
|
|
up to 2048 cores. A slice of a POD means you get as many cores
|
|
|
|
as you request.
|
|
|
|
|
|
|
|
You MUST use DistributedDataSampler with your dataloader for this
|
|
|
|
to work. Your effective batch size is batch_size * total tpu cores.
|
|
|
|
|
|
|
|
This parameter can be either 1 or 8.
|
|
|
|
|
|
|
|
Example::
|
|
|
|
|
|
|
|
# your_trainer_file.py
|
|
|
|
|
|
|
|
# default used by the Trainer (ie: train on CPU)
|
|
|
|
trainer = Trainer(num_tpu_cores=None)
|
|
|
|
|
|
|
|
# int: train on a single core
|
|
|
|
trainer = Trainer(num_tpu_cores=1)
|
|
|
|
|
|
|
|
# int: train on all cores few cores
|
|
|
|
trainer = Trainer(num_tpu_cores=8)
|
|
|
|
|
|
|
|
# for 8+ cores must submit via xla script with
|
|
|
|
# a max of 8 cores specified. The XLA script
|
|
|
|
# will duplicate script onto each TPU in the POD
|
|
|
|
trainer = Trainer(num_tpu_cores=8)
|
|
|
|
|
|
|
|
# -1: train on all available TPUs
|
|
|
|
trainer = Trainer(num_tpu_cores=-1)
|
|
|
|
|
|
|
|
To train on more than 8 cores (ie: a POD),
|
|
|
|
submit this script using the xla_dist script.
|
|
|
|
|
|
|
|
Example::
|
|
|
|
|
|
|
|
$ python -m torch_xla.distributed.xla_dist
|
|
|
|
--tpu=$TPU_POD_NAME
|
|
|
|
--conda-env=torch-xla-nightly
|
|
|
|
--env=XLA_USE_BF16=1
|
|
|
|
-- python your_trainer_file.py
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
log_gpu_memory: None, 'min_max', 'all'. Might slow performance
|
2020-01-17 11:03:31 +00:00
|
|
|
because it uses the output of nvidia-smi.
|
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(log_gpu_memory=None)
|
|
|
|
|
|
|
|
# log all the GPUs (on master node only)
|
|
|
|
trainer = Trainer(log_gpu_memory='all')
|
|
|
|
|
|
|
|
# log only the min and max memory on the master node
|
|
|
|
trainer = Trainer(log_gpu_memory='min_max')
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
show_progress_bar: If true shows tqdm progress bar
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(show_progress_bar=True)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
overfit_pct: uses this much data of all datasets.
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(overfit_pct=0.0)
|
|
|
|
|
|
|
|
# use only 1% of the train, test, val datasets
|
|
|
|
trainer = Trainer(overfit_pct=0.01)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
track_grad_norm: -1 no tracking. Otherwise tracks that norm
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(track_grad_norm=-1)
|
|
|
|
|
|
|
|
# track the 2-norm
|
|
|
|
trainer = Trainer(track_grad_norm=2)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
check_val_every_n_epoch: Check val every n train epochs.
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(check_val_every_n_epoch=1)
|
|
|
|
|
|
|
|
# run val loop every 10 training epochs
|
|
|
|
trainer = Trainer(check_val_every_n_epoch=10)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(fast_dev_run=False)
|
|
|
|
|
|
|
|
# runs 1 train, val, test batch and program ends
|
|
|
|
trainer = Trainer(fast_dev_run=True)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer (no accumulation)
|
|
|
|
trainer = Trainer(accumulate_grad_batches=1)
|
|
|
|
|
|
|
|
# accumulate every 4 batches (effective batch size is batch*4)
|
|
|
|
trainer = Trainer(accumulate_grad_batches=4)
|
|
|
|
|
|
|
|
# no accumulation for epochs 1-4. accumulate 3 for epochs 5-10. accumulate 20 after that
|
|
|
|
trainer = Trainer(accumulate_grad_batches={5: 3, 10: 20})
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
max_epochs: Stop training once this number of epochs is reached.
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(max_epochs=1000)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
max_nb_epochs:
|
2020-02-17 21:01:20 +00:00
|
|
|
.. warning:: .. deprecated:: 0.5.0
|
2020-01-17 11:03:31 +00:00
|
|
|
Use `max_epochs` instead. Will remove 0.8.0.
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
min_epochs: Force training for at least these many epochs
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(min_epochs=1)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
min_nb_epochs:
|
2020-02-17 21:01:20 +00:00
|
|
|
.. warning:: .. deprecated:: 0.5.0
|
2020-01-17 11:03:31 +00:00
|
|
|
Use `min_nb_epochs` instead. Will remove 0.8.0.
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
max_steps: Stop training after this number of steps. Disabled by default (None).
|
2020-02-18 16:23:22 +00:00
|
|
|
Training will stop if max_steps or max_epochs have reached (earliest).
|
|
|
|
Example::
|
|
|
|
|
|
|
|
# Stop after 100 steps
|
|
|
|
trainer = Trainer(max_steps=100)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
min_steps: Force training for at least these number of steps. Disabled by default (None).
|
2020-02-18 16:23:22 +00:00
|
|
|
Trainer will train model for at least min_steps or min_epochs (latest).
|
|
|
|
Example::
|
|
|
|
|
|
|
|
# Run at least for 100 steps (disable min_epochs)
|
|
|
|
trainer = Trainer(min_steps=100, min_epochs=0)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
train_percent_check: How much of training dataset to check.
|
2020-01-17 11:03:31 +00:00
|
|
|
Useful when debugging or testing something that happens at the end of an epoch.
|
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(train_percent_check=1.0)
|
|
|
|
|
|
|
|
# run through only 25% of the training set each epoch
|
|
|
|
trainer = Trainer(train_percent_check=0.25)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
val_percent_check: How much of validation dataset to check.
|
2020-01-17 11:03:31 +00:00
|
|
|
Useful when debugging or testing something that happens at the end of an epoch.
|
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(val_percent_check=1.0)
|
|
|
|
|
|
|
|
# run through only 25% of the validation set each epoch
|
|
|
|
trainer = Trainer(val_percent_check=0.25)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
test_percent_check: How much of test dataset to check.
|
2020-01-17 11:03:31 +00:00
|
|
|
Useful when debugging or testing something that happens at the end of an epoch.
|
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(test_percent_check=1.0)
|
|
|
|
|
|
|
|
# run through only 25% of the test set each epoch
|
|
|
|
trainer = Trainer(test_percent_check=0.25)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
val_check_interval: How often within one training epoch to check the validation set
|
2020-01-17 11:03:31 +00:00
|
|
|
If float, % of tng epoch. If int, check every n batch
|
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(val_check_interval=1.0)
|
|
|
|
|
|
|
|
# check validation set 4 times during a training epoch
|
|
|
|
trainer = Trainer(val_check_interval=0.25)
|
|
|
|
|
|
|
|
# check validation set every 1000 training batches
|
|
|
|
# use this when using iterableDataset and your dataset has no length
|
|
|
|
# (ie: production cases with streaming data)
|
|
|
|
trainer = Trainer(val_check_interval=1000)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
log_save_interval: Writes logs to disk this often
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(log_save_interval=100)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
row_log_interval: How often to add logging rows (does not write to disk)
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(row_log_interval=10)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
add_row_log_interval:
|
2020-02-17 21:01:20 +00:00
|
|
|
.. warning:: .. deprecated:: 0.5.0
|
2020-01-17 11:03:31 +00:00
|
|
|
Use `row_log_interval` instead. Will remove 0.8.0.
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
distributed_backend: The distributed backend to use.
|
2020-01-17 11:03:31 +00:00
|
|
|
Options: 'dp', 'ddp', 'ddp2'.
|
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(distributed_backend=None)
|
|
|
|
|
|
|
|
# dp = DataParallel (split a batch onto k gpus on same machine).
|
|
|
|
trainer = Trainer(gpus=2, distributed_backend='dp')
|
|
|
|
|
|
|
|
# ddp = DistributedDataParallel
|
|
|
|
# Each gpu trains by itself on a subset of the data.
|
|
|
|
# Gradients sync across all gpus and all machines.
|
|
|
|
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp')
|
|
|
|
|
|
|
|
# ddp2 = DistributedDataParallel + dp
|
|
|
|
# behaves like dp on every node
|
|
|
|
# syncs gradients across nodes like ddp
|
|
|
|
# useful for things like increasing the number of negative samples
|
|
|
|
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2')
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
use_amp:
|
2020-02-17 21:01:20 +00:00
|
|
|
.. warning:: .. deprecated:: 0.6.1
|
|
|
|
Use `precision` instead. Will remove 0.8.0.
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
precision: Full precision (32), half precision (16).
|
2020-02-17 21:01:20 +00:00
|
|
|
Can be used on CPU, GPU or TPUs.
|
|
|
|
|
|
|
|
If used on TPU will use torch.bfloat16 but tensor printing
|
|
|
|
will still show torch.float32.
|
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
2020-02-17 21:01:20 +00:00
|
|
|
trainer = Trainer(precision=32)
|
|
|
|
|
|
|
|
# 16-bit precision
|
|
|
|
trainer = Trainer(precision=16)
|
|
|
|
|
|
|
|
# one day
|
|
|
|
trainer = Trainer(precision=8|4|2)
|
2020-01-17 11:03:31 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
print_nan_grads: Prints gradients with nan values
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(print_nan_grads=False)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
weights_summary: Prints a summary of the weights when training begins.
|
2020-01-17 11:03:31 +00:00
|
|
|
Options: 'full', 'top', None.
|
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer (ie: print all weights)
|
|
|
|
trainer = Trainer(weights_summary='full')
|
|
|
|
|
|
|
|
# print only the top level modules
|
|
|
|
trainer = Trainer(weights_summary='top')
|
|
|
|
|
|
|
|
# don't print a summary
|
|
|
|
trainer = Trainer(weights_summary=None)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
weights_save_path: Where to save weights if specified.
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(weights_save_path=os.getcwd())
|
|
|
|
|
|
|
|
# save to your custom path
|
|
|
|
trainer = Trainer(weights_save_path='my/path')
|
|
|
|
|
|
|
|
# if checkpoint callback used, then overrides the weights path
|
|
|
|
# **NOTE: this saves weights to some/path NOT my/path
|
|
|
|
checkpoint_callback = ModelCheckpoint(filepath='some/path')
|
|
|
|
trainer = Trainer(
|
|
|
|
checkpoint_callback=checkpoint_callback,
|
|
|
|
weights_save_path='my/path'
|
|
|
|
)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
amp_level: The optimization level to use (O1, O2, etc...).
|
2020-01-17 11:03:31 +00:00
|
|
|
Check nvidia docs for level (https://nvidia.github.io/apex/amp.html#opt-levels)
|
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(amp_level='O1')
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine.
|
2020-01-17 11:03:31 +00:00
|
|
|
This catches any bugs in your validation without having to wait for the first validation check.
|
|
|
|
The Trainer uses 5 steps by default. Turn it off or modify it here.
|
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(num_sanity_val_steps=5)
|
|
|
|
|
|
|
|
# turn it off
|
|
|
|
trainer = Trainer(num_sanity_val_steps=0)
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
nb_sanity_val_steps:
|
2020-02-17 21:01:20 +00:00
|
|
|
.. warning:: .. deprecated:: 0.5.0
|
2020-01-17 11:03:31 +00:00
|
|
|
Use `num_sanity_val_steps` instead. Will remove 0.8.0.
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of
|
2020-01-17 11:03:31 +00:00
|
|
|
a much longer sequence If this is enabled, your batches will automatically get truncated
|
|
|
|
and the trainer will apply Truncated Backprop to it. Make sure your batches have a sequence
|
|
|
|
dimension. (`Williams et al. "An efficient gradient-based algorithm for on-line training of
|
|
|
|
recurrent network trajectories."
|
|
|
|
<http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.56.7941&rep=rep1&type=pdf>`_)
|
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer (ie: disabled)
|
|
|
|
trainer = Trainer(truncated_bptt_steps=None)
|
|
|
|
|
|
|
|
# backprop every 5 steps in a batch
|
|
|
|
trainer = Trainer(truncated_bptt_steps=5)
|
|
|
|
|
2020-02-01 20:51:42 +00:00
|
|
|
|
2020-02-11 04:55:22 +00:00
|
|
|
Lightning takes care to split your batch along the time-dimension.
|
|
|
|
|
|
|
|
.. note:: If you need to modify how the batch is split,
|
|
|
|
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.
|
|
|
|
|
|
|
|
.. note:: Using this feature requires updating your LightningModule's
|
|
|
|
:meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg.
|
2020-02-01 20:51:42 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.k
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
2020-01-26 13:38:01 +00:00
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(resume_from_checkpoint=None)
|
|
|
|
|
|
|
|
# resume from a specific checkpoint
|
|
|
|
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
|
2020-02-23 02:23:30 +00:00
|
|
|
profiler: To profile individual steps during training and assist in
|
2020-02-07 03:01:21 +00:00
|
|
|
identifying bottlenecks.
|
|
|
|
Example::
|
|
|
|
|
|
|
|
from pytorch_lightning.profiler import Profiler, AdvancedProfiler
|
|
|
|
|
|
|
|
# default used by the Trainer
|
|
|
|
trainer = Trainer(profiler=None)
|
|
|
|
|
|
|
|
# to profile standard training events
|
|
|
|
trainer = Trainer(profiler=True)
|
|
|
|
|
|
|
|
# equivalent to profiler=True
|
|
|
|
profiler = Profiler()
|
|
|
|
trainer = Trainer(profiler=profiler)
|
|
|
|
|
|
|
|
# advanced profiler for function-level stats
|
|
|
|
profiler = AdvancedProfiler()
|
|
|
|
trainer = Trainer(profiler=profiler)
|
2020-01-26 13:38:01 +00:00
|
|
|
|
|
|
|
.. warning:: Following arguments become deprecated and they will be removed in v0.8.0:
|
|
|
|
|
|
|
|
- `nb_sanity_val_steps`
|
|
|
|
|
2019-07-18 16:04:19 +00:00
|
|
|
"""
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
# Transfer params
|
2020-01-14 19:40:41 +00:00
|
|
|
# Backward compatibility
|
|
|
|
if nb_gpu_nodes is not None:
|
2019-12-04 11:57:10 +00:00
|
|
|
warnings.warn("`nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0"
|
2020-02-11 12:41:15 +00:00
|
|
|
" and this method will be removed in v0.8.0", DeprecationWarning)
|
2019-12-04 11:57:10 +00:00
|
|
|
if not num_nodes: # in case you did not set the proper value
|
|
|
|
num_nodes = nb_gpu_nodes
|
|
|
|
self.num_gpu_nodes = num_nodes
|
2020-01-14 19:40:41 +00:00
|
|
|
|
2019-09-04 14:43:46 +00:00
|
|
|
self.log_gpu_memory = log_gpu_memory
|
2020-01-14 19:40:41 +00:00
|
|
|
|
|
|
|
# Backward compatibility
|
|
|
|
if gradient_clip is not None:
|
2019-11-28 17:48:55 +00:00
|
|
|
warnings.warn("`gradient_clip` has renamed to `gradient_clip_val` since v0.5.0"
|
2020-02-11 12:41:15 +00:00
|
|
|
" and this method will be removed in v0.8.0", DeprecationWarning)
|
2019-12-04 11:57:10 +00:00
|
|
|
if not gradient_clip_val: # in case you did not set the proper value
|
|
|
|
gradient_clip_val = gradient_clip
|
2019-09-25 23:05:06 +00:00
|
|
|
self.gradient_clip_val = gradient_clip_val
|
2020-01-14 19:40:41 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
self.check_val_every_n_epoch = check_val_every_n_epoch
|
|
|
|
self.track_grad_norm = track_grad_norm
|
2019-11-30 19:50:50 +00:00
|
|
|
self.on_gpu = True if (gpus and torch.cuda.is_available()) else False
|
2020-02-17 21:01:20 +00:00
|
|
|
|
|
|
|
# tpu config
|
|
|
|
self.on_tpu = num_tpu_cores is not None
|
|
|
|
self.num_tpu_cores = num_tpu_cores
|
|
|
|
assert num_tpu_cores in [1, 8, None], 'num_tpu_cores can only be 1 or 8'
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
self.process_position = process_position
|
2019-10-08 19:30:06 +00:00
|
|
|
self.weights_summary = weights_summary
|
2020-01-14 19:40:41 +00:00
|
|
|
|
|
|
|
# Backward compatibility
|
|
|
|
if max_nb_epochs is not None:
|
2019-12-07 13:50:21 +00:00
|
|
|
warnings.warn("`max_nb_epochs` has renamed to `max_epochs` since v0.5.0"
|
2020-02-11 12:41:15 +00:00
|
|
|
" and this method will be removed in v0.8.0", DeprecationWarning)
|
2019-12-07 13:50:21 +00:00
|
|
|
if not max_epochs: # in case you did not set the proper value
|
|
|
|
max_epochs = max_nb_epochs
|
|
|
|
self.max_epochs = max_epochs
|
2020-01-14 19:40:41 +00:00
|
|
|
|
|
|
|
# Backward compatibility
|
|
|
|
if min_nb_epochs is not None:
|
2019-12-07 13:50:21 +00:00
|
|
|
warnings.warn("`min_nb_epochs` has renamed to `min_epochs` since v0.5.0"
|
2020-02-11 12:41:15 +00:00
|
|
|
" and this method will be removed in v0.8.0", DeprecationWarning)
|
2019-12-07 13:50:21 +00:00
|
|
|
if not min_epochs: # in case you did not set the proper value
|
|
|
|
min_epochs = min_nb_epochs
|
|
|
|
self.min_epochs = min_epochs
|
2020-01-14 19:40:41 +00:00
|
|
|
|
2020-02-18 16:23:22 +00:00
|
|
|
self.max_steps = max_steps
|
|
|
|
self.min_steps = min_steps
|
|
|
|
|
2020-01-14 19:40:41 +00:00
|
|
|
# Backward compatibility
|
|
|
|
if nb_sanity_val_steps is not None:
|
2019-12-04 11:57:10 +00:00
|
|
|
warnings.warn("`nb_sanity_val_steps` has renamed to `num_sanity_val_steps` since v0.5.0"
|
2020-02-11 12:41:15 +00:00
|
|
|
" and this method will be removed in v0.8.0", DeprecationWarning)
|
2019-12-04 11:57:10 +00:00
|
|
|
if not num_sanity_val_steps: # in case you did not set the proper value
|
|
|
|
num_sanity_val_steps = nb_sanity_val_steps
|
2020-01-14 19:40:41 +00:00
|
|
|
|
2019-12-04 11:57:10 +00:00
|
|
|
self.num_sanity_val_steps = num_sanity_val_steps
|
2019-07-01 22:38:07 +00:00
|
|
|
self.print_nan_grads = print_nan_grads
|
2019-10-31 10:45:28 +00:00
|
|
|
self.truncated_bptt_steps = truncated_bptt_steps
|
2019-11-30 21:48:38 +00:00
|
|
|
self.resume_from_checkpoint = resume_from_checkpoint
|
2019-10-24 10:43:35 +00:00
|
|
|
self.shown_warnings = set()
|
2019-07-08 13:42:13 +00:00
|
|
|
|
2019-10-09 14:23:08 +00:00
|
|
|
self.fast_dev_run = fast_dev_run
|
|
|
|
if self.fast_dev_run:
|
2019-12-04 11:57:10 +00:00
|
|
|
self.num_sanity_val_steps = 1
|
2019-12-07 13:50:21 +00:00
|
|
|
self.max_epochs = 1
|
2019-10-09 14:23:08 +00:00
|
|
|
m = '''
|
|
|
|
Running in fast_dev_run mode: will run a full train,
|
|
|
|
val loop using a single batch
|
|
|
|
'''
|
2020-02-01 20:47:58 +00:00
|
|
|
log.info(m)
|
2019-10-09 14:23:08 +00:00
|
|
|
|
2019-10-04 23:48:57 +00:00
|
|
|
# set default save path if user didn't provide one
|
|
|
|
self.default_save_path = default_save_path
|
|
|
|
if self.default_save_path is None:
|
|
|
|
self.default_save_path = os.getcwd()
|
|
|
|
|
2019-07-24 14:42:01 +00:00
|
|
|
# training bookeeping
|
2019-12-04 11:57:10 +00:00
|
|
|
self.total_batch_idx = 0
|
2019-07-24 14:42:01 +00:00
|
|
|
self.running_loss = []
|
|
|
|
self.avg_loss = 0
|
2019-12-04 11:57:10 +00:00
|
|
|
self.batch_idx = 0
|
2019-07-24 14:42:01 +00:00
|
|
|
self.tqdm_metrics = {}
|
2019-10-08 20:21:00 +00:00
|
|
|
self.callback_metrics = {}
|
2019-12-04 11:57:10 +00:00
|
|
|
self.num_val_batches = 0
|
|
|
|
self.num_training_batches = 0
|
|
|
|
self.num_test_batches = 0
|
2019-10-04 19:35:02 +00:00
|
|
|
self.get_train_dataloader = None
|
|
|
|
self.get_test_dataloaders = None
|
|
|
|
self.get_val_dataloaders = None
|
2019-10-22 02:10:00 +00:00
|
|
|
self.is_iterable_train_dataloader = False
|
2019-09-06 04:29:38 +00:00
|
|
|
|
|
|
|
# training state
|
|
|
|
self.model = None
|
|
|
|
self.testing = False
|
2020-01-14 03:31:15 +00:00
|
|
|
self.disable_validation = False
|
2019-09-06 04:29:38 +00:00
|
|
|
self.lr_schedulers = []
|
|
|
|
self.optimizers = None
|
|
|
|
self.global_step = 0
|
|
|
|
self.current_epoch = 0
|
|
|
|
self.total_batches = 0
|
|
|
|
|
2020-01-26 14:42:57 +00:00
|
|
|
# configure logger
|
|
|
|
self.configure_logger(logger)
|
|
|
|
|
2020-02-07 03:01:21 +00:00
|
|
|
# configure profiler
|
|
|
|
if profiler is True:
|
|
|
|
profiler = Profiler()
|
|
|
|
self.profiler = profiler or PassThroughProfiler()
|
|
|
|
|
2019-09-06 21:01:03 +00:00
|
|
|
# configure early stop callback
|
2019-10-04 23:48:57 +00:00
|
|
|
# creates a default one if none passed in
|
2020-01-26 14:42:57 +00:00
|
|
|
self.configure_early_stopping(early_stop_callback)
|
2019-10-04 23:48:57 +00:00
|
|
|
|
2019-12-03 12:59:41 +00:00
|
|
|
self.reduce_lr_on_plateau_scheduler = None
|
|
|
|
|
2019-10-04 23:48:57 +00:00
|
|
|
# configure checkpoint callback
|
|
|
|
self.checkpoint_callback = checkpoint_callback
|
2019-10-09 21:46:27 +00:00
|
|
|
self.weights_save_path = weights_save_path
|
2019-09-06 04:29:38 +00:00
|
|
|
|
|
|
|
# accumulated grads
|
2019-10-22 01:16:51 +00:00
|
|
|
self.configure_accumulated_gradients(accumulate_grad_batches)
|
2019-09-06 04:29:38 +00:00
|
|
|
|
2019-09-08 19:36:58 +00:00
|
|
|
# allow int, string and gpu list
|
2019-10-23 09:05:09 +00:00
|
|
|
self.data_parallel_device_ids = parse_gpu_ids(gpus)
|
|
|
|
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)
|
2019-09-06 04:29:38 +00:00
|
|
|
|
2020-02-17 21:01:20 +00:00
|
|
|
# tpu state flags
|
|
|
|
self.use_tpu = False
|
|
|
|
self.tpu_local_core_rank = None
|
|
|
|
self.tpu_global_core_rank = None
|
|
|
|
|
2019-09-06 04:29:38 +00:00
|
|
|
# distributed backend choice
|
|
|
|
self.use_ddp = False
|
2019-10-04 19:07:54 +00:00
|
|
|
self.use_ddp2 = False
|
2019-09-06 04:29:38 +00:00
|
|
|
self.use_dp = False
|
|
|
|
self.single_gpu = False
|
2019-10-04 19:07:54 +00:00
|
|
|
self.distributed_backend = distributed_backend
|
2019-12-04 11:57:10 +00:00
|
|
|
self.set_distributed_mode(distributed_backend, num_nodes)
|
2019-09-06 04:29:38 +00:00
|
|
|
|
2020-02-17 21:01:20 +00:00
|
|
|
# override dist backend when using tpus
|
|
|
|
if self.on_tpu:
|
|
|
|
self.init_tpu()
|
|
|
|
self.current_tpu_idx = None
|
|
|
|
|
2019-09-06 04:29:38 +00:00
|
|
|
# init flags for SLURM+ddp to work
|
|
|
|
self.proc_rank = 0
|
|
|
|
self.world_size = 1
|
|
|
|
self.node_rank = 0
|
2019-12-04 11:57:10 +00:00
|
|
|
self.configure_slurm_ddp(num_nodes)
|
2019-09-08 19:36:58 +00:00
|
|
|
|
|
|
|
# nvidia setup
|
2019-10-22 01:16:51 +00:00
|
|
|
self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)
|
2019-09-06 04:29:38 +00:00
|
|
|
|
|
|
|
# can't init progress bar here because starting a new process
|
2019-09-25 23:05:06 +00:00
|
|
|
# means the progress_bar won't survive pickling
|
2019-09-06 04:29:38 +00:00
|
|
|
self.show_progress_bar = show_progress_bar
|
|
|
|
|
|
|
|
# logging
|
|
|
|
self.log_save_interval = log_save_interval
|
|
|
|
self.val_check_interval = val_check_interval
|
2020-01-14 19:40:41 +00:00
|
|
|
|
|
|
|
# backward compatibility
|
2019-12-04 11:59:19 +00:00
|
|
|
if add_row_log_interval is not None:
|
|
|
|
warnings.warn("`add_row_log_interval` has renamed to `row_log_interval` since v0.5.0"
|
2020-02-11 12:41:15 +00:00
|
|
|
" and this method will be removed in v0.8.0", DeprecationWarning)
|
2019-12-04 11:59:19 +00:00
|
|
|
if not row_log_interval: # in case you did not set the proper value
|
|
|
|
row_log_interval = add_row_log_interval
|
2019-09-25 23:05:06 +00:00
|
|
|
self.row_log_interval = row_log_interval
|
2019-09-06 04:29:38 +00:00
|
|
|
|
|
|
|
# how much of the data to use
|
2019-10-22 01:16:51 +00:00
|
|
|
self.determine_data_use_amount(train_percent_check, val_percent_check,
|
|
|
|
test_percent_check, overfit_pct)
|
2019-09-06 04:29:38 +00:00
|
|
|
|
|
|
|
# 16 bit mixed precision training using apex
|
|
|
|
self.amp_level = amp_level
|
2020-02-17 21:01:20 +00:00
|
|
|
self.precision = precision
|
|
|
|
if self.precision == 16:
|
|
|
|
use_amp = True
|
2019-10-22 01:16:51 +00:00
|
|
|
self.init_amp(use_amp)
|
2019-09-06 04:29:38 +00:00
|
|
|
|
2019-10-05 18:45:37 +00:00
|
|
|
@property
|
2020-02-23 02:23:30 +00:00
|
|
|
def slurm_job_id(self) -> int:
|
2019-10-05 18:45:37 +00:00
|
|
|
try:
|
|
|
|
job_id = os.environ['SLURM_JOB_ID']
|
|
|
|
job_id = int(job_id)
|
2019-12-04 11:57:10 +00:00
|
|
|
except Exception:
|
2019-10-05 18:45:37 +00:00
|
|
|
job_id = None
|
|
|
|
return job_id
|
|
|
|
|
2019-09-06 04:29:38 +00:00
|
|
|
def __parse_gpu_ids(self, gpus):
|
2019-12-04 11:59:19 +00:00
|
|
|
"""Parse GPUs id.
|
|
|
|
|
|
|
|
:param list|str|int gpus: input GPU ids
|
|
|
|
:return list(int):
|
2019-09-08 19:36:58 +00:00
|
|
|
"""
|
2019-07-08 13:42:13 +00:00
|
|
|
# if gpus = -1 then use all available devices
|
|
|
|
# otherwise, split the string using commas
|
|
|
|
if gpus is not None:
|
2019-12-04 11:57:10 +00:00
|
|
|
if isinstance(gpus, list):
|
2019-09-06 04:29:38 +00:00
|
|
|
gpus = gpus
|
2019-12-04 11:57:10 +00:00
|
|
|
elif isinstance(gpus, str):
|
2019-07-21 12:08:21 +00:00
|
|
|
if gpus == '-1':
|
2019-09-06 04:29:38 +00:00
|
|
|
gpus = list(range(0, torch.cuda.device_count()))
|
2019-07-21 12:08:21 +00:00
|
|
|
else:
|
2019-09-06 04:29:38 +00:00
|
|
|
gpus = [int(x.strip()) for x in gpus.split(',')]
|
2019-12-04 11:57:10 +00:00
|
|
|
elif isinstance(gpus, int):
|
2019-09-08 19:36:58 +00:00
|
|
|
gpus = gpus
|
2019-07-08 13:42:13 +00:00
|
|
|
else:
|
2019-12-04 11:57:10 +00:00
|
|
|
raise ValueError('`gpus` has to be a string, int or list of ints')
|
2019-07-08 14:00:04 +00:00
|
|
|
|
2019-09-06 04:29:38 +00:00
|
|
|
return gpus
|
|
|
|
|
2019-09-11 11:52:36 +00:00
|
|
|
def __set_root_gpu(self, gpus):
|
|
|
|
if gpus is None:
|
|
|
|
return None
|
|
|
|
|
|
|
|
# set root gpu
|
|
|
|
root_gpu = 0
|
2020-02-01 23:44:05 +00:00
|
|
|
if isinstance(gpus, list):
|
2019-09-11 11:52:36 +00:00
|
|
|
root_gpu = gpus[0]
|
|
|
|
|
|
|
|
return root_gpu
|
|
|
|
|
2019-09-08 19:36:58 +00:00
|
|
|
@property
|
2020-02-23 02:23:30 +00:00
|
|
|
def num_gpus(self) -> int:
|
2019-09-08 19:36:58 +00:00
|
|
|
gpus = self.data_parallel_device_ids
|
|
|
|
if gpus is None:
|
|
|
|
return 0
|
2020-02-01 23:44:05 +00:00
|
|
|
return len(gpus)
|
2019-09-08 19:36:58 +00:00
|
|
|
|
2019-07-18 15:08:48 +00:00
|
|
|
@property
|
2020-02-23 02:23:30 +00:00
|
|
|
def data_parallel(self) -> bool:
|
2019-10-05 20:39:05 +00:00
|
|
|
return self.use_dp or self.use_ddp or self.use_ddp2
|
2019-07-18 15:08:48 +00:00
|
|
|
|
2019-10-22 01:16:51 +00:00
|
|
|
@property
|
2020-02-23 02:23:30 +00:00
|
|
|
def training_tqdm_dict(self) -> dict:
|
2019-12-04 11:59:19 +00:00
|
|
|
"""Read-only for tqdm metrics.
|
2019-10-22 01:16:51 +00:00
|
|
|
:return:
|
2019-03-31 01:45:16 +00:00
|
|
|
"""
|
2020-02-05 11:24:43 +00:00
|
|
|
ref_model = self.model if not self.data_parallel else self.model.module
|
2019-08-08 14:59:16 +00:00
|
|
|
|
2020-02-05 11:24:43 +00:00
|
|
|
return dict(**ref_model.get_tqdm_dict(), **self.tqdm_metrics)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-10-21 06:16:55 +00:00
|
|
|
@property
|
|
|
|
def tng_tqdm_dic(self):
|
2019-12-04 11:59:19 +00:00
|
|
|
"""Read-only for tqdm metrics.
|
|
|
|
|
2020-01-26 13:38:01 +00:00
|
|
|
:return: dictionary
|
|
|
|
|
2020-02-17 21:01:20 +00:00
|
|
|
.. warning:: .. deprecated:: 0.5.0
|
2020-01-26 13:38:01 +00:00
|
|
|
Use `training_tqdm_dict` instead. Will remove 0.8.0.
|
2019-10-21 06:16:55 +00:00
|
|
|
"""
|
2019-12-04 11:59:19 +00:00
|
|
|
warnings.warn("`tng_tqdm_dic` has renamed to `training_tqdm_dict` since v0.5.0"
|
2020-02-11 12:41:15 +00:00
|
|
|
" and this method will be removed in v0.8.0", DeprecationWarning)
|
2019-10-21 06:16:55 +00:00
|
|
|
return self.training_tqdm_dict
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
# -----------------------------
|
|
|
|
# MODEL TRAINING
|
|
|
|
# -----------------------------
|
2020-02-23 02:23:30 +00:00
|
|
|
def fit(
|
|
|
|
self,
|
|
|
|
model: LightningModule,
|
|
|
|
train_dataloader: Optional[DataLoader] = None,
|
|
|
|
val_dataloader: Optional[DataLoader] = None,
|
|
|
|
test_dataloader: Optional[DataLoader] = None
|
|
|
|
):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
|
|
|
Runs the full optimization routine.
|
|
|
|
|
2020-02-19 11:00:08 +00:00
|
|
|
Args:
|
2020-02-23 02:23:30 +00:00
|
|
|
model: Model to fit.
|
2020-02-19 11:00:08 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
train_dataloader: A Pytorch
|
2020-02-19 11:00:08 +00:00
|
|
|
DataLoader with training samples. If the model has
|
|
|
|
a predefined train_dataloader method this will be skipped.
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
val_dataloader: Either a single
|
2020-02-19 11:00:08 +00:00
|
|
|
Pytorch Dataloader or a list of them, specifying validation samples.
|
|
|
|
If the model has a predefined val_dataloader method this will be skipped
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
test_dataloader: Either a single
|
2020-02-19 11:00:08 +00:00
|
|
|
Pytorch Dataloader or a list of them, specifying validation samples.
|
|
|
|
If the model has a predefined val_dataloader method this will be skipped
|
|
|
|
|
2020-01-17 11:03:31 +00:00
|
|
|
Example::
|
|
|
|
|
2020-02-19 11:00:08 +00:00
|
|
|
# Option 1,
|
|
|
|
# Define the train_dataloader(), test_dataloader() and val_dataloader() fxs
|
|
|
|
# in the lightningModule
|
|
|
|
# RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY
|
2020-01-17 11:03:31 +00:00
|
|
|
trainer = Trainer()
|
|
|
|
model = LightningModule()
|
2020-02-19 11:00:08 +00:00
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
# Option 2
|
|
|
|
# in production cases we might want to pass different datasets to the same model
|
|
|
|
# Recommended for PRODUCTION SYSTEMS
|
|
|
|
train, val, test = DataLoader(...), DataLoader(...), DataLoader(...)
|
|
|
|
trainer = Trainer()
|
|
|
|
model = LightningModule()
|
|
|
|
trainer.fit(model, train_dataloader=train,
|
|
|
|
val_dataloader=val, test_dataloader=test)
|
|
|
|
|
|
|
|
# Option 1 & 2 can be mixed, for example the training set can be
|
|
|
|
# defined as part of the model, and validation/test can then be
|
|
|
|
# feed to .fit()
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
"""
|
2020-02-19 11:00:08 +00:00
|
|
|
|
|
|
|
# Update the dataloader attributes of the model with the ones supplied here,
|
|
|
|
# if they are not already defined in model
|
|
|
|
_set_dataloader(model, train_dataloader, 'train_dataloader')
|
|
|
|
_set_dataloader(model, val_dataloader, 'val_dataloader')
|
|
|
|
_set_dataloader(model, test_dataloader, 'test_dataloader')
|
|
|
|
|
2019-07-18 15:08:48 +00:00
|
|
|
# when using multi-node or DDP within a node start each module in a separate process
|
2019-10-05 20:39:05 +00:00
|
|
|
if self.use_ddp2:
|
|
|
|
task = int(os.environ['SLURM_LOCALID'])
|
|
|
|
self.ddp_train(task, model)
|
2019-07-18 20:47:46 +00:00
|
|
|
|
2019-10-05 20:39:05 +00:00
|
|
|
elif self.use_ddp:
|
|
|
|
if self.is_slurm_managing_tasks:
|
2019-07-18 20:47:46 +00:00
|
|
|
task = int(os.environ['SLURM_LOCALID'])
|
|
|
|
self.ddp_train(task, model)
|
|
|
|
else:
|
2019-10-22 08:32:40 +00:00
|
|
|
mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,))
|
2019-07-14 20:57:15 +00:00
|
|
|
|
2019-07-18 15:08:48 +00:00
|
|
|
# 1 gpu or dp option triggers training using DP module
|
|
|
|
# easier to avoid NCCL issues
|
|
|
|
elif self.use_dp:
|
2019-10-22 01:16:51 +00:00
|
|
|
self.dp_train(model)
|
2019-07-14 20:57:15 +00:00
|
|
|
|
2019-08-07 17:39:40 +00:00
|
|
|
elif self.single_gpu:
|
2019-10-22 01:16:51 +00:00
|
|
|
self.single_gpu_train(model)
|
2019-08-07 17:39:40 +00:00
|
|
|
|
2020-02-17 21:01:20 +00:00
|
|
|
elif self.use_tpu:
|
|
|
|
log.info(f'training on {self.num_tpu_cores} TPU cores')
|
|
|
|
|
|
|
|
# COLAB_GPU is an env var available by default in Colab environments.
|
|
|
|
start_method = 'fork' if os.getenv('COLAB_GPU') else 'spawn'
|
|
|
|
xmp.spawn(self.tpu_train, args=(model,), nprocs=self.num_tpu_cores, start_method=start_method)
|
|
|
|
|
2019-07-18 15:09:00 +00:00
|
|
|
# ON CPU
|
2019-07-03 19:09:49 +00:00
|
|
|
else:
|
2019-07-11 18:17:43 +00:00
|
|
|
# run through amp wrapper
|
|
|
|
if self.use_amp:
|
2019-12-04 11:59:19 +00:00
|
|
|
raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option')
|
2019-07-11 18:17:43 +00:00
|
|
|
|
2019-07-25 15:08:31 +00:00
|
|
|
# CHOOSE OPTIMIZER
|
2019-07-28 13:33:58 +00:00
|
|
|
# allow for lr schedulers as well
|
2019-08-15 15:31:56 +00:00
|
|
|
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
|
2019-07-25 15:08:31 +00:00
|
|
|
|
2019-10-22 01:16:51 +00:00
|
|
|
self.run_pretrain_routine(model)
|
2019-07-03 19:09:49 +00:00
|
|
|
|
2019-07-24 11:26:18 +00:00
|
|
|
# return 1 when finished
|
|
|
|
# used for testing or when we need to know that training succeeded
|
|
|
|
return 1
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
def init_optimizers(
|
|
|
|
self,
|
|
|
|
optimizers: Union[Optimizer, Tuple[List, List], List[Optimizer], Tuple[Optimizer]]
|
|
|
|
) -> Tuple[List, List]:
|
|
|
|
|
2019-08-15 15:31:56 +00:00
|
|
|
# single optimizer
|
|
|
|
if isinstance(optimizers, Optimizer):
|
|
|
|
return [optimizers], []
|
|
|
|
|
|
|
|
# two lists
|
2020-02-01 23:44:05 +00:00
|
|
|
if len(optimizers) == 2 and isinstance(optimizers[0], list):
|
2019-08-15 15:31:56 +00:00
|
|
|
optimizers, lr_schedulers = optimizers
|
2019-12-03 12:59:41 +00:00
|
|
|
lr_schedulers, self.reduce_lr_on_plateau_scheduler = self.configure_schedulers(lr_schedulers)
|
2019-08-15 15:31:56 +00:00
|
|
|
return optimizers, lr_schedulers
|
|
|
|
|
|
|
|
# single list or tuple
|
2020-02-01 23:44:05 +00:00
|
|
|
if isinstance(optimizers, (list, tuple)):
|
2019-08-15 15:31:56 +00:00
|
|
|
return optimizers, []
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
def configure_schedulers(self, schedulers: list):
|
2019-12-03 12:59:41 +00:00
|
|
|
for i, scheduler in enumerate(schedulers):
|
|
|
|
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
|
|
reduce_lr_on_plateau_scheduler = schedulers.pop(i)
|
|
|
|
return schedulers, reduce_lr_on_plateau_scheduler
|
|
|
|
return schedulers, None
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
def run_pretrain_routine(self, model: LightningModule):
|
2019-12-04 11:57:10 +00:00
|
|
|
"""Sanity check a few things before starting actual training.
|
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
Args:
|
|
|
|
model: The model to run sanity test on.
|
2019-07-03 19:09:49 +00:00
|
|
|
"""
|
2019-07-08 21:38:57 +00:00
|
|
|
ref_model = model
|
2019-07-14 02:21:17 +00:00
|
|
|
if self.data_parallel:
|
2019-07-08 21:38:57 +00:00
|
|
|
ref_model = model.module
|
|
|
|
|
2019-08-30 22:56:09 +00:00
|
|
|
# give model convenience properties
|
2019-07-08 22:55:05 +00:00
|
|
|
ref_model.trainer = self
|
|
|
|
|
2019-07-08 21:15:26 +00:00
|
|
|
# set local properties on the model
|
2019-10-22 01:16:51 +00:00
|
|
|
self.copy_trainer_model_properties(ref_model)
|
2019-07-08 21:15:26 +00:00
|
|
|
|
2019-10-10 19:16:19 +00:00
|
|
|
# link up experiment object
|
|
|
|
if self.logger is not None:
|
|
|
|
ref_model.logger = self.logger
|
|
|
|
|
|
|
|
# save exp to get started
|
|
|
|
if hasattr(ref_model, "hparams"):
|
|
|
|
self.logger.log_hyperparams(ref_model.hparams)
|
|
|
|
|
|
|
|
self.logger.save()
|
|
|
|
|
|
|
|
if self.use_ddp or self.use_ddp2:
|
|
|
|
dist.barrier()
|
|
|
|
|
2020-02-22 01:39:12 +00:00
|
|
|
# wait for all models to restore weights
|
|
|
|
if self.on_tpu and XLA_AVAILABLE:
|
|
|
|
# wait for all processes to catch up
|
|
|
|
torch_xla.core.xla_model.rendezvous()
|
|
|
|
|
2019-10-09 21:46:27 +00:00
|
|
|
# set up checkpoint callback
|
2019-10-22 01:16:51 +00:00
|
|
|
self.configure_checkpoint_callback()
|
2019-10-09 21:46:27 +00:00
|
|
|
|
2019-09-06 15:54:51 +00:00
|
|
|
# register auto-resubmit when on SLURM
|
|
|
|
self.register_slurm_signal_handlers()
|
|
|
|
|
2019-07-08 21:15:26 +00:00
|
|
|
# transfer data loaders from model
|
2019-07-24 21:09:14 +00:00
|
|
|
self.get_dataloaders(ref_model)
|
2019-07-08 21:15:26 +00:00
|
|
|
|
|
|
|
# print model summary
|
2019-10-08 21:11:47 +00:00
|
|
|
if self.proc_rank == 0 and self.weights_summary is not None:
|
|
|
|
if self.weights_summary in ['full', 'top']:
|
|
|
|
ref_model.summarize(mode=self.weights_summary)
|
|
|
|
else:
|
|
|
|
m = "weights_summary can be None, 'full' or 'top'"
|
|
|
|
raise MisconfigurationException(m)
|
2019-07-08 21:15:26 +00:00
|
|
|
|
2019-07-27 02:57:49 +00:00
|
|
|
# track model now.
|
|
|
|
# if cluster resets state, the model will update with the saved weights
|
|
|
|
self.model = model
|
|
|
|
|
2019-08-07 11:42:14 +00:00
|
|
|
# restore training and model before hpc call
|
2019-09-06 15:54:51 +00:00
|
|
|
self.restore_weights(model)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-08-30 22:56:09 +00:00
|
|
|
# when testing requested only run test and return
|
|
|
|
if self.testing:
|
2019-10-22 01:16:51 +00:00
|
|
|
self.run_evaluation(test=True)
|
2019-08-30 22:56:09 +00:00
|
|
|
return
|
|
|
|
|
2020-01-14 03:31:15 +00:00
|
|
|
# check if we should run validation during training
|
|
|
|
self.disable_validation = ((self.num_val_batches == 0 or
|
2020-02-16 01:24:38 +00:00
|
|
|
not self.is_overriden('validation_step')) and
|
2020-01-14 03:31:15 +00:00
|
|
|
not self.fast_dev_run)
|
|
|
|
|
2019-08-30 22:56:09 +00:00
|
|
|
# run tiny validation (if validation defined)
|
|
|
|
# to make sure program won't crash during val
|
2019-08-07 12:14:52 +00:00
|
|
|
ref_model.on_sanity_check_start()
|
2019-12-07 13:52:06 +00:00
|
|
|
ref_model.on_train_start()
|
2020-01-14 03:31:15 +00:00
|
|
|
if not self.disable_validation and self.num_sanity_val_steps > 0:
|
2019-11-03 10:42:53 +00:00
|
|
|
# init progress bars for validation sanity check
|
2020-01-26 15:19:09 +00:00
|
|
|
pbar = tqdm(desc='Validation sanity check',
|
2019-12-07 13:47:59 +00:00
|
|
|
total=self.num_sanity_val_steps * len(self.get_val_dataloaders()),
|
2019-11-03 10:42:53 +00:00
|
|
|
leave=False, position=2 * self.process_position,
|
2020-02-05 11:24:43 +00:00
|
|
|
disable=not self.show_progress_bar, dynamic_ncols=True)
|
2019-11-03 10:42:53 +00:00
|
|
|
self.main_progress_bar = pbar
|
|
|
|
# dummy validation progress bar
|
2020-01-26 15:19:09 +00:00
|
|
|
self.val_progress_bar = tqdm(disable=True)
|
2019-08-24 01:23:27 +00:00
|
|
|
|
2020-01-23 16:12:51 +00:00
|
|
|
eval_results = self.evaluate(model, self.get_val_dataloaders(),
|
|
|
|
self.num_sanity_val_steps, False)
|
|
|
|
_, _, _, callback_metrics, _ = self.process_output(eval_results)
|
2019-08-07 11:51:55 +00:00
|
|
|
|
2019-11-03 10:42:53 +00:00
|
|
|
# close progress bars
|
|
|
|
self.main_progress_bar.close()
|
|
|
|
self.val_progress_bar.close()
|
|
|
|
|
2020-01-23 16:12:51 +00:00
|
|
|
if self.enable_early_stop:
|
|
|
|
self.early_stop_callback.check_metrics(callback_metrics)
|
|
|
|
|
2019-11-03 10:42:53 +00:00
|
|
|
# init progress bar
|
2020-01-26 15:19:09 +00:00
|
|
|
pbar = tqdm(leave=True, position=2 * self.process_position,
|
2020-02-05 11:24:43 +00:00
|
|
|
disable=not self.show_progress_bar, dynamic_ncols=True,
|
2020-01-26 15:19:09 +00:00
|
|
|
file=sys.stdout)
|
2019-11-03 10:42:53 +00:00
|
|
|
self.main_progress_bar = pbar
|
|
|
|
|
2019-10-23 15:41:00 +00:00
|
|
|
# clear cache before training
|
|
|
|
if self.on_gpu:
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
# CORE TRAINING LOOP
|
2019-10-22 01:16:51 +00:00
|
|
|
self.train()
|
2019-10-05 17:35:20 +00:00
|
|
|
|
2020-02-23 02:23:30 +00:00
|
|
|
def test(self, model: Optional[LightningModule] = None):
|
2020-01-17 11:03:31 +00:00
|
|
|
r"""
|
|
|
|
|
|
|
|
Separates from fit to make sure you never run on your test set until you want to.
|
|
|
|
|
|
|
|
Args:
|
2020-02-23 02:23:30 +00:00
|
|
|
model: The model to test.
|
2020-01-17 11:03:31 +00:00
|
|
|
|
|
|
|
Example::
|
|
|
|
|
|
|
|
# Option 1
|
|
|
|
# run test after fitting
|
|
|
|
trainer = Trainer()
|
|
|
|
model = LightningModule()
|
|
|
|
|
|
|
|
trainer.fit()
|
|
|
|
trainer.test()
|
|
|
|
|
|
|
|
# Option 2
|
|
|
|
# run test from a loaded model
|
|
|
|
model = LightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
|
|
|
|
trainer = Trainer()
|
|
|
|
trainer.test(model)
|
|
|
|
"""
|
2019-10-18 22:39:30 +00:00
|
|
|
self.testing = True
|
2019-08-30 22:56:09 +00:00
|
|
|
if model is not None:
|
|
|
|
self.fit(model)
|
|
|
|
else:
|
2019-10-22 01:16:51 +00:00
|
|
|
self.run_evaluation(test=True)
|
2020-02-19 11:00:08 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _set_dataloader(model, dataloader, attribute):
|
|
|
|
r'''
|
|
|
|
Check dataloaders passed to .fit() method if they are pytorch DataLoader
|
|
|
|
objects and whether or not we should overright the corresponding dataloader
|
|
|
|
in the model
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model (LightningModule): The model to check
|
|
|
|
|
|
|
|
dataloader: If a pytorch dataloader (or a list of pytorch dataloaders)
|
|
|
|
is passed, it will be incorporate into the model as model.attribute.
|
|
|
|
If attribute alreay exist it will warn the userpass. If not a
|
|
|
|
dataloader will throw an error
|
|
|
|
|
|
|
|
attribute (str): The attribute to save the dataloader under
|
|
|
|
|
|
|
|
'''
|
|
|
|
# Check if attribute comes directly from base class or
|
|
|
|
# derived in user subclass
|
|
|
|
if LightningModule.__qualname__ in getattr(model, attribute).__qualname__:
|
|
|
|
# Val and test should be list of dataloaders
|
|
|
|
dataloader = dataloader if attribute == 'train_dataloader' or \
|
|
|
|
(attribute != 'train_dataloader' and isinstance(dataloader, list)) else [dataloader]
|
|
|
|
|
|
|
|
# Check we are given valid dataloaders
|
|
|
|
is_dataloader = isinstance(dataloader, torch.utils.data.DataLoader)
|
|
|
|
is_dataloader_list = isinstance(dataloader, list)
|
|
|
|
if is_dataloader_list:
|
|
|
|
valid_loaders = all(isinstance(d, torch.utils.data.DataLoader) for d in dataloader)
|
|
|
|
if is_dataloader or is_dataloader_list and valid_loaders:
|
|
|
|
|
|
|
|
# Overwrite abstract methods
|
|
|
|
dl = lambda: dataloader
|
|
|
|
dl.__name__ = attribute
|
|
|
|
setattr(model, attribute, dl)
|
|
|
|
|
|
|
|
elif dataloader and dataloader != [None]:
|
|
|
|
raise ValueError(f'`{attribute}` needs to be an instance of '
|
|
|
|
'`torch.utils.data.DataLoader` or a list of '
|
|
|
|
'DataLoaders, instead got %r`' % dataloader)
|
|
|
|
|
|
|
|
elif dataloader: # if default (None) is passed, do not warn the user
|
|
|
|
warnings.warn(f'Model has predefined `{attribute}`,'
|
|
|
|
f' will skip `{attribute}={dataloader}` passed to fit method.')
|