lightning/pytorch_lightning/trainer/trainer.py

1232 lines
48 KiB
Python

import inspect
import os
from argparse import ArgumentParser, Namespace
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any
import torch
import torch.distributed as torch_distrib
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler, BaseProfiler
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_9
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
from pytorch_lightning.trainer.distrib_parts import (
TrainerDPMixin, parse_gpu_ids, determine_root_gpu_device, pick_multiple_gpus)
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info
try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True
try:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True
try:
import horovod.torch as hvd
except ImportError:
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True
class Trainer(
TrainerIOMixin,
TrainerOptimizersMixin,
TrainerAMPMixin,
TrainerDPMixin,
TrainerDDPMixin,
TrainerLoggingMixin,
TrainerModelHooksMixin,
TrainerTrainingTricksMixin,
TrainerDataLoadingMixin,
TrainerEvaluationLoopMixin,
TrainerTrainLoopMixin,
TrainerCallbackConfigMixin,
TrainerCallbackHookMixin,
TrainerLRFinderMixin,
TrainerDeprecatedAPITillVer0_9,
):
DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar', 'training_tqdm_dict', 'num_tpu_cores')
def __init__(
self,
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
checkpoint_callback: Union[ModelCheckpoint, bool] = True,
early_stop_callback: Optional[Union[EarlyStopping, bool]] = False,
callbacks: Optional[List[Callback]] = None,
default_root_dir: Optional[str] = None,
gradient_clip_val: float = 0,
process_position: int = 0,
num_nodes: int = 1,
num_processes: int = 1,
gpus: Optional[Union[List[int], str, int]] = None,
auto_select_gpus: bool = False,
tpu_cores: Optional[Union[List[int], int]] = None,
log_gpu_memory: Optional[str] = None,
progress_bar_refresh_rate: int = 1,
overfit_batches: Union[int, float] = 0.0,
track_grad_norm: Union[int, float, str] = -1,
check_val_every_n_epoch: int = 1,
fast_dev_run: bool = False,
accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
max_epochs: int = 1000,
min_epochs: int = 1,
max_steps: Optional[int] = None,
min_steps: Optional[int] = None,
train_percent_check: float = 1.0,
limit_val_batches: Union[int, float] = 1.0,
limit_test_batches: Union[int, float] = 1.0,
val_check_interval: float = 1.0,
log_save_interval: int = 100,
row_log_interval: int = 50,
distributed_backend: Optional[str] = None,
precision: int = 32,
print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0
weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT,
weights_save_path: Optional[str] = None,
num_sanity_val_steps: int = 2,
truncated_bptt_steps: Optional[int] = None,
resume_from_checkpoint: Optional[str] = None,
profiler: Optional[Union[BaseProfiler, bool]] = None,
benchmark: bool = False,
deterministic: bool = False,
reload_dataloaders_every_epoch: bool = False,
auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
terminate_on_nan: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
prepare_data_per_node: bool = True,
amp_level: str = 'O1', # backward compatible, todo: remove in v1.0.0
num_tpu_cores: Optional[int] = None, # backward compatible, todo: remove in v0.9.0
use_amp=None, # backward compatible, todo: remove in v0.9.0
show_progress_bar=None, # backward compatible, todo: remove in v0.9.0
val_percent_check: float = 1.0, # backward compatible, todo: remove in v1.0.0
test_percent_check: float = 1.0, # backward compatible, todo: remove in v1.0.0
overfit_pct: float = 0.0 # backward compatible, todo: remove in v1.0.0
):
r"""
Customize every aspect of training via flags
Args:
logger: Logger (or iterable collection of loggers) for experiment tracking.
checkpoint_callback: Callback for checkpointing.
early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`):
callbacks: Add a list of callbacks.
default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed
gradient_clip_val: 0 means don't clip.
gradient_clip:
.. warning:: .. deprecated:: 0.7.0
Use `gradient_clip_val` instead. Will remove 0.9.0.
process_position: orders the progress bar when running multiple models on same machine.
num_nodes: number of GPU nodes for distributed training.
nb_gpu_nodes:
.. warning:: .. deprecated:: 0.7.0
Use `num_nodes` instead. Will remove 0.9.0.
gpus: Which GPUs to train on.
auto_select_gpus:
If enabled and `gpus` is an integer, pick available
gpus automatically. This is especially useful when
GPUs are configured to be in "exclusive mode", such
that only one process at a time can access them.
tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1]
num_tpu_cores: How many TPU cores to train on (1 or 8)
.. warning:: .. deprecated:: 0.7.6. Will remove 0.9.0.
log_gpu_memory: None, 'min_max', 'all'. Might slow performance
show_progress_bar:
.. warning:: .. deprecated:: 0.7.2
Set `progress_bar_refresh_rate` to positive integer to enable. Will remove 0.9.0.
progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`.
overfit_batches: Overfit a percent of training data (float) or a set number of batches (int).
overfit_pct:
.. warning:: .. deprecated:: 0.8.0
Use `overfit_batches` instead. Will remove 1.0.0.
track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.
check_val_every_n_epoch: Check val every n train epochs.
fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).
accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.
max_epochs: Stop training once this number of epochs is reached.
max_nb_epochs:
.. warning:: .. deprecated:: 0.7.0
Use `max_epochs` instead. Will remove 0.9.0.
min_epochs: Force training for at least these many epochs
min_nb_epochs:
.. warning:: .. deprecated:: 0.7.0
Use `min_epochs` instead. Will remove 0.9.0.
max_steps: Stop training after this number of steps. Disabled by default (None).
min_steps: Force training for at least these number of steps. Disabled by default (None).
train_percent_check: How much of training dataset to check.
limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches)
limit_test_batches: How much of test dataset to check (floats = percent, int = num_batches)
val_percent_check:
.. warning:: .. deprecated:: 0.8.0
Use `min_epochs` instead. Will remove 1.0.0.
test_percent_check:
.. warning:: .. deprecated:: 0.8.0
Use `min_epochs` instead. Will remove 1.0.0.
val_check_interval: How often within one training epoch to check the validation set
log_save_interval: Writes logs to disk this often
row_log_interval: How often to add logging rows (does not write to disk)
add_row_log_interval:
.. warning:: .. deprecated:: 0.7.0
Use `row_log_interval` instead. Will remove 0.9.0.
distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn)
use_amp:
.. warning:: .. deprecated:: 0.7.0
Use `precision` instead. Will remove 0.9.0.
precision: Full precision (32), half precision (16).
print_nan_grads:
.. warning:: .. deprecated:: 0.7.2
Has no effect. When detected, NaN grads will be printed automatically.
Will remove 0.9.0.
weights_summary: Prints a summary of the weights when training begins.
weights_save_path: Where to save weights if specified. Will override default_root_dir
for checkpoints only. Use this if for whatever reason you need the checkpoints
stored in a different place than the logs written in `default_root_dir`.
amp_level: The optimization level to use (O1, O2, etc...).
num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine.
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of
resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.
This can be a URL.
profiler: To profile individual steps during training and assist in
reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch
auto_lr_find: If set to True, will `initially` run a learning rate finder,
trying to optimize initial learning for faster convergence. Sets learning
rate in self.lr or self.learning_rate in the LightningModule.
To use a different key, set a string instead of True with the key name.
replace_sampler_ddp: Explicitly enables or disables sampler replacement.
If not specified this will toggled automatically ddp is used
benchmark: If true enables cudnn.benchmark.
deterministic: If true enables cudnn.deterministic
terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
end of each training batch, if any of the parameters or the loss are NaN or +/-inf.
auto_scale_batch_size: If set to True, will `initially` run a batch size
finder trying to find the largest batch size that fits into memory.
The result will be stored in self.batch_size in the LightningModule.
Additionally, can be set to either `power` that estimates the batch size through
a power search or `binsearch` that estimates the batch size through a binary search.
prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data
"""
super().__init__()
self.deterministic = deterministic
torch.backends.cudnn.deterministic = self.deterministic
if self.deterministic:
# fixing non-deterministic part of horovod
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)
# Init callbacks
self.prepare_data_per_node = prepare_data_per_node
self.callbacks = callbacks or []
self.on_init_start()
# benchmarking
self.benchmark = benchmark
torch.backends.cudnn.benchmark = self.benchmark
# Transfer params
self.num_nodes = num_nodes
self.log_gpu_memory = log_gpu_memory
self.gradient_clip_val = gradient_clip_val
self.check_val_every_n_epoch = check_val_every_n_epoch
if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf':
raise MisconfigurationException(
"track_grad_norm can be an int, a float or 'inf' (infinity norm).")
self.track_grad_norm = float(track_grad_norm)
self.on_gpu = True if (gpus and torch.cuda.is_available()) else False
# tpu config
if num_tpu_cores is not None:
rank_zero_warn("Argument `num_tpu_cores` is now set by `tpu_cores` since v0.7.6"
" and this argument will be removed in v0.9.0", DeprecationWarning)
if tpu_cores is None:
tpu_cores = num_tpu_cores
self.on_tpu = tpu_cores is not None
self.tpu_cores = tpu_cores
assert self.tpu_cores in (1, 8, None) or (
isinstance(self.tpu_cores, (list, tuple, set)) and len(self.tpu_cores) == 1
), '`tpu_cores` can only be 1, 8 or [<1-8>]'
self.tpu_id = tpu_cores[0] if isinstance(tpu_cores, list) else None
if num_processes != 1 and distributed_backend != "ddp_cpu":
rank_zero_warn("num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it.")
self.num_processes = num_processes
self.weights_summary = weights_summary
self.max_epochs = max_epochs
self.min_epochs = min_epochs
self.max_steps = max_steps
self.min_steps = min_steps
self.num_sanity_val_steps = num_sanity_val_steps
# Backward compatibility, TODO: remove in v0.9.0
if print_nan_grads:
rank_zero_warn("Argument `print_nan_grads` has no effect and will be removed in v0.9.0."
" NaN grads will be printed automatically when detected.", DeprecationWarning)
self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch
self.auto_lr_find = auto_lr_find
self.auto_scale_batch_size = auto_scale_batch_size
self._is_data_prepared = False
self.replace_sampler_ddp = replace_sampler_ddp
self.truncated_bptt_steps = truncated_bptt_steps
self.resume_from_checkpoint = resume_from_checkpoint
self.terminate_on_nan = terminate_on_nan
self.shown_warnings = set()
self.fast_dev_run = fast_dev_run
if self.fast_dev_run:
self.num_sanity_val_steps = 0
self.max_epochs = 1
rank_zero_info('Running in fast_dev_run mode: will run a full train,'
' val and test loop using a single batch')
# set default save path if user didn't provide one
self.default_root_dir = default_root_dir
if self.default_root_dir is None:
self.default_root_dir = os.getcwd()
# training bookeeping
self.total_batch_idx = 0
self.running_loss = TensorRunningAccum(window_length=20)
self.batch_idx = 0
self.progress_bar_metrics = {}
self.callback_metrics = {}
self.num_val_batches = [0]
self.num_training_batches = 0
self.num_test_batches = [0]
self.train_dataloader = None
self.test_dataloaders = None
self.val_dataloaders = None
# training state
self.model = None
self.testing = False
self.disable_validation = False
self.lr_schedulers = []
self.optimizers = None
self.optimizer_frequencies = []
self.global_step = 0
self.current_epoch = 0
self.interrupted = False
# configure logger
self.configure_logger(logger)
# configure profiler
if profiler is True:
profiler = SimpleProfiler()
self.profiler = profiler or PassThroughProfiler()
# configure early stop callback
# creates a default one if none passed in
self.configure_early_stopping(early_stop_callback)
# configure checkpoint callback
self.checkpoint_callback = checkpoint_callback
self.weights_save_path = weights_save_path
# accumulated grads
self.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)
# for gpus allow int, string and gpu list
if auto_select_gpus and isinstance(gpus, int):
self.gpus = pick_multiple_gpus(gpus)
else:
self.gpus = gpus
self.data_parallel_device_ids = parse_gpu_ids(self.gpus)
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)
self.root_device = torch.device("cpu")
# tpu state flags
self.use_tpu = False
self.tpu_local_core_rank = None
self.tpu_global_core_rank = None
# distributed backend choice
self.distributed_backend = distributed_backend
self.set_distributed_mode(distributed_backend)
# override dist backend when using tpus
if self.on_tpu:
self.init_tpu()
# init flags for SLURM+ddp to work
self.world_size = 1
self.interactive_ddp_procs = []
self.configure_slurm_ddp(self.num_nodes)
self.node_rank = self.determine_ddp_node_rank()
self.local_rank = self.determine_local_rank()
self.global_rank = 0
# nvidia setup
self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)
# backward compatibility
if show_progress_bar is not None:
self.show_progress_bar = show_progress_bar
self._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)
# logging
self.log_save_interval = log_save_interval
self.val_check_interval = val_check_interval
self.row_log_interval = row_log_interval
# how much of the data to use
# TODO: remove in 1.0.0
if overfit_pct > 0:
overfit_batches = overfit_pct
# convert floats to ints
overfit_batches = int(overfit_batches) if overfit_batches > 1.0 else overfit_batches
self.overfit_batches = overfit_batches
# TODO: remove in 1.0.0
if val_percent_check < 1.0:
limit_val_batches = val_percent_check
if test_percent_check < 1.0:
limit_test_batches = test_percent_check
limit_test_batches = int(limit_test_batches) if limit_test_batches > 1.0 else limit_test_batches
limit_val_batches = int(limit_val_batches) if limit_val_batches > 1.0 else limit_val_batches
# TODO: convert train_percent_check to limit_train_batches
self.determine_data_use_amount(train_percent_check, limit_val_batches,
limit_test_batches, overfit_batches)
# AMP init
# These are the only lines needed after v0.8.0
# we wrap the user's forward with autocast and give it back at the end of fit
self.autocast_original_forward = None
self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
self.precision = precision
self.scaler = None
self.amp_level = amp_level
self.init_amp(use_amp)
self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')
# Callback system
self.on_init_end()
@property
def is_global_zero(self):
return self.global_rank == 0
@property
def slurm_job_id(self) -> Optional[int]:
try:
job_id = os.environ['SLURM_JOB_ID']
job_id = int(job_id)
# in interactive mode, don't make logs use the same job id
in_slurm_interactive_mode = os.environ['SLURM_JOB_NAME'] == 'bash'
if in_slurm_interactive_mode:
job_id = None
except Exception:
job_id = None
return job_id
@classmethod
def default_attributes(cls):
init_signature = inspect.signature(Trainer)
args = {}
for param_name in init_signature.parameters:
value = init_signature.parameters[param_name].default
args[param_name] = value
return args
@classmethod
def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
r"""Scans the Trainer signature and returns argument names, types and default values.
Returns:
List with tuples of 3 values:
(argument name, set with argument types, argument default value).
Examples:
>>> args = Trainer.get_init_arguments_and_types()
>>> import pprint
>>> pprint.pprint(sorted(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
[('accumulate_grad_batches',
(<class 'int'>, typing.Dict[int, int], typing.List[list]),
1),
...
('callbacks',
(typing.List[pytorch_lightning.callbacks.base.Callback],
<class 'NoneType'>),
None),
('check_val_every_n_epoch', (<class 'int'>,), 1),
...
('max_epochs', (<class 'int'>,), 1000),
...
('precision', (<class 'int'>,), 32),
('prepare_data_per_node', (<class 'bool'>,), True),
('print_nan_grads', (<class 'bool'>,), False),
('process_position', (<class 'int'>,), 0),
('profiler',
(<class 'pytorch_lightning.profiler.profilers.BaseProfiler'>,
<class 'bool'>,
<class 'NoneType'>),
None),
...
"""
trainer_default_params = inspect.signature(cls).parameters
name_type_default = []
for arg in trainer_default_params:
arg_type = trainer_default_params[arg].annotation
arg_default = trainer_default_params[arg].default
try:
arg_types = tuple(arg_type.__args__)
except AttributeError:
arg_types = (arg_type,)
name_type_default.append((arg, arg_types, arg_default))
return name_type_default
@classmethod
def get_deprecated_arg_names(cls) -> List:
"""Returns a list with deprecated Trainer arguments."""
depr_arg_names = []
for name, val in cls.__dict__.items():
if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)):
depr_arg_names.extend(val)
return depr_arg_names
@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
r"""Extends existing argparse by default `Trainer` attributes.
Args:
parent_parser:
The custom cli arguments parser, which will be extended by
the Trainer default arguments.
Only arguments of the allowed types (str, float, int, bool) will
extend the `parent_parser`.
Examples:
>>> import argparse
>>> import pprint
>>> parser = argparse.ArgumentParser()
>>> parser = Trainer.add_argparse_args(parser)
>>> args = parser.parse_args([])
>>> pprint.pprint(vars(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
{...
'check_val_every_n_epoch': 1,
'checkpoint_callback': True,
'default_root_dir': None,
'deterministic': False,
'distributed_backend': None,
'early_stop_callback': False,
...
'logger': True,
'max_epochs': 1000,
'max_steps': None,
'min_epochs': 1,
'min_steps': None,
...
'profiler': None,
'progress_bar_refresh_rate': 1,
...}
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False, )
blacklist = ['kwargs']
depr_arg_names = cls.get_deprecated_arg_names() + blacklist
allowed_types = (str, float, int, bool)
# TODO: get "help" from docstring :)
for arg, arg_types, arg_default in (at for at in cls.get_init_arguments_and_types()
if at[0] not in depr_arg_names):
arg_types = [at for at in allowed_types if at in arg_types]
if not arg_types:
# skip argument with not supported type
continue
arg_kwargs = {}
if bool in arg_types:
arg_kwargs.update(nargs="?")
# if the only arg type is bool
if len(arg_types) == 1:
# redefine the type for ArgParser needed
def use_type(x):
return bool(parsing.str_to_bool(x))
else:
# filter out the bool as we need to use more general
use_type = [at for at in arg_types if at is not bool][0]
else:
use_type = arg_types[0]
if arg == 'gpus':
use_type = Trainer._allowed_type
arg_default = Trainer._arg_default
parser.add_argument(
f'--{arg}',
dest=arg,
default=arg_default,
type=use_type,
help='autogenerated by pl.Trainer',
**arg_kwargs,
)
return parser
def _allowed_type(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)
def _arg_default(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)
@staticmethod
def parse_argparser(arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
"""Parse CLI arguments, required for custom bool types."""
args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser
args = {k: True if v is None else v for k, v in vars(args).items()}
return Namespace(**args)
@classmethod
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer':
"""
Create an instance from CLI arguments.
Args:
args: The parser or namespace to take arguments from. Only known arguments will be
parsed and passed to the :class:`Trainer`.
**kwargs: Additional keyword arguments that may override ones in the parser or namespace.
These must be valid Trainer arguments.
Example:
>>> parser = ArgumentParser(add_help=False)
>>> parser = Trainer.add_argparse_args(parser)
>>> parser.add_argument('--my_custom_arg', default='something') # doctest: +SKIP
>>> args = Trainer.parse_argparser(parser.parse_args(""))
>>> trainer = Trainer.from_argparse_args(args, logger=False)
"""
if isinstance(args, ArgumentParser):
args = cls.parse_argparser(args)
params = vars(args)
# we only want to pass in valid Trainer args, the rest may be user specific
valid_kwargs = inspect.signature(cls.__init__).parameters
trainer_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params)
trainer_kwargs.update(**kwargs)
return cls(**trainer_kwargs)
@property
def num_gpus(self) -> int:
gpus = self.data_parallel_device_ids
if gpus is None:
return 0
return len(gpus)
@property
def data_parallel(self) -> bool:
return self.use_dp or self.use_ddp or self.use_ddp2
@property
def progress_bar_callback(self):
return self._progress_bar_callback
@property
def progress_bar_dict(self) -> dict:
""" Read-only for progress bar metrics. """
ref_model = self.model if not self.data_parallel else self.model.module
return dict(**ref_model.get_progress_bar_dict(), **self.progress_bar_metrics)
# -----------------------------
# MODEL TRAINING
# -----------------------------
def fit(
self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None
):
r"""
Runs the full optimization routine.
Args:
model: Model to fit.
train_dataloader: A Pytorch
DataLoader with training samples. If the model has
a predefined train_dataloader method this will be skipped.
val_dataloaders: Either a single
Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped
Example::
# Option 1,
# Define the train_dataloader() and val_dataloader() fxs
# in the lightningModule
# RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY
trainer = Trainer()
model = LightningModule()
trainer.fit(model)
# Option 2
# in production cases we might want to pass different datasets to the same model
# Recommended for PRODUCTION SYSTEMS
train, val = DataLoader(...), DataLoader(...)
trainer = Trainer()
model = LightningModule()
trainer.fit(model, train_dataloader=train, val_dataloaders=val)
# Option 1 & 2 can be mixed, for example the training set can be
# defined as part of the model, and validation can then be feed to .fit()
"""
# bind logger and other properties
model.logger = self.logger
self.copy_trainer_model_properties(model)
# clean hparams
if hasattr(model, 'hparams'):
parsing.clean_namespace(model.hparams)
# set up the passed in dataloaders (if needed)
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
# check that model is configured correctly
self.check_model_configuration(model)
# callbacks
self.on_fit_start()
if self.is_function_implemented('on_fit_start'):
model.on_fit_start()
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
if self.can_prepare_data():
model.prepare_data()
self._is_data_prepared = True
# Run auto batch size scaling
if self.auto_scale_batch_size:
if isinstance(self.auto_scale_batch_size, bool):
self.auto_scale_batch_size = 'power'
self.scale_batch_size(model, mode=self.auto_scale_batch_size)
model.logger = self.logger # reset logger binding
# Run learning rate finder:
if self.auto_lr_find:
self._run_lr_finder_internally(model)
model.logger = self.logger # reset logger binding
# route to appropriate start method
# when using multi-node or DDP within a node start each module in a separate process
if self.use_ddp2:
if self.is_slurm_managing_tasks:
task = int(os.environ['SLURM_LOCALID'])
# torchelastic or general non_slurm ddp2
elif 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ):
task = int(os.environ['LOCAL_RANK'])
self.ddp_train(task, model)
elif self.use_ddp:
if self.is_slurm_managing_tasks:
task = int(os.environ['SLURM_LOCALID'])
self.ddp_train(task, model)
# torchelastic or general non_slurm ddp
elif 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ):
task = int(os.environ['LOCAL_RANK'])
self.ddp_train(task, model)
elif self.distributed_backend == 'cpu_ddp':
self.__set_random_port()
self.model = model
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,))
elif self.distributed_backend == 'ddp_spawn':
model.share_memory()
# spin up peers
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model, ))
elif self.distributed_backend == 'ddp':
self.spawn_ddp_children(model)
# 1 gpu or dp option triggers training using DP module
# easier to avoid NCCL issues
elif self.use_dp:
self.dp_train(model)
elif self.use_horovod:
self.horovod_train(model)
elif self.single_gpu:
self.single_gpu_train(model)
elif self.use_tpu: # pragma: no-cover
rank_zero_info(f'training on {self.tpu_cores} TPU cores')
# COLAB_GPU is an env var available by default in Colab environments.
start_method = 'fork' if self.on_colab_kaggle else 'spawn'
# track for predict
self.model = model
# train
if self.tpu_id is not None:
self.tpu_train(self.tpu_id, model)
else:
xmp.spawn(self.tpu_train, args=(model,), nprocs=self.tpu_cores, start_method=start_method)
# load weights if not interrupted
self.load_spawn_weights(model)
self.model = model
# ON CPU
else:
# run through amp wrapper
if self.use_amp:
raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option')
# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
self.run_pretrain_routine(model)
# callbacks
self.on_fit_end()
# model hooks
if self.is_function_implemented('on_fit_end'):
model.on_fit_end()
# return 1 when finished
# used for testing or when we need to know that training succeeded
return 1
def can_prepare_data(self):
if self.prepare_data_per_node:
return self.local_rank == 0
else:
return self.node_rank == 0 and self.local_rank == 0
def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None):
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloader is not None:
model.train_dataloader = _PatchDataLoader(train_dataloader)
if val_dataloaders is not None:
model.val_dataloader = _PatchDataLoader(val_dataloaders)
if test_dataloaders is not None:
model.test_dataloader = _PatchDataLoader(test_dataloaders)
def run_pretrain_routine(self, model: LightningModule):
"""Sanity check a few things before starting actual training.
Args:
model: The model to run sanity test on.
"""
ref_model = model
if self.data_parallel:
ref_model = model.module
# give model convenience properties
ref_model.trainer = self
# set local properties on the model
self.copy_trainer_model_properties(ref_model)
# init amp. Must be done here instead of __init__ to allow ddp to work
if self.use_native_amp and self.precision == 16:
self.scaler = torch.cuda.amp.GradScaler()
# log hyper-parameters
if self.logger is not None:
# save exp to get started
self.logger.log_hyperparams(ref_model.hparams)
self.logger.save()
if self.use_ddp or self.use_ddp2:
torch_distrib.barrier()
# wait for all models to restore weights
if self.on_tpu and XLA_AVAILABLE:
# wait for all processes to catch up
torch_xla.core.xla_model.rendezvous("pl.Trainer.run_pretrain_routine")
elif self.use_horovod:
# wait for all processes to catch up
hvd.join()
# register auto-resubmit when on SLURM
self.register_slurm_signal_handlers()
# print model summary
if self.is_global_zero and self.weights_summary is not None and not self.testing:
if self.weights_summary in ModelSummary.MODES:
ref_model.summarize(mode=self.weights_summary)
else:
raise MisconfigurationException(
"weights_summary can be None, " + ", ".join(ModelSummary.MODES)
)
# track model now.
# if cluster resets state, the model will update with the saved weights
self.model = model
# set up checkpoint callback
self.configure_checkpoint_callback()
# restore training and model before hpc call
self.restore_weights(model)
# when testing requested only run test and return
if self.testing:
# only load test dataloader for testing
# self.reset_test_dataloader(ref_model)
self.run_evaluation(test_mode=True)
return
# check if we should run validation during training
self.disable_validation = not (self.is_overridden('validation_step') and self.limit_val_batches > 0) \
and not self.fast_dev_run
# run tiny validation (if validation defined)
# to make sure program won't crash during val
if not self.disable_validation and self.num_sanity_val_steps > 0:
self.reset_val_dataloader(ref_model)
# hook and callback
ref_model.on_sanity_check_start()
self.on_sanity_check_start()
num_loaders = len(self.val_dataloaders)
max_batches = [self.num_sanity_val_steps] * num_loaders
eval_results = self._evaluate(model,
self.val_dataloaders,
max_batches,
False)
_, _, _, callback_metrics, _ = self.process_output(eval_results)
self.on_sanity_check_end()
# verify that early stop has conditioned on a metric that exists
if self.enable_early_stop:
self.early_stop_callback._validate_condition_metric(callback_metrics)
# clear cache before training
if self.on_gpu and self.root_gpu is not None:
# use context because of:
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
with torch.cuda.device(f'cuda:{self.root_gpu}'):
torch.cuda.empty_cache()
# CORE TRAINING LOOP
self.train()
def test(
self,
model: Optional[LightningModule] = None,
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
ckpt_path: Optional[str] = 'best'
):
r"""
Separates from fit to make sure you never run on your test set until you want to.
Args:
model: The model to test.
test_dataloaders: Either a single
Pytorch Dataloader or a list of them, specifying validation samples.
ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
If ``None``, use the weights from the last epoch to test. Default to ``best``.
Example::
# Option 1
# run test with the best checkpoint from ``ModelCheckpoint`` after fitting.
test = DataLoader(...)
trainer = Trainer()
model = LightningModule()
trainer.fit(model)
trainer.test(test_dataloaders=test)
# Option 2
# run test with the specified checkpoint after fitting
test = DataLoader(...)
trainer = Trainer()
model = LightningModule()
trainer.fit(model)
trainer.test(test_dataloaders=test, ckpt_path='path/to/checkpoint.ckpt')
# Option 3
# run test with the weights from the end of training after fitting
test = DataLoader(...)
trainer = Trainer()
model = LightningModule()
trainer.fit(model)
trainer.test(test_dataloaders=test, ckpt_path=None)
# Option 4
# run test from a loaded model. ``ckpt_path`` is ignored in this case.
test = DataLoader(...)
model = LightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
trainer = Trainer()
trainer.test(model, test_dataloaders=test)
"""
if model is None and ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0:
raise MisconfigurationException(
'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.')
# if model is not given (None), ckpt_path is given,
# load the given checkpoint for testing
if model is None and ckpt_path is not None:
# ckpt_path is 'best' so load the best model
if ckpt_path == 'best':
ckpt_path = self.checkpoint_callback.best_model_path
model = self.get_model().load_from_checkpoint(ckpt_path)
self.testing = True
if test_dataloaders is not None:
if model:
self.__attach_dataloaders(model, test_dataloaders=test_dataloaders)
else:
self.__attach_dataloaders(self.model, test_dataloaders=test_dataloaders)
if model is not None:
self.model = model
self.fit(model)
# on tpu, .spawn means we don't have a trained model
# TODO: remove TPU spawn
elif self.use_tpu: # pragma: no-cover
# attempt to load weights from a spawn
path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
test_model = self.model
if os.path.exists(path):
test_model = self.load_spawn_weights(self.model)
self.fit(test_model)
else:
self.run_evaluation(test_mode=True)
self.testing = False
def check_model_configuration(self, model: LightningModule):
r"""
Checks that the model is configured correctly before training or testing is started.
Args:
model: The model to check the configuration.
"""
# Check training_step, train_dataloader, configure_optimizer methods
if not self.testing:
if not self.is_overridden('training_step', model):
raise MisconfigurationException(
'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.')
if not self.is_overridden('train_dataloader', model):
raise MisconfigurationException(
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.')
if not self.is_overridden('configure_optimizers', model):
raise MisconfigurationException(
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.')
# Check val_dataloader, validation_step and validation_epoch_end
if self.is_overridden('val_dataloader', model):
if not self.is_overridden('validation_step', model):
raise MisconfigurationException('You have passed in a `val_dataloader()`'
' but have not defined `validation_step()`.')
else:
if not self.is_overridden('validation_epoch_end', model):
rank_zero_warn(
'You have defined a `val_dataloader()` and have defined a `validation_step()`,'
' you may also want to define `validation_epoch_end()` for accumulating stats.',
RuntimeWarning
)
else:
if self.is_overridden('validation_step', model):
raise MisconfigurationException('You have defined `validation_step()`,'
' but have not passed in a `val_dataloader()`.')
# Check test_dataloader, test_step and test_epoch_end
if self.is_overridden('test_dataloader', model):
if not self.is_overridden('test_step', model):
raise MisconfigurationException('You have passed in a `test_dataloader()`'
' but have not defined `test_step()`.')
else:
if not self.is_overridden('test_epoch_end', model):
rank_zero_warn(
'You have defined a `test_dataloader()` and have defined a `test_step()`, you may also want to'
' define `test_epoch_end()` for accumulating stats.', RuntimeWarning
)
else:
if self.testing and self.is_overridden('test_step', model):
raise MisconfigurationException('You have defined `test_step()` but did not'
' implement `test_dataloader` nor passed in `.test(test_dataloader)`.')
class _PatchDataLoader(object):
r"""
Callable object for patching dataloaders passed into trainer.fit().
Use this class to override model.*_dataloader() and be pickle-compatible.
Args:
dataloader: Dataloader object to return when called.
"""
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader
# cannot pickle __code__ so cannot verify if PatchDataloader
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
self.patch_loader_code = str(self.__call__.__code__)
def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader