lightning/pytorch_lightning/trainer/trainer.py

525 lines
20 KiB
Python

"""
The trainer handles all the logic for running a val loop, training loop, distributing, etc.. .
"""
import os
import sys
import warnings
import logging
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import tqdm
from torch.optim.optimizer import Optimizer
from pytorch_lightning.trainer.amp_mixin import TrainerAMPMixin
from pytorch_lightning.trainer.callback_config_mixin import TrainerCallbackConfigMixin
from pytorch_lightning.trainer.data_loading_mixin import TrainerDataLoadingMixin
from pytorch_lightning.trainer.ddp_mixin import TrainerDDPMixin
from pytorch_lightning.trainer.dp_mixin import TrainerDPMixin
from pytorch_lightning.trainer.dp_mixin import (
parse_gpu_ids,
determine_root_gpu_device
)
from pytorch_lightning.trainer.evaluation_loop_mixin import TrainerEvaluationLoopMixin
from pytorch_lightning.trainer.logging_mixin import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks_mixin import TrainerModelHooksMixin
from pytorch_lightning.trainer.train_loop_mixin import TrainerTrainLoopMixin
from pytorch_lightning.trainer.trainer_io import TrainerIOMixin
from pytorch_lightning.trainer.training_tricks_mixin import TrainerTrainingTricksMixin
from pytorch_lightning.utilities.debugging import MisconfigurationException
try:
from apex import amp
APEX_AVAILABLE = True
except ImportError:
APEX_AVAILABLE = False
class Trainer(TrainerIOMixin,
TrainerDDPMixin,
TrainerDPMixin,
TrainerDataLoadingMixin,
TrainerAMPMixin,
TrainerEvaluationLoopMixin,
TrainerTrainLoopMixin,
TrainerLoggingMixin,
TrainerTrainingTricksMixin,
TrainerCallbackConfigMixin,
TrainerModelHooksMixin):
def __init__(
self,
logger=True,
checkpoint_callback=True,
early_stop_callback=True,
default_save_path=None,
gradient_clip_val=0,
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
process_position=0,
nb_gpu_nodes=None, # backward compatible, todo: remove in v0.8.0
num_nodes=1,
gpus=None,
log_gpu_memory=None,
show_progress_bar=True,
overfit_pct=0.0,
track_grad_norm=-1,
check_val_every_n_epoch=1,
fast_dev_run=False,
accumulate_grad_batches=1,
max_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
max_num_epochs=1000,
min_num_epochs=1,
train_percent_check=1.0,
val_percent_check=1.0,
test_percent_check=1.0,
val_check_interval=1.0,
log_save_interval=100,
row_log_interval=10,
add_row_log_interval=None, # backward compatible, todo: remove in v0.8.0
distributed_backend=None,
use_amp=False,
print_nan_grads=False,
weights_summary='full',
weights_save_path=None,
amp_level='O1',
nb_sanity_val_steps=None, # backward compatible, todo: remove in v0.8.0
num_sanity_val_steps=5,
truncated_bptt_steps=None,
resume_from_checkpoint=None,
):
"""
:param logger: Logger for experiment tracking
:param checkpoint_callback: Callback for checkpointing
:param early_stop_callback: Callback for early stopping
:param str default_save_path: Default path for logs+weights if no logger/ckpt_callback passed
:param int gradient_clip_val: 0 means don't clip.
:param int gradient_clip: 0 means don't clip. Deprecated.
:param process_position: shown in the tqdm bar
:param int num_nodes: number of GPU nodes
:param list|str|int gpus: int. (ie: 2 gpus) OR list to specify which GPUs [0, 1] OR '0,1'
OR '-1' / -1 to use all available gpus
:param str log_gpu_memory: None, 'min_max', 'all'
:param bool show_progress_bar: If true shows tqdm bar
:param float overfit_pct: uses this much of all datasets
:param int track_grad_norm: -1 no tracking. Otherwise tracks that norm
:param int check_val_every_n_epoch: check val every n train epochs
:param bool fast_dev_run: runs full iteration over everything to find bugs
:param int accumulate_grad_batches: Accumulates grads every k batches
:param int max_num_epochs:
:param int min_num_epochs:
:param int train_percent_check: How much of train set to check
:param int val_percent_check: How much of val set to check
:param int test_percent_check: How much of test set to check
:param float|int val_check_interval: If float, % of tng epoch. If int, check every n batch
:param int log_save_interval: Writes logs to disk this often
:param int row_log_interval: How often to add logging rows
:param int add_row_log_interval: How often to add logging rows. Deprecated.
:param str distributed_backend: Options: 'dp', 'ddp', 'ddp2'.
:param bool use_amp: If true uses apex for 16bit precision
:param bool print_nan_grads: Prints nan gradients
:param str weights_summary: Options: 'full', 'top', None to not print.
:param bool weights_save_path: Where to save weights if on cluster
:param str amp_level: Check nvidia docs for level
:param int num_sanity_val_steps: How many val steps before a full train loop.
:param int truncated_bptt_steps: Enables multiple backward passes for each batch.
.. warning:: Following arguments become deprecated and they will be removed in v0.8.0:
- `gradient_clip`,
- `nb_gpu_nodes`,
- `max_nb_epochs`,
- `min_nb_epochs`,
- `add_row_log_interval`,
- `nb_sanity_val_steps`
"""
# Transfer params
if nb_gpu_nodes is not None: # Backward compatibility
warnings.warn("`nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0"
" and will be removed in v0.8.0", DeprecationWarning)
if not num_nodes: # in case you did not set the proper value
num_nodes = nb_gpu_nodes
self.num_gpu_nodes = num_nodes
self.log_gpu_memory = log_gpu_memory
if gradient_clip is not None: # Backward compatibility
warnings.warn("`gradient_clip` has renamed to `gradient_clip_val` since v0.5.0"
" and will be removed in v0.8.0", DeprecationWarning)
if not gradient_clip_val: # in case you did not set the proper value
gradient_clip_val = gradient_clip
self.gradient_clip_val = gradient_clip_val
self.check_val_every_n_epoch = check_val_every_n_epoch
self.track_grad_norm = track_grad_norm
self.on_gpu = True if (gpus and torch.cuda.is_available()) else False
self.process_position = process_position
self.weights_summary = weights_summary
if max_nb_epochs is not None: # Backward compatibility
warnings.warn("`max_nb_epochs` has renamed to `max_num_epochs` since v0.5.0"
" and will be removed in v0.8.0", DeprecationWarning)
if not max_num_epochs: # in case you did not set the proper value
max_num_epochs = max_nb_epochs
self.max_num_epochs = max_num_epochs
if min_nb_epochs is not None: # Backward compatibility
warnings.warn("`min_nb_epochs` has renamed to `min_num_epochs` since v0.5.0"
" and will be removed in v0.8.0", DeprecationWarning)
if not min_num_epochs: # in case you did not set the proper value
min_num_epochs = min_nb_epochs
self.min_num_epochs = min_num_epochs
if nb_sanity_val_steps is not None: # Backward compatibility
warnings.warn("`nb_sanity_val_steps` has renamed to `num_sanity_val_steps` since v0.5.0"
" and will be removed in v0.8.0", DeprecationWarning)
if not num_sanity_val_steps: # in case you did not set the proper value
num_sanity_val_steps = nb_sanity_val_steps
self.num_sanity_val_steps = num_sanity_val_steps
self.print_nan_grads = print_nan_grads
self.truncated_bptt_steps = truncated_bptt_steps
self.resume_from_checkpoint = resume_from_checkpoint
self.shown_warnings = set()
self.fast_dev_run = fast_dev_run
if self.fast_dev_run:
self.num_sanity_val_steps = 1
self.max_num_epochs = 1
m = '''
Running in fast_dev_run mode: will run a full train,
val loop using a single batch
'''
logging.info(m)
# set default save path if user didn't provide one
self.default_save_path = default_save_path
if self.default_save_path is None:
self.default_save_path = os.getcwd()
# training bookeeping
self.total_batch_idx = 0
self.running_loss = []
self.avg_loss = 0
self.batch_idx = 0
self.tqdm_metrics = {}
self.callback_metrics = {}
self.num_val_batches = 0
self.num_training_batches = 0
self.num_test_batches = 0
self.get_train_dataloader = None
self.get_test_dataloaders = None
self.get_val_dataloaders = None
self.is_iterable_train_dataloader = False
# training state
self.model = None
self.testing = False
self.lr_schedulers = []
self.optimizers = None
self.global_step = 0
self.current_epoch = 0
self.total_batches = 0
# configure early stop callback
# creates a default one if none passed in
self.early_stop_callback = None
self.configure_early_stopping(early_stop_callback, logger)
self.reduce_lr_on_plateau_scheduler = None
# configure checkpoint callback
self.checkpoint_callback = checkpoint_callback
self.weights_save_path = weights_save_path
# accumulated grads
self.configure_accumulated_gradients(accumulate_grad_batches)
# allow int, string and gpu list
self.data_parallel_device_ids = parse_gpu_ids(gpus)
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)
# distributed backend choice
self.use_ddp = False
self.use_ddp2 = False
self.use_dp = False
self.single_gpu = False
self.distributed_backend = distributed_backend
self.set_distributed_mode(distributed_backend, num_nodes)
# init flags for SLURM+ddp to work
self.proc_rank = 0
self.world_size = 1
self.node_rank = 0
self.configure_slurm_ddp(num_nodes)
# nvidia setup
self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)
# can't init progress bar here because starting a new process
# means the progress_bar won't survive pickling
self.show_progress_bar = show_progress_bar
# logging
self.log_save_interval = log_save_interval
self.val_check_interval = val_check_interval
if add_row_log_interval is not None:
# backward compatibility
warnings.warn("`add_row_log_interval` has renamed to `row_log_interval` since v0.5.0"
" and will be removed in v0.8.0", DeprecationWarning)
if not row_log_interval: # in case you did not set the proper value
row_log_interval = add_row_log_interval
self.row_log_interval = row_log_interval
# how much of the data to use
self.determine_data_use_amount(train_percent_check, val_percent_check,
test_percent_check, overfit_pct)
# 16 bit mixed precision training using apex
self.amp_level = amp_level
self.init_amp(use_amp)
@property
def slurm_job_id(self):
try:
job_id = os.environ['SLURM_JOB_ID']
job_id = int(job_id)
except Exception:
job_id = None
return job_id
def __parse_gpu_ids(self, gpus):
"""Parse GPUs id.
:param list|str|int gpus: input GPU ids
:return list(int):
"""
# if gpus = -1 then use all available devices
# otherwise, split the string using commas
if gpus is not None:
if isinstance(gpus, list):
gpus = gpus
elif isinstance(gpus, str):
if gpus == '-1':
gpus = list(range(0, torch.cuda.device_count()))
else:
gpus = [int(x.strip()) for x in gpus.split(',')]
elif isinstance(gpus, int):
gpus = gpus
else:
raise ValueError('`gpus` has to be a string, int or list of ints')
return gpus
def __set_root_gpu(self, gpus):
if gpus is None:
return None
# set root gpu
root_gpu = 0
if type(gpus) is list:
root_gpu = gpus[0]
return root_gpu
@property
def num_gpus(self):
gpus = self.data_parallel_device_ids
if gpus is None:
return 0
else:
return len(gpus)
@property
def data_parallel(self):
return self.use_dp or self.use_ddp or self.use_ddp2
@property
def training_tqdm_dict(self):
"""Read-only for tqdm metrics.
:return:
"""
tqdm_dict = {
'loss': '{0:.3f}'.format(self.avg_loss),
'batch_idx': '{}'.format(self.batch_idx),
}
if self.truncated_bptt_steps is not None:
tqdm_dict['split_idx'] = self.split_idx
if self.logger is not None and self.logger.version is not None:
tqdm_dict['v_nb'] = self.logger.version
tqdm_dict.update(self.tqdm_metrics)
if self.on_gpu:
tqdm_dict['gpu'] = '{}'.format(torch.cuda.current_device())
return tqdm_dict
@property
def tng_tqdm_dic(self):
"""Read-only for tqdm metrics.
.. warning:: Deprecated in v0.5.0. use training_tqdm_dict instead.
:return:
"""
warnings.warn("`tng_tqdm_dic` has renamed to `training_tqdm_dict` since v0.5.0"
" and will be removed in v0.8.0", DeprecationWarning)
return self.training_tqdm_dict
# -----------------------------
# MODEL TRAINING
# -----------------------------
def fit(self, model):
# when using multi-node or DDP within a node start each module in a separate process
if self.use_ddp2:
task = int(os.environ['SLURM_LOCALID'])
self.ddp_train(task, model)
elif self.use_ddp:
if self.is_slurm_managing_tasks:
task = int(os.environ['SLURM_LOCALID'])
self.ddp_train(task, model)
else:
mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,))
# 1 gpu or dp option triggers training using DP module
# easier to avoid NCCL issues
elif self.use_dp:
self.dp_train(model)
elif self.single_gpu:
self.single_gpu_train(model)
# ON CPU
else:
# run through amp wrapper
if self.use_amp:
raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option')
# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
self.run_pretrain_routine(model)
# return 1 when finished
# used for testing or when we need to know that training succeeded
return 1
def init_optimizers(self, optimizers):
# single optimizer
if isinstance(optimizers, Optimizer):
return [optimizers], []
# two lists
elif len(optimizers) == 2 and isinstance(optimizers[0], list):
optimizers, lr_schedulers = optimizers
lr_schedulers, self.reduce_lr_on_plateau_scheduler = self.configure_schedulers(lr_schedulers)
return optimizers, lr_schedulers
# single list or tuple
elif isinstance(optimizers, list) or isinstance(optimizers, tuple):
return optimizers, []
def configure_schedulers(self, schedulers):
for i, scheduler in enumerate(schedulers):
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
reduce_lr_on_plateau_scheduler = schedulers.pop(i)
return schedulers, reduce_lr_on_plateau_scheduler
return schedulers, None
def run_pretrain_routine(self, model):
"""Sanity check a few things before starting actual training.
:param model:
"""
ref_model = model
if self.data_parallel:
ref_model = model.module
# give model convenience properties
ref_model.trainer = self
# set local properties on the model
self.copy_trainer_model_properties(ref_model)
# link up experiment object
if self.logger is not None:
ref_model.logger = self.logger
# save exp to get started
if hasattr(ref_model, "hparams"):
self.logger.log_hyperparams(ref_model.hparams)
self.logger.save()
if self.use_ddp or self.use_ddp2:
dist.barrier()
# set up checkpoint callback
self.configure_checkpoint_callback()
# register auto-resubmit when on SLURM
self.register_slurm_signal_handlers()
# transfer data loaders from model
self.get_dataloaders(ref_model)
# print model summary
if self.proc_rank == 0 and self.weights_summary is not None:
if self.weights_summary in ['full', 'top']:
ref_model.summarize(mode=self.weights_summary)
else:
m = "weights_summary can be None, 'full' or 'top'"
raise MisconfigurationException(m)
# track model now.
# if cluster resets state, the model will update with the saved weights
self.model = model
# restore training and model before hpc call
self.restore_weights(model)
# when testing requested only run test and return
if self.testing:
self.run_evaluation(test=True)
return
# run tiny validation (if validation defined)
# to make sure program won't crash during val
ref_model.on_sanity_check_start()
if self.get_val_dataloaders() is not None and self.num_sanity_val_steps > 0:
# init progress bars for validation sanity check
pbar = tqdm.tqdm(desc='Validation sanity check', total=self.num_sanity_val_steps,
leave=False, position=2 * self.process_position,
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch')
self.main_progress_bar = pbar
# dummy validation progress bar
self.val_progress_bar = tqdm.tqdm(disable=True)
self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, self.testing)
# close progress bars
self.main_progress_bar.close()
self.val_progress_bar.close()
# init progress bar
pbar = tqdm.tqdm(leave=True, position=2 * self.process_position,
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch',
file=sys.stdout)
self.main_progress_bar = pbar
# clear cache before training
if self.on_gpu:
torch.cuda.empty_cache()
# CORE TRAINING LOOP
self.train()
def test(self, model=None):
self.testing = True
if model is not None:
self.fit(model)
else:
self.run_evaluation(test=True)