1232 lines
48 KiB
Python
1232 lines
48 KiB
Python
import inspect
|
|
import os
|
|
from argparse import ArgumentParser, Namespace
|
|
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any
|
|
|
|
import torch
|
|
import torch.distributed as torch_distrib
|
|
import torch.multiprocessing as mp
|
|
from torch.utils.data import DataLoader
|
|
|
|
from pytorch_lightning import _logger as log
|
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback
|
|
from pytorch_lightning.core.lightning import LightningModule
|
|
from pytorch_lightning.core.memory import ModelSummary
|
|
from pytorch_lightning.loggers import LightningLoggerBase
|
|
from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler, BaseProfiler
|
|
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
|
|
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
|
|
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
|
|
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
|
|
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_9
|
|
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
|
|
from pytorch_lightning.trainer.distrib_parts import (
|
|
TrainerDPMixin, parse_gpu_ids, determine_root_gpu_device, pick_multiple_gpus)
|
|
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
|
|
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
|
|
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
|
|
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
|
|
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
|
from pytorch_lightning.trainer.training_io import TrainerIOMixin
|
|
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
|
|
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
|
|
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info
|
|
|
|
try:
|
|
from apex import amp
|
|
except ImportError:
|
|
APEX_AVAILABLE = False
|
|
else:
|
|
APEX_AVAILABLE = True
|
|
|
|
try:
|
|
import torch_xla
|
|
import torch_xla.core.xla_model as xm
|
|
import torch_xla.distributed.xla_multiprocessing as xmp
|
|
except ImportError:
|
|
XLA_AVAILABLE = False
|
|
else:
|
|
XLA_AVAILABLE = True
|
|
|
|
try:
|
|
import horovod.torch as hvd
|
|
except ImportError:
|
|
HOROVOD_AVAILABLE = False
|
|
else:
|
|
HOROVOD_AVAILABLE = True
|
|
|
|
|
|
class Trainer(
|
|
TrainerIOMixin,
|
|
TrainerOptimizersMixin,
|
|
TrainerAMPMixin,
|
|
TrainerDPMixin,
|
|
TrainerDDPMixin,
|
|
TrainerLoggingMixin,
|
|
TrainerModelHooksMixin,
|
|
TrainerTrainingTricksMixin,
|
|
TrainerDataLoadingMixin,
|
|
TrainerEvaluationLoopMixin,
|
|
TrainerTrainLoopMixin,
|
|
TrainerCallbackConfigMixin,
|
|
TrainerCallbackHookMixin,
|
|
TrainerLRFinderMixin,
|
|
TrainerDeprecatedAPITillVer0_9,
|
|
):
|
|
DEPRECATED_IN_0_9 = ('use_amp', 'show_progress_bar', 'training_tqdm_dict', 'num_tpu_cores')
|
|
|
|
def __init__(
|
|
self,
|
|
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
|
|
checkpoint_callback: Union[ModelCheckpoint, bool] = True,
|
|
early_stop_callback: Optional[Union[EarlyStopping, bool]] = False,
|
|
callbacks: Optional[List[Callback]] = None,
|
|
default_root_dir: Optional[str] = None,
|
|
gradient_clip_val: float = 0,
|
|
process_position: int = 0,
|
|
num_nodes: int = 1,
|
|
num_processes: int = 1,
|
|
gpus: Optional[Union[List[int], str, int]] = None,
|
|
auto_select_gpus: bool = False,
|
|
tpu_cores: Optional[Union[List[int], int]] = None,
|
|
log_gpu_memory: Optional[str] = None,
|
|
progress_bar_refresh_rate: int = 1,
|
|
overfit_batches: Union[int, float] = 0.0,
|
|
track_grad_norm: Union[int, float, str] = -1,
|
|
check_val_every_n_epoch: int = 1,
|
|
fast_dev_run: bool = False,
|
|
accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
|
|
max_epochs: int = 1000,
|
|
min_epochs: int = 1,
|
|
max_steps: Optional[int] = None,
|
|
min_steps: Optional[int] = None,
|
|
train_percent_check: float = 1.0,
|
|
limit_val_batches: Union[int, float] = 1.0,
|
|
limit_test_batches: Union[int, float] = 1.0,
|
|
val_check_interval: float = 1.0,
|
|
log_save_interval: int = 100,
|
|
row_log_interval: int = 50,
|
|
distributed_backend: Optional[str] = None,
|
|
precision: int = 32,
|
|
print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0
|
|
weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT,
|
|
weights_save_path: Optional[str] = None,
|
|
num_sanity_val_steps: int = 2,
|
|
truncated_bptt_steps: Optional[int] = None,
|
|
resume_from_checkpoint: Optional[str] = None,
|
|
profiler: Optional[Union[BaseProfiler, bool]] = None,
|
|
benchmark: bool = False,
|
|
deterministic: bool = False,
|
|
reload_dataloaders_every_epoch: bool = False,
|
|
auto_lr_find: Union[bool, str] = False,
|
|
replace_sampler_ddp: bool = True,
|
|
terminate_on_nan: bool = False,
|
|
auto_scale_batch_size: Union[str, bool] = False,
|
|
prepare_data_per_node: bool = True,
|
|
amp_level: str = 'O1', # backward compatible, todo: remove in v1.0.0
|
|
num_tpu_cores: Optional[int] = None, # backward compatible, todo: remove in v0.9.0
|
|
use_amp=None, # backward compatible, todo: remove in v0.9.0
|
|
show_progress_bar=None, # backward compatible, todo: remove in v0.9.0
|
|
val_percent_check: float = 1.0, # backward compatible, todo: remove in v1.0.0
|
|
test_percent_check: float = 1.0, # backward compatible, todo: remove in v1.0.0
|
|
overfit_pct: float = 0.0 # backward compatible, todo: remove in v1.0.0
|
|
):
|
|
r"""
|
|
|
|
Customize every aspect of training via flags
|
|
|
|
Args:
|
|
logger: Logger (or iterable collection of loggers) for experiment tracking.
|
|
|
|
checkpoint_callback: Callback for checkpointing.
|
|
|
|
early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`):
|
|
|
|
callbacks: Add a list of callbacks.
|
|
|
|
default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed
|
|
|
|
gradient_clip_val: 0 means don't clip.
|
|
|
|
gradient_clip:
|
|
.. warning:: .. deprecated:: 0.7.0
|
|
|
|
Use `gradient_clip_val` instead. Will remove 0.9.0.
|
|
|
|
process_position: orders the progress bar when running multiple models on same machine.
|
|
|
|
num_nodes: number of GPU nodes for distributed training.
|
|
|
|
nb_gpu_nodes:
|
|
.. warning:: .. deprecated:: 0.7.0
|
|
|
|
Use `num_nodes` instead. Will remove 0.9.0.
|
|
|
|
gpus: Which GPUs to train on.
|
|
|
|
auto_select_gpus:
|
|
|
|
If enabled and `gpus` is an integer, pick available
|
|
gpus automatically. This is especially useful when
|
|
GPUs are configured to be in "exclusive mode", such
|
|
that only one process at a time can access them.
|
|
|
|
tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1]
|
|
|
|
num_tpu_cores: How many TPU cores to train on (1 or 8)
|
|
.. warning:: .. deprecated:: 0.7.6. Will remove 0.9.0.
|
|
|
|
log_gpu_memory: None, 'min_max', 'all'. Might slow performance
|
|
|
|
show_progress_bar:
|
|
.. warning:: .. deprecated:: 0.7.2
|
|
|
|
Set `progress_bar_refresh_rate` to positive integer to enable. Will remove 0.9.0.
|
|
|
|
progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
|
|
Ignored when a custom callback is passed to :paramref:`~Trainer.callbacks`.
|
|
|
|
overfit_batches: Overfit a percent of training data (float) or a set number of batches (int).
|
|
|
|
overfit_pct:
|
|
.. warning:: .. deprecated:: 0.8.0
|
|
|
|
Use `overfit_batches` instead. Will remove 1.0.0.
|
|
|
|
track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.
|
|
|
|
check_val_every_n_epoch: Check val every n train epochs.
|
|
|
|
fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).
|
|
|
|
accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.
|
|
|
|
max_epochs: Stop training once this number of epochs is reached.
|
|
|
|
max_nb_epochs:
|
|
.. warning:: .. deprecated:: 0.7.0
|
|
|
|
Use `max_epochs` instead. Will remove 0.9.0.
|
|
|
|
min_epochs: Force training for at least these many epochs
|
|
|
|
min_nb_epochs:
|
|
.. warning:: .. deprecated:: 0.7.0
|
|
|
|
Use `min_epochs` instead. Will remove 0.9.0.
|
|
|
|
max_steps: Stop training after this number of steps. Disabled by default (None).
|
|
|
|
min_steps: Force training for at least these number of steps. Disabled by default (None).
|
|
|
|
train_percent_check: How much of training dataset to check.
|
|
|
|
limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches)
|
|
|
|
limit_test_batches: How much of test dataset to check (floats = percent, int = num_batches)
|
|
|
|
val_percent_check:
|
|
.. warning:: .. deprecated:: 0.8.0
|
|
|
|
Use `min_epochs` instead. Will remove 1.0.0.
|
|
|
|
test_percent_check:
|
|
.. warning:: .. deprecated:: 0.8.0
|
|
|
|
Use `min_epochs` instead. Will remove 1.0.0.
|
|
|
|
val_check_interval: How often within one training epoch to check the validation set
|
|
|
|
log_save_interval: Writes logs to disk this often
|
|
|
|
row_log_interval: How often to add logging rows (does not write to disk)
|
|
|
|
add_row_log_interval:
|
|
.. warning:: .. deprecated:: 0.7.0
|
|
|
|
Use `row_log_interval` instead. Will remove 0.9.0.
|
|
|
|
distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn)
|
|
|
|
use_amp:
|
|
.. warning:: .. deprecated:: 0.7.0
|
|
|
|
Use `precision` instead. Will remove 0.9.0.
|
|
|
|
precision: Full precision (32), half precision (16).
|
|
|
|
print_nan_grads:
|
|
.. warning:: .. deprecated:: 0.7.2
|
|
|
|
Has no effect. When detected, NaN grads will be printed automatically.
|
|
Will remove 0.9.0.
|
|
|
|
weights_summary: Prints a summary of the weights when training begins.
|
|
|
|
weights_save_path: Where to save weights if specified. Will override default_root_dir
|
|
for checkpoints only. Use this if for whatever reason you need the checkpoints
|
|
stored in a different place than the logs written in `default_root_dir`.
|
|
|
|
amp_level: The optimization level to use (O1, O2, etc...).
|
|
|
|
num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine.
|
|
|
|
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of
|
|
|
|
resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.
|
|
This can be a URL.
|
|
|
|
profiler: To profile individual steps during training and assist in
|
|
|
|
reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch
|
|
|
|
auto_lr_find: If set to True, will `initially` run a learning rate finder,
|
|
trying to optimize initial learning for faster convergence. Sets learning
|
|
rate in self.lr or self.learning_rate in the LightningModule.
|
|
To use a different key, set a string instead of True with the key name.
|
|
|
|
replace_sampler_ddp: Explicitly enables or disables sampler replacement.
|
|
If not specified this will toggled automatically ddp is used
|
|
|
|
benchmark: If true enables cudnn.benchmark.
|
|
|
|
deterministic: If true enables cudnn.deterministic
|
|
|
|
terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
|
|
end of each training batch, if any of the parameters or the loss are NaN or +/-inf.
|
|
|
|
auto_scale_batch_size: If set to True, will `initially` run a batch size
|
|
finder trying to find the largest batch size that fits into memory.
|
|
The result will be stored in self.batch_size in the LightningModule.
|
|
Additionally, can be set to either `power` that estimates the batch size through
|
|
a power search or `binsearch` that estimates the batch size through a binary search.
|
|
|
|
prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
|
|
Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data
|
|
"""
|
|
super().__init__()
|
|
|
|
self.deterministic = deterministic
|
|
torch.backends.cudnn.deterministic = self.deterministic
|
|
if self.deterministic:
|
|
# fixing non-deterministic part of horovod
|
|
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
|
|
os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)
|
|
|
|
# Init callbacks
|
|
self.prepare_data_per_node = prepare_data_per_node
|
|
self.callbacks = callbacks or []
|
|
self.on_init_start()
|
|
|
|
# benchmarking
|
|
self.benchmark = benchmark
|
|
torch.backends.cudnn.benchmark = self.benchmark
|
|
|
|
# Transfer params
|
|
self.num_nodes = num_nodes
|
|
self.log_gpu_memory = log_gpu_memory
|
|
|
|
self.gradient_clip_val = gradient_clip_val
|
|
self.check_val_every_n_epoch = check_val_every_n_epoch
|
|
|
|
if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf':
|
|
raise MisconfigurationException(
|
|
"track_grad_norm can be an int, a float or 'inf' (infinity norm).")
|
|
self.track_grad_norm = float(track_grad_norm)
|
|
|
|
self.on_gpu = True if (gpus and torch.cuda.is_available()) else False
|
|
|
|
# tpu config
|
|
if num_tpu_cores is not None:
|
|
rank_zero_warn("Argument `num_tpu_cores` is now set by `tpu_cores` since v0.7.6"
|
|
" and this argument will be removed in v0.9.0", DeprecationWarning)
|
|
|
|
if tpu_cores is None:
|
|
tpu_cores = num_tpu_cores
|
|
self.on_tpu = tpu_cores is not None
|
|
self.tpu_cores = tpu_cores
|
|
assert self.tpu_cores in (1, 8, None) or (
|
|
isinstance(self.tpu_cores, (list, tuple, set)) and len(self.tpu_cores) == 1
|
|
), '`tpu_cores` can only be 1, 8 or [<1-8>]'
|
|
|
|
self.tpu_id = tpu_cores[0] if isinstance(tpu_cores, list) else None
|
|
|
|
if num_processes != 1 and distributed_backend != "ddp_cpu":
|
|
rank_zero_warn("num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it.")
|
|
self.num_processes = num_processes
|
|
|
|
self.weights_summary = weights_summary
|
|
|
|
self.max_epochs = max_epochs
|
|
self.min_epochs = min_epochs
|
|
self.max_steps = max_steps
|
|
self.min_steps = min_steps
|
|
|
|
self.num_sanity_val_steps = num_sanity_val_steps
|
|
# Backward compatibility, TODO: remove in v0.9.0
|
|
if print_nan_grads:
|
|
rank_zero_warn("Argument `print_nan_grads` has no effect and will be removed in v0.9.0."
|
|
" NaN grads will be printed automatically when detected.", DeprecationWarning)
|
|
|
|
self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch
|
|
|
|
self.auto_lr_find = auto_lr_find
|
|
self.auto_scale_batch_size = auto_scale_batch_size
|
|
self._is_data_prepared = False
|
|
self.replace_sampler_ddp = replace_sampler_ddp
|
|
|
|
self.truncated_bptt_steps = truncated_bptt_steps
|
|
self.resume_from_checkpoint = resume_from_checkpoint
|
|
self.terminate_on_nan = terminate_on_nan
|
|
self.shown_warnings = set()
|
|
|
|
self.fast_dev_run = fast_dev_run
|
|
if self.fast_dev_run:
|
|
self.num_sanity_val_steps = 0
|
|
self.max_epochs = 1
|
|
rank_zero_info('Running in fast_dev_run mode: will run a full train,'
|
|
' val and test loop using a single batch')
|
|
|
|
# set default save path if user didn't provide one
|
|
self.default_root_dir = default_root_dir
|
|
|
|
if self.default_root_dir is None:
|
|
self.default_root_dir = os.getcwd()
|
|
|
|
# training bookeeping
|
|
self.total_batch_idx = 0
|
|
self.running_loss = TensorRunningAccum(window_length=20)
|
|
self.batch_idx = 0
|
|
self.progress_bar_metrics = {}
|
|
self.callback_metrics = {}
|
|
self.num_val_batches = [0]
|
|
self.num_training_batches = 0
|
|
self.num_test_batches = [0]
|
|
self.train_dataloader = None
|
|
self.test_dataloaders = None
|
|
self.val_dataloaders = None
|
|
|
|
# training state
|
|
self.model = None
|
|
self.testing = False
|
|
self.disable_validation = False
|
|
self.lr_schedulers = []
|
|
self.optimizers = None
|
|
self.optimizer_frequencies = []
|
|
self.global_step = 0
|
|
self.current_epoch = 0
|
|
self.interrupted = False
|
|
|
|
# configure logger
|
|
self.configure_logger(logger)
|
|
|
|
# configure profiler
|
|
if profiler is True:
|
|
profiler = SimpleProfiler()
|
|
self.profiler = profiler or PassThroughProfiler()
|
|
|
|
# configure early stop callback
|
|
# creates a default one if none passed in
|
|
self.configure_early_stopping(early_stop_callback)
|
|
|
|
# configure checkpoint callback
|
|
self.checkpoint_callback = checkpoint_callback
|
|
self.weights_save_path = weights_save_path
|
|
|
|
# accumulated grads
|
|
self.accumulate_grad_batches = accumulate_grad_batches
|
|
self.configure_accumulated_gradients(accumulate_grad_batches)
|
|
|
|
# for gpus allow int, string and gpu list
|
|
if auto_select_gpus and isinstance(gpus, int):
|
|
self.gpus = pick_multiple_gpus(gpus)
|
|
else:
|
|
self.gpus = gpus
|
|
|
|
self.data_parallel_device_ids = parse_gpu_ids(self.gpus)
|
|
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)
|
|
self.root_device = torch.device("cpu")
|
|
|
|
# tpu state flags
|
|
self.use_tpu = False
|
|
self.tpu_local_core_rank = None
|
|
self.tpu_global_core_rank = None
|
|
|
|
# distributed backend choice
|
|
self.distributed_backend = distributed_backend
|
|
self.set_distributed_mode(distributed_backend)
|
|
|
|
# override dist backend when using tpus
|
|
if self.on_tpu:
|
|
self.init_tpu()
|
|
|
|
# init flags for SLURM+ddp to work
|
|
self.world_size = 1
|
|
self.interactive_ddp_procs = []
|
|
self.configure_slurm_ddp(self.num_nodes)
|
|
self.node_rank = self.determine_ddp_node_rank()
|
|
self.local_rank = self.determine_local_rank()
|
|
self.global_rank = 0
|
|
|
|
# nvidia setup
|
|
self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)
|
|
|
|
# backward compatibility
|
|
if show_progress_bar is not None:
|
|
self.show_progress_bar = show_progress_bar
|
|
|
|
self._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)
|
|
|
|
# logging
|
|
self.log_save_interval = log_save_interval
|
|
self.val_check_interval = val_check_interval
|
|
|
|
self.row_log_interval = row_log_interval
|
|
|
|
# how much of the data to use
|
|
# TODO: remove in 1.0.0
|
|
if overfit_pct > 0:
|
|
overfit_batches = overfit_pct
|
|
|
|
# convert floats to ints
|
|
overfit_batches = int(overfit_batches) if overfit_batches > 1.0 else overfit_batches
|
|
self.overfit_batches = overfit_batches
|
|
|
|
# TODO: remove in 1.0.0
|
|
if val_percent_check < 1.0:
|
|
limit_val_batches = val_percent_check
|
|
|
|
if test_percent_check < 1.0:
|
|
limit_test_batches = test_percent_check
|
|
|
|
limit_test_batches = int(limit_test_batches) if limit_test_batches > 1.0 else limit_test_batches
|
|
limit_val_batches = int(limit_val_batches) if limit_val_batches > 1.0 else limit_val_batches
|
|
|
|
# TODO: convert train_percent_check to limit_train_batches
|
|
self.determine_data_use_amount(train_percent_check, limit_val_batches,
|
|
limit_test_batches, overfit_batches)
|
|
|
|
# AMP init
|
|
# These are the only lines needed after v0.8.0
|
|
# we wrap the user's forward with autocast and give it back at the end of fit
|
|
self.autocast_original_forward = None
|
|
self.use_native_amp = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
|
|
self.precision = precision
|
|
self.scaler = None
|
|
|
|
self.amp_level = amp_level
|
|
self.init_amp(use_amp)
|
|
|
|
self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')
|
|
|
|
# Callback system
|
|
self.on_init_end()
|
|
|
|
@property
|
|
def is_global_zero(self):
|
|
return self.global_rank == 0
|
|
|
|
@property
|
|
def slurm_job_id(self) -> Optional[int]:
|
|
try:
|
|
job_id = os.environ['SLURM_JOB_ID']
|
|
job_id = int(job_id)
|
|
|
|
# in interactive mode, don't make logs use the same job id
|
|
in_slurm_interactive_mode = os.environ['SLURM_JOB_NAME'] == 'bash'
|
|
if in_slurm_interactive_mode:
|
|
job_id = None
|
|
|
|
except Exception:
|
|
job_id = None
|
|
return job_id
|
|
|
|
@classmethod
|
|
def default_attributes(cls):
|
|
init_signature = inspect.signature(Trainer)
|
|
|
|
args = {}
|
|
for param_name in init_signature.parameters:
|
|
value = init_signature.parameters[param_name].default
|
|
args[param_name] = value
|
|
|
|
return args
|
|
|
|
@classmethod
|
|
def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
|
|
r"""Scans the Trainer signature and returns argument names, types and default values.
|
|
|
|
Returns:
|
|
List with tuples of 3 values:
|
|
(argument name, set with argument types, argument default value).
|
|
|
|
Examples:
|
|
>>> args = Trainer.get_init_arguments_and_types()
|
|
>>> import pprint
|
|
>>> pprint.pprint(sorted(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
|
[('accumulate_grad_batches',
|
|
(<class 'int'>, typing.Dict[int, int], typing.List[list]),
|
|
1),
|
|
...
|
|
('callbacks',
|
|
(typing.List[pytorch_lightning.callbacks.base.Callback],
|
|
<class 'NoneType'>),
|
|
None),
|
|
('check_val_every_n_epoch', (<class 'int'>,), 1),
|
|
...
|
|
('max_epochs', (<class 'int'>,), 1000),
|
|
...
|
|
('precision', (<class 'int'>,), 32),
|
|
('prepare_data_per_node', (<class 'bool'>,), True),
|
|
('print_nan_grads', (<class 'bool'>,), False),
|
|
('process_position', (<class 'int'>,), 0),
|
|
('profiler',
|
|
(<class 'pytorch_lightning.profiler.profilers.BaseProfiler'>,
|
|
<class 'bool'>,
|
|
<class 'NoneType'>),
|
|
None),
|
|
...
|
|
"""
|
|
trainer_default_params = inspect.signature(cls).parameters
|
|
name_type_default = []
|
|
for arg in trainer_default_params:
|
|
arg_type = trainer_default_params[arg].annotation
|
|
arg_default = trainer_default_params[arg].default
|
|
try:
|
|
arg_types = tuple(arg_type.__args__)
|
|
except AttributeError:
|
|
arg_types = (arg_type,)
|
|
|
|
name_type_default.append((arg, arg_types, arg_default))
|
|
|
|
return name_type_default
|
|
|
|
@classmethod
|
|
def get_deprecated_arg_names(cls) -> List:
|
|
"""Returns a list with deprecated Trainer arguments."""
|
|
depr_arg_names = []
|
|
for name, val in cls.__dict__.items():
|
|
if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)):
|
|
depr_arg_names.extend(val)
|
|
return depr_arg_names
|
|
|
|
@classmethod
|
|
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
|
|
r"""Extends existing argparse by default `Trainer` attributes.
|
|
|
|
Args:
|
|
parent_parser:
|
|
The custom cli arguments parser, which will be extended by
|
|
the Trainer default arguments.
|
|
|
|
Only arguments of the allowed types (str, float, int, bool) will
|
|
extend the `parent_parser`.
|
|
|
|
Examples:
|
|
>>> import argparse
|
|
>>> import pprint
|
|
>>> parser = argparse.ArgumentParser()
|
|
>>> parser = Trainer.add_argparse_args(parser)
|
|
>>> args = parser.parse_args([])
|
|
>>> pprint.pprint(vars(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
|
{...
|
|
'check_val_every_n_epoch': 1,
|
|
'checkpoint_callback': True,
|
|
'default_root_dir': None,
|
|
'deterministic': False,
|
|
'distributed_backend': None,
|
|
'early_stop_callback': False,
|
|
...
|
|
'logger': True,
|
|
'max_epochs': 1000,
|
|
'max_steps': None,
|
|
'min_epochs': 1,
|
|
'min_steps': None,
|
|
...
|
|
'profiler': None,
|
|
'progress_bar_refresh_rate': 1,
|
|
...}
|
|
|
|
"""
|
|
parser = ArgumentParser(parents=[parent_parser], add_help=False, )
|
|
|
|
blacklist = ['kwargs']
|
|
depr_arg_names = cls.get_deprecated_arg_names() + blacklist
|
|
|
|
allowed_types = (str, float, int, bool)
|
|
|
|
# TODO: get "help" from docstring :)
|
|
for arg, arg_types, arg_default in (at for at in cls.get_init_arguments_and_types()
|
|
if at[0] not in depr_arg_names):
|
|
arg_types = [at for at in allowed_types if at in arg_types]
|
|
if not arg_types:
|
|
# skip argument with not supported type
|
|
continue
|
|
arg_kwargs = {}
|
|
if bool in arg_types:
|
|
arg_kwargs.update(nargs="?")
|
|
# if the only arg type is bool
|
|
if len(arg_types) == 1:
|
|
# redefine the type for ArgParser needed
|
|
def use_type(x):
|
|
return bool(parsing.str_to_bool(x))
|
|
else:
|
|
# filter out the bool as we need to use more general
|
|
use_type = [at for at in arg_types if at is not bool][0]
|
|
else:
|
|
use_type = arg_types[0]
|
|
|
|
if arg == 'gpus':
|
|
use_type = Trainer._allowed_type
|
|
arg_default = Trainer._arg_default
|
|
|
|
parser.add_argument(
|
|
f'--{arg}',
|
|
dest=arg,
|
|
default=arg_default,
|
|
type=use_type,
|
|
help='autogenerated by pl.Trainer',
|
|
**arg_kwargs,
|
|
)
|
|
|
|
return parser
|
|
|
|
def _allowed_type(x) -> Union[int, str]:
|
|
if ',' in x:
|
|
return str(x)
|
|
else:
|
|
return int(x)
|
|
|
|
def _arg_default(x) -> Union[int, str]:
|
|
if ',' in x:
|
|
return str(x)
|
|
else:
|
|
return int(x)
|
|
|
|
@staticmethod
|
|
def parse_argparser(arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
|
|
"""Parse CLI arguments, required for custom bool types."""
|
|
args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser
|
|
args = {k: True if v is None else v for k, v in vars(args).items()}
|
|
return Namespace(**args)
|
|
|
|
@classmethod
|
|
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer':
|
|
"""
|
|
Create an instance from CLI arguments.
|
|
|
|
Args:
|
|
args: The parser or namespace to take arguments from. Only known arguments will be
|
|
parsed and passed to the :class:`Trainer`.
|
|
**kwargs: Additional keyword arguments that may override ones in the parser or namespace.
|
|
These must be valid Trainer arguments.
|
|
|
|
Example:
|
|
>>> parser = ArgumentParser(add_help=False)
|
|
>>> parser = Trainer.add_argparse_args(parser)
|
|
>>> parser.add_argument('--my_custom_arg', default='something') # doctest: +SKIP
|
|
>>> args = Trainer.parse_argparser(parser.parse_args(""))
|
|
>>> trainer = Trainer.from_argparse_args(args, logger=False)
|
|
"""
|
|
if isinstance(args, ArgumentParser):
|
|
args = cls.parse_argparser(args)
|
|
params = vars(args)
|
|
|
|
# we only want to pass in valid Trainer args, the rest may be user specific
|
|
valid_kwargs = inspect.signature(cls.__init__).parameters
|
|
trainer_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params)
|
|
trainer_kwargs.update(**kwargs)
|
|
|
|
return cls(**trainer_kwargs)
|
|
|
|
@property
|
|
def num_gpus(self) -> int:
|
|
gpus = self.data_parallel_device_ids
|
|
if gpus is None:
|
|
return 0
|
|
return len(gpus)
|
|
|
|
@property
|
|
def data_parallel(self) -> bool:
|
|
return self.use_dp or self.use_ddp or self.use_ddp2
|
|
|
|
@property
|
|
def progress_bar_callback(self):
|
|
return self._progress_bar_callback
|
|
|
|
@property
|
|
def progress_bar_dict(self) -> dict:
|
|
""" Read-only for progress bar metrics. """
|
|
ref_model = self.model if not self.data_parallel else self.model.module
|
|
return dict(**ref_model.get_progress_bar_dict(), **self.progress_bar_metrics)
|
|
|
|
# -----------------------------
|
|
# MODEL TRAINING
|
|
# -----------------------------
|
|
def fit(
|
|
self,
|
|
model: LightningModule,
|
|
train_dataloader: Optional[DataLoader] = None,
|
|
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None
|
|
):
|
|
r"""
|
|
Runs the full optimization routine.
|
|
|
|
Args:
|
|
model: Model to fit.
|
|
|
|
train_dataloader: A Pytorch
|
|
DataLoader with training samples. If the model has
|
|
a predefined train_dataloader method this will be skipped.
|
|
|
|
val_dataloaders: Either a single
|
|
Pytorch Dataloader or a list of them, specifying validation samples.
|
|
If the model has a predefined val_dataloaders method this will be skipped
|
|
|
|
Example::
|
|
|
|
# Option 1,
|
|
# Define the train_dataloader() and val_dataloader() fxs
|
|
# in the lightningModule
|
|
# RECOMMENDED FOR MOST RESEARCH AND APPLICATIONS TO MAINTAIN READABILITY
|
|
trainer = Trainer()
|
|
model = LightningModule()
|
|
trainer.fit(model)
|
|
|
|
# Option 2
|
|
# in production cases we might want to pass different datasets to the same model
|
|
# Recommended for PRODUCTION SYSTEMS
|
|
train, val = DataLoader(...), DataLoader(...)
|
|
trainer = Trainer()
|
|
model = LightningModule()
|
|
trainer.fit(model, train_dataloader=train, val_dataloaders=val)
|
|
|
|
# Option 1 & 2 can be mixed, for example the training set can be
|
|
# defined as part of the model, and validation can then be feed to .fit()
|
|
|
|
"""
|
|
# bind logger and other properties
|
|
model.logger = self.logger
|
|
self.copy_trainer_model_properties(model)
|
|
|
|
# clean hparams
|
|
if hasattr(model, 'hparams'):
|
|
parsing.clean_namespace(model.hparams)
|
|
|
|
# set up the passed in dataloaders (if needed)
|
|
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
|
|
|
|
# check that model is configured correctly
|
|
self.check_model_configuration(model)
|
|
|
|
# callbacks
|
|
self.on_fit_start()
|
|
if self.is_function_implemented('on_fit_start'):
|
|
model.on_fit_start()
|
|
|
|
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
|
|
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
|
|
if self.can_prepare_data():
|
|
model.prepare_data()
|
|
self._is_data_prepared = True
|
|
|
|
# Run auto batch size scaling
|
|
if self.auto_scale_batch_size:
|
|
if isinstance(self.auto_scale_batch_size, bool):
|
|
self.auto_scale_batch_size = 'power'
|
|
self.scale_batch_size(model, mode=self.auto_scale_batch_size)
|
|
model.logger = self.logger # reset logger binding
|
|
|
|
# Run learning rate finder:
|
|
if self.auto_lr_find:
|
|
self._run_lr_finder_internally(model)
|
|
model.logger = self.logger # reset logger binding
|
|
|
|
# route to appropriate start method
|
|
# when using multi-node or DDP within a node start each module in a separate process
|
|
if self.use_ddp2:
|
|
if self.is_slurm_managing_tasks:
|
|
task = int(os.environ['SLURM_LOCALID'])
|
|
|
|
# torchelastic or general non_slurm ddp2
|
|
elif 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ):
|
|
task = int(os.environ['LOCAL_RANK'])
|
|
|
|
self.ddp_train(task, model)
|
|
elif self.use_ddp:
|
|
if self.is_slurm_managing_tasks:
|
|
task = int(os.environ['SLURM_LOCALID'])
|
|
self.ddp_train(task, model)
|
|
|
|
# torchelastic or general non_slurm ddp
|
|
elif 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ):
|
|
task = int(os.environ['LOCAL_RANK'])
|
|
self.ddp_train(task, model)
|
|
|
|
elif self.distributed_backend == 'cpu_ddp':
|
|
self.__set_random_port()
|
|
self.model = model
|
|
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,))
|
|
|
|
elif self.distributed_backend == 'ddp_spawn':
|
|
model.share_memory()
|
|
|
|
# spin up peers
|
|
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model, ))
|
|
|
|
elif self.distributed_backend == 'ddp':
|
|
self.spawn_ddp_children(model)
|
|
|
|
# 1 gpu or dp option triggers training using DP module
|
|
# easier to avoid NCCL issues
|
|
elif self.use_dp:
|
|
self.dp_train(model)
|
|
|
|
elif self.use_horovod:
|
|
self.horovod_train(model)
|
|
|
|
elif self.single_gpu:
|
|
self.single_gpu_train(model)
|
|
|
|
elif self.use_tpu: # pragma: no-cover
|
|
rank_zero_info(f'training on {self.tpu_cores} TPU cores')
|
|
|
|
# COLAB_GPU is an env var available by default in Colab environments.
|
|
start_method = 'fork' if self.on_colab_kaggle else 'spawn'
|
|
|
|
# track for predict
|
|
self.model = model
|
|
|
|
# train
|
|
if self.tpu_id is not None:
|
|
self.tpu_train(self.tpu_id, model)
|
|
else:
|
|
xmp.spawn(self.tpu_train, args=(model,), nprocs=self.tpu_cores, start_method=start_method)
|
|
|
|
# load weights if not interrupted
|
|
self.load_spawn_weights(model)
|
|
self.model = model
|
|
|
|
# ON CPU
|
|
else:
|
|
# run through amp wrapper
|
|
if self.use_amp:
|
|
raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option')
|
|
|
|
# CHOOSE OPTIMIZER
|
|
# allow for lr schedulers as well
|
|
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
|
|
|
|
self.run_pretrain_routine(model)
|
|
|
|
# callbacks
|
|
self.on_fit_end()
|
|
|
|
# model hooks
|
|
if self.is_function_implemented('on_fit_end'):
|
|
model.on_fit_end()
|
|
|
|
# return 1 when finished
|
|
# used for testing or when we need to know that training succeeded
|
|
return 1
|
|
|
|
def can_prepare_data(self):
|
|
if self.prepare_data_per_node:
|
|
return self.local_rank == 0
|
|
|
|
else:
|
|
return self.node_rank == 0 and self.local_rank == 0
|
|
|
|
def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None):
|
|
# when dataloader is passed via fit, patch the train_dataloader
|
|
# functions to overwrite with these implementations
|
|
if train_dataloader is not None:
|
|
model.train_dataloader = _PatchDataLoader(train_dataloader)
|
|
|
|
if val_dataloaders is not None:
|
|
model.val_dataloader = _PatchDataLoader(val_dataloaders)
|
|
|
|
if test_dataloaders is not None:
|
|
model.test_dataloader = _PatchDataLoader(test_dataloaders)
|
|
|
|
def run_pretrain_routine(self, model: LightningModule):
|
|
"""Sanity check a few things before starting actual training.
|
|
|
|
Args:
|
|
model: The model to run sanity test on.
|
|
"""
|
|
ref_model = model
|
|
if self.data_parallel:
|
|
ref_model = model.module
|
|
|
|
# give model convenience properties
|
|
ref_model.trainer = self
|
|
|
|
# set local properties on the model
|
|
self.copy_trainer_model_properties(ref_model)
|
|
|
|
# init amp. Must be done here instead of __init__ to allow ddp to work
|
|
if self.use_native_amp and self.precision == 16:
|
|
self.scaler = torch.cuda.amp.GradScaler()
|
|
|
|
# log hyper-parameters
|
|
if self.logger is not None:
|
|
# save exp to get started
|
|
self.logger.log_hyperparams(ref_model.hparams)
|
|
|
|
self.logger.save()
|
|
|
|
if self.use_ddp or self.use_ddp2:
|
|
torch_distrib.barrier()
|
|
|
|
# wait for all models to restore weights
|
|
if self.on_tpu and XLA_AVAILABLE:
|
|
# wait for all processes to catch up
|
|
torch_xla.core.xla_model.rendezvous("pl.Trainer.run_pretrain_routine")
|
|
|
|
elif self.use_horovod:
|
|
# wait for all processes to catch up
|
|
hvd.join()
|
|
|
|
# register auto-resubmit when on SLURM
|
|
self.register_slurm_signal_handlers()
|
|
|
|
# print model summary
|
|
if self.is_global_zero and self.weights_summary is not None and not self.testing:
|
|
if self.weights_summary in ModelSummary.MODES:
|
|
ref_model.summarize(mode=self.weights_summary)
|
|
else:
|
|
raise MisconfigurationException(
|
|
"weights_summary can be None, " + ", ".join(ModelSummary.MODES)
|
|
)
|
|
|
|
# track model now.
|
|
# if cluster resets state, the model will update with the saved weights
|
|
self.model = model
|
|
|
|
# set up checkpoint callback
|
|
self.configure_checkpoint_callback()
|
|
|
|
# restore training and model before hpc call
|
|
self.restore_weights(model)
|
|
|
|
# when testing requested only run test and return
|
|
if self.testing:
|
|
# only load test dataloader for testing
|
|
# self.reset_test_dataloader(ref_model)
|
|
self.run_evaluation(test_mode=True)
|
|
return
|
|
|
|
# check if we should run validation during training
|
|
self.disable_validation = not (self.is_overridden('validation_step') and self.limit_val_batches > 0) \
|
|
and not self.fast_dev_run
|
|
|
|
# run tiny validation (if validation defined)
|
|
# to make sure program won't crash during val
|
|
if not self.disable_validation and self.num_sanity_val_steps > 0:
|
|
self.reset_val_dataloader(ref_model)
|
|
|
|
# hook and callback
|
|
ref_model.on_sanity_check_start()
|
|
self.on_sanity_check_start()
|
|
|
|
num_loaders = len(self.val_dataloaders)
|
|
max_batches = [self.num_sanity_val_steps] * num_loaders
|
|
eval_results = self._evaluate(model,
|
|
self.val_dataloaders,
|
|
max_batches,
|
|
False)
|
|
_, _, _, callback_metrics, _ = self.process_output(eval_results)
|
|
|
|
self.on_sanity_check_end()
|
|
|
|
# verify that early stop has conditioned on a metric that exists
|
|
if self.enable_early_stop:
|
|
self.early_stop_callback._validate_condition_metric(callback_metrics)
|
|
|
|
# clear cache before training
|
|
if self.on_gpu and self.root_gpu is not None:
|
|
# use context because of:
|
|
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
|
|
with torch.cuda.device(f'cuda:{self.root_gpu}'):
|
|
torch.cuda.empty_cache()
|
|
|
|
# CORE TRAINING LOOP
|
|
self.train()
|
|
|
|
def test(
|
|
self,
|
|
model: Optional[LightningModule] = None,
|
|
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
|
|
ckpt_path: Optional[str] = 'best'
|
|
):
|
|
r"""
|
|
|
|
Separates from fit to make sure you never run on your test set until you want to.
|
|
|
|
Args:
|
|
model: The model to test.
|
|
|
|
test_dataloaders: Either a single
|
|
Pytorch Dataloader or a list of them, specifying validation samples.
|
|
|
|
ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
|
|
If ``None``, use the weights from the last epoch to test. Default to ``best``.
|
|
|
|
Example::
|
|
|
|
# Option 1
|
|
# run test with the best checkpoint from ``ModelCheckpoint`` after fitting.
|
|
test = DataLoader(...)
|
|
trainer = Trainer()
|
|
model = LightningModule()
|
|
|
|
trainer.fit(model)
|
|
trainer.test(test_dataloaders=test)
|
|
|
|
# Option 2
|
|
# run test with the specified checkpoint after fitting
|
|
test = DataLoader(...)
|
|
trainer = Trainer()
|
|
model = LightningModule()
|
|
|
|
trainer.fit(model)
|
|
trainer.test(test_dataloaders=test, ckpt_path='path/to/checkpoint.ckpt')
|
|
|
|
# Option 3
|
|
# run test with the weights from the end of training after fitting
|
|
test = DataLoader(...)
|
|
trainer = Trainer()
|
|
model = LightningModule()
|
|
|
|
trainer.fit(model)
|
|
trainer.test(test_dataloaders=test, ckpt_path=None)
|
|
|
|
# Option 4
|
|
# run test from a loaded model. ``ckpt_path`` is ignored in this case.
|
|
test = DataLoader(...)
|
|
model = LightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')
|
|
trainer = Trainer()
|
|
trainer.test(model, test_dataloaders=test)
|
|
"""
|
|
if model is None and ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0:
|
|
raise MisconfigurationException(
|
|
'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.')
|
|
|
|
# if model is not given (None), ckpt_path is given,
|
|
# load the given checkpoint for testing
|
|
if model is None and ckpt_path is not None:
|
|
# ckpt_path is 'best' so load the best model
|
|
if ckpt_path == 'best':
|
|
ckpt_path = self.checkpoint_callback.best_model_path
|
|
model = self.get_model().load_from_checkpoint(ckpt_path)
|
|
|
|
self.testing = True
|
|
|
|
if test_dataloaders is not None:
|
|
if model:
|
|
self.__attach_dataloaders(model, test_dataloaders=test_dataloaders)
|
|
else:
|
|
self.__attach_dataloaders(self.model, test_dataloaders=test_dataloaders)
|
|
|
|
if model is not None:
|
|
self.model = model
|
|
self.fit(model)
|
|
|
|
# on tpu, .spawn means we don't have a trained model
|
|
# TODO: remove TPU spawn
|
|
elif self.use_tpu: # pragma: no-cover
|
|
# attempt to load weights from a spawn
|
|
path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
|
|
test_model = self.model
|
|
if os.path.exists(path):
|
|
test_model = self.load_spawn_weights(self.model)
|
|
|
|
self.fit(test_model)
|
|
else:
|
|
self.run_evaluation(test_mode=True)
|
|
|
|
self.testing = False
|
|
|
|
def check_model_configuration(self, model: LightningModule):
|
|
r"""
|
|
Checks that the model is configured correctly before training or testing is started.
|
|
|
|
Args:
|
|
model: The model to check the configuration.
|
|
|
|
"""
|
|
# Check training_step, train_dataloader, configure_optimizer methods
|
|
if not self.testing:
|
|
if not self.is_overridden('training_step', model):
|
|
raise MisconfigurationException(
|
|
'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
|
|
' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.')
|
|
|
|
if not self.is_overridden('train_dataloader', model):
|
|
raise MisconfigurationException(
|
|
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
|
|
' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.')
|
|
|
|
if not self.is_overridden('configure_optimizers', model):
|
|
raise MisconfigurationException(
|
|
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
|
|
' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.')
|
|
|
|
# Check val_dataloader, validation_step and validation_epoch_end
|
|
if self.is_overridden('val_dataloader', model):
|
|
if not self.is_overridden('validation_step', model):
|
|
raise MisconfigurationException('You have passed in a `val_dataloader()`'
|
|
' but have not defined `validation_step()`.')
|
|
else:
|
|
if not self.is_overridden('validation_epoch_end', model):
|
|
rank_zero_warn(
|
|
'You have defined a `val_dataloader()` and have defined a `validation_step()`,'
|
|
' you may also want to define `validation_epoch_end()` for accumulating stats.',
|
|
RuntimeWarning
|
|
)
|
|
else:
|
|
if self.is_overridden('validation_step', model):
|
|
raise MisconfigurationException('You have defined `validation_step()`,'
|
|
' but have not passed in a `val_dataloader()`.')
|
|
|
|
# Check test_dataloader, test_step and test_epoch_end
|
|
if self.is_overridden('test_dataloader', model):
|
|
if not self.is_overridden('test_step', model):
|
|
raise MisconfigurationException('You have passed in a `test_dataloader()`'
|
|
' but have not defined `test_step()`.')
|
|
else:
|
|
if not self.is_overridden('test_epoch_end', model):
|
|
rank_zero_warn(
|
|
'You have defined a `test_dataloader()` and have defined a `test_step()`, you may also want to'
|
|
' define `test_epoch_end()` for accumulating stats.', RuntimeWarning
|
|
)
|
|
else:
|
|
if self.testing and self.is_overridden('test_step', model):
|
|
raise MisconfigurationException('You have defined `test_step()` but did not'
|
|
' implement `test_dataloader` nor passed in `.test(test_dataloader)`.')
|
|
|
|
|
|
class _PatchDataLoader(object):
|
|
r"""
|
|
Callable object for patching dataloaders passed into trainer.fit().
|
|
Use this class to override model.*_dataloader() and be pickle-compatible.
|
|
|
|
Args:
|
|
dataloader: Dataloader object to return when called.
|
|
|
|
"""
|
|
|
|
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
|
|
self.dataloader = dataloader
|
|
|
|
# cannot pickle __code__ so cannot verify if PatchDataloader
|
|
# exists which shows dataloader methods have been overwritten.
|
|
# so, we hack it by using the string representation
|
|
self.patch_loader_code = str(self.__call__.__code__)
|
|
|
|
def __call__(self) -> Union[List[DataLoader], DataLoader]:
|
|
return self.dataloader
|