lightning/pytorch_lightning/models/trainer.py

874 lines
30 KiB
Python
Raw Normal View History

2019-07-09 00:12:27 +00:00
"""
The trainer handles all the logic for running a val loop, training loop, distributing, etc...
"""
2019-07-09 00:11:20 +00:00
import subprocess
import traceback
import warnings
import os
2019-07-18 15:21:35 +00:00
import pdb
import re
2019-07-09 00:11:20 +00:00
2019-03-31 01:45:16 +00:00
import torch
2019-07-09 00:11:20 +00:00
from torch.utils.data.distributed import DistributedSampler
from torch.optim.lr_scheduler import MultiStepLR
import torch.multiprocessing as mp
import torch.distributed as dist
2019-03-31 01:45:16 +00:00
import numpy as np
2019-07-09 00:11:20 +00:00
import tqdm
2019-03-31 20:29:50 +00:00
from pytorch_lightning.root_module.memory import get_gpu_memory_map
from pytorch_lightning.root_module.model_saving import TrainerIO
2019-07-18 15:08:48 +00:00
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDistributedDataParallel, LightningDataParallel
2019-07-08 22:02:41 +00:00
2019-03-31 01:45:16 +00:00
2019-05-14 00:40:07 +00:00
try:
from apex import amp
APEX_AVAILABLE = True
except ModuleNotFoundError:
APEX_AVAILABLE = False
2019-03-31 01:45:16 +00:00
2019-07-09 00:12:27 +00:00
2019-07-18 15:29:21 +00:00
def reduce_distributed_output(output, nb_gpus):
2019-07-18 16:08:17 +00:00
if nb_gpus <= 1:
2019-07-18 15:40:00 +00:00
return output
2019-07-18 16:08:17 +00:00
# when using DP, we get one output per gpu
# average outputs and return
if type(output) is torch.Tensor:
return output.mean()
2019-07-18 15:29:21 +00:00
for k, v in output.items():
# recurse on nested dics
if isinstance(output[k], dict):
output[k] = reduce_distributed_output(output[k], nb_gpus)
# reduce only metrics that have the same nb of gpus
elif output[k].size(0) == nb_gpus:
reduced = torch.mean(output[k])
output[k] = reduced
return output
2019-03-31 01:45:16 +00:00
class Trainer(TrainerIO):
def __init__(self,
experiment,
2019-07-15 18:53:37 +00:00
early_stop_callback=None,
2019-07-15 17:17:38 +00:00
checkpoint_callback=None,
2019-07-01 22:38:07 +00:00
gradient_clip=0,
2019-03-31 20:29:50 +00:00
cluster=None,
2019-03-31 01:45:16 +00:00
process_position=0,
current_gpu_name=0,
2019-07-08 21:33:20 +00:00
nb_gpu_nodes=1,
2019-07-01 22:38:07 +00:00
gpus=None,
progress_bar=True,
2019-03-31 20:29:50 +00:00
overfit_pct=0.0,
2019-03-31 01:45:16 +00:00
track_grad_norm=-1,
check_val_every_n_epoch=1,
fast_dev_run=False,
2019-03-31 20:29:50 +00:00
accumulate_grad_batches=1,
2019-07-15 18:54:38 +00:00
max_nb_epochs=1000, min_nb_epochs=1,
2019-07-16 00:48:46 +00:00
train_percent_check=1.0, val_percent_check=1.0, test_percent_check=1.0,
val_check_interval=0.95,
2019-07-01 22:38:07 +00:00
log_save_interval=100, add_log_row_interval=10,
2019-03-31 01:45:16 +00:00
lr_scheduler_milestones=None,
2019-07-18 15:15:21 +00:00
distributed_backend='dp',
2019-05-14 02:02:53 +00:00
use_amp=False,
2019-07-01 22:38:07 +00:00
print_nan_grads=False,
2019-07-16 01:11:29 +00:00
print_weights_summary=True,
2019-05-16 19:45:56 +00:00
amp_level='O2',
2019-03-31 01:45:16 +00:00
nb_sanity_val_steps=5):
2019-07-18 16:04:19 +00:00
"""
:param experiment: Test-tube experiment
:param early_stop_callback: from pytorch_lightning import EarlyStopping
:param checkpoint_callback: from pytorch_lightning import Checkpoint
:param gradient_clip:
:param cluster:
:param process_position:
:param current_gpu_name:
:param nb_gpu_nodes:
:param gpus:
:param progress_bar:
:param overfit_pct:
:param track_grad_norm:
:param check_val_every_n_epoch:
:param fast_dev_run:
:param accumulate_grad_batches:
:param max_nb_epochs:
:param min_nb_epochs:
:param train_percent_check:
:param val_percent_check:
:param test_percent_check:
:param val_check_interval:
:param log_save_interval:
:param add_log_row_interval:
:param lr_scheduler_milestones:
:param distributed_backend: 'np' to use DistributedParallel, 'ddp' to use DistributedDataParallel
:param use_amp:
:param print_nan_grads:
:param print_weights_summary:
:param amp_level:
:param nb_sanity_val_steps:
"""
2019-03-31 01:45:16 +00:00
# Transfer params
2019-07-03 20:34:49 +00:00
self.nb_gpu_nodes = nb_gpu_nodes
2019-07-01 22:38:07 +00:00
self.gradient_clip = gradient_clip
2019-03-31 01:45:16 +00:00
self.check_val_every_n_epoch = check_val_every_n_epoch
2019-07-15 18:53:37 +00:00
self.enable_early_stop = early_stop_callback is not None
2019-03-31 01:45:16 +00:00
self.track_grad_norm = track_grad_norm
self.fast_dev_run = fast_dev_run
2019-07-01 22:38:07 +00:00
self.on_gpu = gpus is not None and torch.cuda.is_available()
self.progress_bar = progress_bar
2019-03-31 01:45:16 +00:00
self.experiment = experiment
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)
self.cluster = cluster
self.process_position = process_position
self.current_gpu_name = current_gpu_name
2019-07-16 01:11:29 +00:00
self.print_weights_summary = print_weights_summary
2019-03-31 01:45:16 +00:00
self.checkpoint_callback = checkpoint_callback
2019-07-15 17:18:56 +00:00
if self.checkpoint_callback is not None:
self.checkpoint_callback.save_function = self.save_checkpoint
2019-03-31 01:45:16 +00:00
self.early_stop = early_stop_callback
self.model = None
self.max_nb_epochs = max_nb_epochs
self.accumulate_grad_batches = accumulate_grad_batches
self.early_stop_callback = early_stop_callback
self.min_nb_epochs = min_nb_epochs
self.nb_sanity_val_steps = nb_sanity_val_steps
self.lr_scheduler_milestones = [] if lr_scheduler_milestones is None else [int(x.strip()) for x in lr_scheduler_milestones.split(',')]
self.lr_schedulers = []
2019-05-16 19:45:56 +00:00
self.amp_level = amp_level
2019-07-01 22:38:07 +00:00
self.print_nan_grads = print_nan_grads
2019-07-08 21:44:06 +00:00
self.data_parallel_device_ids = None
2019-07-08 21:51:07 +00:00
self.world_size = 1
self.node_rank = 0
2019-07-18 15:03:16 +00:00
self.use_ddp = False
self.use_dp = False
2019-07-08 13:42:13 +00:00
2019-07-21 12:20:06 +00:00
2019-07-08 13:42:13 +00:00
# gpus come in as a string.
# if gpus = -1 then use all available devices
# otherwise, split the string using commas
if gpus is not None:
2019-07-21 12:08:21 +00:00
if type(gpus) is list:
self.data_parallel_device_ids = gpus
elif type(gpus) is str:
if gpus == '-1':
self.data_parallel_device_ids = list(range(0, torch.cuda.device_count()))
else:
self.data_parallel_device_ids = [int(x.strip()) for x in gpus.split(',')]
2019-07-08 13:42:13 +00:00
else:
2019-07-21 12:08:21 +00:00
raise Exception('gpus has to be a string or list of ids')
2019-06-25 22:51:41 +00:00
2019-07-08 14:00:04 +00:00
# set the correct cuda visible devices (using pci order)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(x) for x in self.data_parallel_device_ids])
2019-07-18 15:48:16 +00:00
print(f'VISIBLE GPUS: {os.environ["CUDA_VISIBLE_DEVICES"]}')
2019-07-08 14:00:04 +00:00
2019-07-18 15:03:16 +00:00
# make DP and DDP mutually exclusive
# single GPU will also use DP with devices=[0]
have_gpus = self.data_parallel_device_ids is not None and len(self.data_parallel_device_ids) > 0
if have_gpus:
2019-07-18 15:15:21 +00:00
self.use_dp = distributed_backend == 'dp'
self.use_ddp = distributed_backend == 'ddp'
2019-07-08 13:44:20 +00:00
2019-07-21 12:20:06 +00:00
# use ddp automatically if nb_gpu_nodes > 1
if nb_gpu_nodes > 1:
self.use_ddp = True
self.use_ddp = False
w = 'DataParallel does not support nb_gpu_nodes > 1. ' \
'Switching to DistributedDataParallel for you. ' \
'To silence this warning set distributed_backend=ddp'
warnings.warn(w)
2019-07-03 21:02:30 +00:00
# process info
self.proc_rank = 0
2019-03-31 01:45:16 +00:00
# training state
self.optimizers = None
self.prog_bar = None
self.global_step = 0
self.current_epoch = 0
self.total_batches = 0
# logging
self.log_save_interval = log_save_interval
self.val_check_interval = val_check_interval
self.add_log_row_interval = add_log_row_interval
# dataloaders
self.tng_dataloader = None
self.test_dataloader = None
self.val_dataloader = None
# how much of the data to use
self.__determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct)
print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu))
2019-07-09 00:13:40 +00:00
# 16 bit mixed precision training using apex
2019-05-14 00:40:07 +00:00
self.use_amp = use_amp and APEX_AVAILABLE
2019-05-14 00:41:23 +00:00
if self.use_amp:
print('using 16bit precision')
2019-05-14 00:40:07 +00:00
2019-07-09 00:02:06 +00:00
if use_amp and not APEX_AVAILABLE:
2019-07-09 00:00:43 +00:00
msg = '''
You set use_amp=True but do not have apex installed.
Install apex first using this guide and rerun with use_amp=True:
https://github.com/NVIDIA/apex#linux
2019-07-09 00:03:31 +00:00
this run will NOT use 16 bit precision
2019-07-09 00:00:43 +00:00
'''
2019-07-09 00:03:08 +00:00
warnings.warn(msg)
2019-07-09 00:00:43 +00:00
2019-07-18 15:08:48 +00:00
@property
def data_parallel(self):
return self.use_dp or self.use_ddp
2019-03-31 01:45:16 +00:00
def __determine_data_use_amount(self, train_percent_check, val_percent_check, test_percent_check, overfit_pct):
"""
Use less data for debugging purposes
"""
self.train_percent_check = train_percent_check
self.val_percent_check = val_percent_check
self.test_percent_check = test_percent_check
if overfit_pct > 0:
self.train_percent_check = overfit_pct
self.val_percent_check = overfit_pct
self.test_percent_check = overfit_pct
2019-07-12 16:42:17 +00:00
def __get_model(self):
return self.model.module if self.data_parallel else self.model
2019-03-31 01:45:16 +00:00
def __is_function_implemented(self, f_name):
2019-07-12 16:42:17 +00:00
model = self.__get_model()
f_op = getattr(model, f_name, None)
2019-03-31 01:45:16 +00:00
return callable(f_op)
@property
def __tng_tqdm_dic(self):
tqdm_dic = {
'tng_loss': '{0:.3f}'.format(self.avg_loss),
'v_nb': '{}'.format(self.experiment.version),
'epoch': '{}'.format(self.current_epoch),
'batch_nb':'{}'.format(self.batch_nb),
}
tqdm_dic.update(self.tqdm_metrics)
2019-07-01 22:38:07 +00:00
if self.on_gpu:
tqdm_dic['gpu'] = '{}'.format(self.current_gpu_name)
2019-03-31 01:45:16 +00:00
return tqdm_dic
2019-07-09 00:13:40 +00:00
def __layout_bookeeping(self):
2019-03-31 01:45:16 +00:00
# training bookeeping
self.total_batch_nb = 0
self.running_loss = []
self.avg_loss = 0
self.batch_nb = 0
self.tqdm_metrics = {}
# determine number of training batches
2019-07-08 23:11:16 +00:00
self.nb_tng_batches = len(self.tng_dataloader)
2019-05-14 10:11:16 +00:00
self.nb_tng_batches = int(self.nb_tng_batches * self.train_percent_check)
2019-03-31 01:45:16 +00:00
# determine number of validation batches
2019-07-08 23:11:16 +00:00
self.nb_val_batches = len(self.val_dataloader)
2019-05-14 10:11:16 +00:00
self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check)
self.nb_val_batches = max(1, self.nb_val_batches)
self.nb_val_batches = self.nb_val_batches
2019-03-31 01:45:16 +00:00
# determine number of test batches
2019-07-08 23:11:16 +00:00
self.nb_test_batches = len(self.test_dataloader)
2019-05-14 10:11:16 +00:00
self.nb_test_batches = int(self.nb_test_batches * self.test_percent_check)
2019-03-31 01:45:16 +00:00
# determine when to check validation
2019-05-14 10:11:16 +00:00
self.val_check_batch = int(self.nb_tng_batches * self.val_check_interval)
2019-03-31 01:45:16 +00:00
def __add_tqdm_metrics(self, metrics):
for k, v in metrics.items():
2019-07-01 22:38:07 +00:00
if type(v) is torch.Tensor:
v = v.item()
2019-03-31 01:45:16 +00:00
self.tqdm_metrics[k] = v
def validate(self, model, dataloader, max_batches):
"""
Run validation code
:param model: PT model
:param dataloader: PT dataloader
:param max_batches: Scalar
:return:
"""
# enable eval mode
model.zero_grad()
model.eval()
# disable gradients to save memory
torch.set_grad_enabled(False)
# bookkeeping
outputs = []
# run training
2019-05-14 10:36:26 +00:00
for batch_i, data_batch in enumerate(dataloader):
2019-03-31 01:45:16 +00:00
if data_batch is None:
continue
# stop short when on fast dev run
2019-05-14 10:40:11 +00:00
if max_batches is not None and batch_i >= max_batches:
2019-03-31 01:45:16 +00:00
break
# -----------------
# RUN VALIDATION STEP
# -----------------
2019-07-18 15:29:21 +00:00
if self.use_ddp:
output = model(data_batch, batch_i)
elif self.use_dp:
2019-07-03 20:51:32 +00:00
output = model(data_batch, batch_i)
2019-07-18 15:29:21 +00:00
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
2019-07-03 20:51:32 +00:00
else:
output = model.validation_step(data_batch, batch_i)
2019-07-01 22:38:07 +00:00
2019-03-31 01:45:16 +00:00
outputs.append(output)
# batch done
2019-07-01 22:38:07 +00:00
if self.progress_bar and self.prog_bar is not None:
2019-03-31 01:45:16 +00:00
self.prog_bar.update(1)
# give model a chance to do something with the outputs
2019-07-01 22:38:07 +00:00
if self.data_parallel:
val_results = model.module.validation_end(outputs)
else:
val_results = model.validation_end(outputs)
2019-03-31 01:45:16 +00:00
# enable train mode again
model.train()
# enable gradients to save memory
torch.set_grad_enabled(True)
2019-07-01 22:38:07 +00:00
2019-03-31 01:45:16 +00:00
return val_results
def __get_dataloaders(self, model):
"""
Dataloaders are provided by the model
:param model:
:return:
"""
self.tng_dataloader = model.tng_dataloader
self.test_dataloader = model.test_dataloader
self.val_dataloader = model.val_dataloader
2019-07-18 15:18:19 +00:00
if self.use_ddp and not isinstance(self.tng_dataloader.sampler, DistributedSampler):
2019-07-08 23:39:59 +00:00
msg = '''
when using multiple gpus and multiple nodes you must pass a DistributedSampler to DataLoader(sampler).
ie: this:
dataset = myDataset()
dataloader = Dataloader(dataset)
becomes:
dataset = myDataset()
dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
dataloader = Dataloader(dataset, sampler=dist_sampler)
'''
raise Exception(msg)
2019-03-31 01:45:16 +00:00
# -----------------------------
# MODEL TRAINING
# -----------------------------
def fit(self, model):
2019-07-08 21:38:57 +00:00
2019-07-18 15:08:48 +00:00
# when using multi-node or DDP within a node start each module in a separate process
if self.use_ddp:
2019-07-18 15:09:37 +00:00
# must copy only the meta of the exp so it survives pickle/unpickle when going to new process
2019-07-08 16:58:47 +00:00
self.experiment = self.experiment.get_meta_copy()
2019-07-18 20:47:46 +00:00
# whenever we have the correct number of tasks, we let slurm manage processes
# otherwise we launch the required number of processes
2019-07-21 16:18:46 +00:00
nb_requested_gpus = len(self.data_parallel_device_ids) * self.nb_gpu_nodes
2019-07-21 16:20:01 +00:00
nb_slurm_tasks = 0
try:
nb_slurm_tasks = int(os.environ['SLURM_NTASKS'])
is_slurm_managing_tasks = nb_slurm_tasks == nb_requested_gpus
except Exception as e:
# likely not on slurm, so set the slurm managed flag to false
is_slurm_managing_tasks = False
2019-07-18 20:47:46 +00:00
if is_slurm_managing_tasks:
task = int(os.environ['SLURM_LOCALID'])
self.ddp_train(task, model)
else:
msg = f"""
You requested {nb_requested_gpus} GPUs but launched {nb_slurm_tasks} slurm tasks.
We will launch {nb_requested_gpus} processes for you.
We recommend you let slurm manage the processes by setting: --ntasks-per-node={nb_requested_gpus}
If you're not using SLURM, ignore this message!
2019-07-18 20:47:46 +00:00
"""
warnings.warn(msg)
mp.spawn(self.ddp_train, nprocs=len(self.data_parallel_device_ids), args=(model, ))
2019-07-14 20:57:15 +00:00
2019-07-18 15:08:48 +00:00
# 1 gpu or dp option triggers training using DP module
# easier to avoid NCCL issues
elif self.use_dp:
self.dp_train(model)
2019-07-14 20:57:15 +00:00
2019-07-18 15:09:00 +00:00
# ON CPU
2019-07-03 19:09:49 +00:00
else:
2019-07-11 19:23:33 +00:00
# CHOOSE OPTIMIZER
# filter out the weights that were done on gpu so we can load on good old cpus
self.optimizers = model.configure_optimizers()
2019-07-11 18:17:43 +00:00
# run through amp wrapper
if self.use_amp:
# An example
model, optimizers = amp.initialize(
model, self.optimizers, opt_level=self.amp_level,
)
self.optimizers = optimizers
2019-07-03 19:09:49 +00:00
self.__run_pretrain_routine(model)
2019-07-18 15:08:48 +00:00
def dp_train(self, model):
2019-07-14 20:57:15 +00:00
# CHOOSE OPTIMIZER
# filter out the weights that were done on gpu so we can load on good old cpus
self.optimizers = model.configure_optimizers()
2019-07-18 15:49:42 +00:00
model.cuda(self.data_parallel_device_ids[0])
2019-07-18 15:36:31 +00:00
model = LightningDataParallel(model, device_ids=self.data_parallel_device_ids)
2019-07-14 20:57:15 +00:00
# run through amp wrapper
if self.use_amp:
# An example
model, optimizers = amp.initialize(
model, self.optimizers, opt_level=self.amp_level,
)
self.optimizers = optimizers
self.__run_pretrain_routine(model)
2019-07-18 15:08:48 +00:00
def ddp_train(self, gpu_nb, model):
2019-07-03 19:09:49 +00:00
"""
Entry point into a DP thread
:param gpu_nb:
:param model:
:param cluster_obj:
:return:
"""
2019-07-08 17:48:59 +00:00
# node rank using relative slurm id
2019-07-08 21:31:47 +00:00
# otherwise default to node rank 0
try:
node_id = os.environ['SLURM_NODEID']
2019-07-20 13:15:09 +00:00
self.node_rank = int(node_id)
except Exception as e:
self.node_rank = 0
2019-07-08 16:27:53 +00:00
2019-07-03 20:29:10 +00:00
# recover original exp before went into process
2019-07-12 18:36:00 +00:00
# init in write mode only on proc 0
self.experiment.debug = self.proc_rank > 0
2019-07-03 20:29:10 +00:00
self.experiment = self.experiment.get_non_ddp_exp()
2019-07-03 20:17:56 +00:00
2019-07-03 22:18:29 +00:00
# show progbar only on prog_rank 0
self.prog_bar = self.prog_bar and self.node_rank == 0 and gpu_nb == 0
2019-07-08 16:27:53 +00:00
2019-07-08 13:36:09 +00:00
# determine which process we are and world size
self.proc_rank = self.node_rank * len(self.data_parallel_device_ids) + gpu_nb
2019-07-08 21:51:07 +00:00
self.world_size = self.nb_gpu_nodes * len(self.data_parallel_device_ids)
2019-07-08 13:36:09 +00:00
# set up server using proc 0's ip address
2019-07-11 18:35:41 +00:00
# try to init for 20 times at max in case ports are taken
2019-07-12 16:41:54 +00:00
# where to store ip_table
2019-07-12 17:19:10 +00:00
self.__init_tcp_connection()
2019-07-03 19:09:49 +00:00
2019-07-11 19:23:33 +00:00
# CHOOSE OPTIMIZER
# filter out the weights that were done on gpu so we can load on good old cpus
self.optimizers = model.configure_optimizers()
# MODEL
2019-07-03 19:09:49 +00:00
# copy model to each gpu
torch.cuda.set_device(gpu_nb)
model.cuda(gpu_nb)
2019-07-11 18:17:43 +00:00
2019-07-11 19:23:33 +00:00
# AMP
2019-07-11 18:17:43 +00:00
# run through amp wrapper before going to distributed DP
if self.use_amp:
# An example
model, optimizers = amp.initialize(
model, self.optimizers, opt_level=self.amp_level,
)
self.optimizers = optimizers
2019-07-03 19:09:49 +00:00
model = LightningDistributedDataParallel(model, device_ids=[gpu_nb])
# continue training routine
self.__run_pretrain_routine(model)
2019-07-12 20:07:57 +00:00
def __init_tcp_connection(self):
2019-07-12 17:39:58 +00:00
"""
Connect all procs in the world using the env:// init
Use the first node as the root address
:param port:
:param tries:
:return:
"""
2019-07-12 20:07:57 +00:00
try:
2019-07-12 20:23:20 +00:00
port = os.environ['MASTER_PORT']
2019-07-12 20:07:57 +00:00
except Exception as e:
port = 12910
2019-07-12 20:08:23 +00:00
os.environ['MASTER_PORT'] = f'{port}'
2019-07-12 19:55:28 +00:00
root_node = self.__resolve_root_node_address()
2019-07-12 19:11:32 +00:00
os.environ['MASTER_ADDR'] = root_node
2019-07-12 20:05:46 +00:00
dist.init_process_group("nccl", rank=self.proc_rank, world_size=self.world_size)
2019-07-11 18:35:41 +00:00
2019-07-20 13:31:10 +00:00
def __resolve_root_node_address(self):
try:
root_node = os.environ['SLURM_NODELIST'].split(' ')[0]
if '[' in root_node:
name = root_node.split('[')[0]
number = root_node.split(',')[0]
2019-07-20 13:31:10 +00:00
if '-' in number:
number = number.split('-')[0]
number = re.sub('[^0-9]', '', number)
root_node = name + number
except Exception as e:
root_node = '127.0.0.2'
return root_node
2019-07-03 19:09:49 +00:00
def __run_pretrain_routine(self, model):
"""
Sanity check a few things before starting actual training
:param model:
:return:
"""
2019-07-08 21:38:57 +00:00
ref_model = model
2019-07-14 02:21:17 +00:00
if self.data_parallel:
2019-07-08 21:38:57 +00:00
ref_model = model.module
2019-07-08 22:55:05 +00:00
ref_model.trainer = self
2019-07-08 21:15:26 +00:00
# set local properties on the model
2019-07-08 21:38:57 +00:00
ref_model.on_gpu = self.on_gpu
2019-07-08 21:15:26 +00:00
# transfer data loaders from model
2019-07-08 21:38:57 +00:00
self.__get_dataloaders(ref_model)
2019-07-08 21:15:26 +00:00
# init training constants
2019-07-09 00:13:40 +00:00
self.__layout_bookeeping()
2019-07-08 21:15:26 +00:00
# add lr schedulers
if self.lr_scheduler_milestones is not None:
for optimizer in self.optimizers:
scheduler = MultiStepLR(optimizer, self.lr_scheduler_milestones)
self.lr_schedulers.append(scheduler)
# print model summary
2019-07-16 01:11:29 +00:00
if self.proc_rank == 0 and self.print_weights_summary:
ref_model.summarize()
2019-07-08 21:15:26 +00:00
2019-07-03 20:17:56 +00:00
# give model convenience properties
ref_model.trainer = self
ref_model.experiment = self.experiment
2019-03-31 01:45:16 +00:00
# run tiny validation to make sure program won't crash during val
_ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)
# save exp to get started
2019-07-03 21:03:10 +00:00
if self.proc_rank == 0:
2019-07-03 21:02:30 +00:00
self.experiment.save()
2019-03-31 01:45:16 +00:00
# enable cluster checkpointing
2019-03-31 20:29:50 +00:00
if self.cluster is not None:
self.enable_auto_hpc_walltime_manager()
2019-03-31 01:45:16 +00:00
# ---------------------------
# CORE TRAINING LOOP
# ---------------------------
2019-07-01 22:38:07 +00:00
self.model = model
2019-03-31 01:45:16 +00:00
self.__train()
def __train(self):
# run all epochs
for epoch_nb in range(self.current_epoch, self.max_nb_epochs):
# update the lr scheduler
for lr_scheduler in self.lr_schedulers:
lr_scheduler.step()
2019-07-12 16:42:17 +00:00
model = self.__get_model()
2019-07-01 22:38:07 +00:00
model.current_epoch = epoch_nb
2019-03-31 01:45:16 +00:00
# hook
if self.__is_function_implemented('on_epoch_start'):
2019-07-12 16:42:17 +00:00
model = self.__get_model()
2019-07-01 22:38:07 +00:00
model.on_epoch_start()
2019-03-31 01:45:16 +00:00
self.current_epoch = epoch_nb
self.total_batches = self.nb_tng_batches + self.nb_val_batches
self.batch_loss_value = 0 # accumulated grads
# init progbar when requested
2019-07-09 00:17:55 +00:00
if self.progress_bar:
2019-03-31 01:45:16 +00:00
self.prog_bar = tqdm.tqdm(range(self.total_batches), position=self.process_position)
for batch_nb, data_batch in enumerate(self.tng_dataloader):
self.batch_nb = batch_nb
self.global_step += 1
2019-07-01 22:38:07 +00:00
2019-07-12 16:42:17 +00:00
model = self.__get_model()
2019-07-01 22:38:07 +00:00
model.global_step = self.global_step
2019-03-31 01:45:16 +00:00
# stop when the flag is changed or we've gone past the amount requested in the batches
self.total_batch_nb += 1
met_batch_limit = batch_nb > self.nb_tng_batches
if met_batch_limit:
break
# ---------------
# RUN TRAIN STEP
# ---------------
2019-05-14 10:36:26 +00:00
batch_result = self.__run_tng_batch(data_batch, batch_nb)
2019-04-23 12:46:20 +00:00
early_stop_epoch = batch_result == -1
2019-03-31 01:45:16 +00:00
# ---------------
# RUN VAL STEP
# ---------------
is_val_check_batch = (batch_nb + 1) % self.val_check_batch == 0
2019-04-23 12:46:20 +00:00
if self.fast_dev_run or is_val_check_batch or early_stop_epoch:
2019-03-31 01:45:16 +00:00
self.__run_validation()
# when batch should be saved
2019-04-23 15:12:01 +00:00
if (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch:
2019-07-03 21:02:30 +00:00
if self.proc_rank == 0:
self.experiment.save()
2019-03-31 01:45:16 +00:00
# when metrics should be logged
2019-04-23 15:12:01 +00:00
if batch_nb % self.add_log_row_interval == 0 or early_stop_epoch:
2019-03-31 01:45:16 +00:00
# count items in memory
# nb_params, nb_tensors = count_mem_items()
2019-07-12 16:42:17 +00:00
model = self.__get_model()
metrics = model.update_tng_log_metrics(self.__tng_tqdm_dic)
2019-03-31 01:45:16 +00:00
# add gpu memory
if self.on_gpu:
mem_map = get_gpu_memory_map()
metrics.update(mem_map)
# add norms
if self.track_grad_norm > 0:
2019-07-12 16:42:17 +00:00
model = self.__get_model()
2019-07-01 22:38:07 +00:00
grad_norm_dic = model.grad_norm(self.track_grad_norm)
2019-03-31 01:45:16 +00:00
metrics.update(grad_norm_dic)
if self.__is_function_implemented('on_tng_metrics'):
model.on_tng_metrics(metrics)
2019-03-31 01:45:16 +00:00
# log metrics
2019-07-01 22:38:07 +00:00
scalar_metrics = self.__metrics_to_scalars(metrics, blacklist=self.__log_vals_blacklist())
2019-07-03 21:02:30 +00:00
if self.proc_rank == 0:
self.experiment.log(scalar_metrics, global_step=self.global_step)
self.experiment.save()
2019-03-31 01:45:16 +00:00
# hook
if self.__is_function_implemented('on_batch_end'):
2019-07-12 16:42:17 +00:00
model = self.__get_model()
2019-07-01 22:38:07 +00:00
model.on_batch_end()
2019-03-31 01:45:16 +00:00
2019-04-23 12:57:58 +00:00
# end epoch early
2019-04-23 12:46:20 +00:00
if early_stop_epoch:
break
2019-03-31 01:45:16 +00:00
# hook
if self.__is_function_implemented('on_epoch_end'):
2019-07-12 16:42:17 +00:00
model = self.__get_model()
2019-07-01 22:38:07 +00:00
model.on_epoch_end()
2019-03-31 01:45:16 +00:00
# early stopping
met_min_epochs = epoch_nb > self.min_nb_epochs
if self.enable_early_stop and met_min_epochs:
2019-03-31 01:45:16 +00:00
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch_nb, logs=self.__tng_tqdm_dic)
# stop training
stop = should_stop and met_min_epochs
if stop:
return
2019-07-01 22:38:07 +00:00
def __metrics_to_scalars(self, metrics, blacklist=[]):
new_metrics = {}
for k, v in metrics.items():
if type(v) is torch.Tensor:
v = v.item()
if type(v) is dict:
v = self.__metrics_to_scalars(v)
if k not in blacklist:
new_metrics[k] = float(v)
return new_metrics
def __log_vals_blacklist(self):
"""avoid logging some vals lightning uses to maintain state"""
2019-07-18 17:32:36 +00:00
blacklist = {'batch_nb', 'v_nb', 'gpu'}
2019-07-01 22:38:07 +00:00
return blacklist
2019-04-23 12:57:58 +00:00
2019-05-14 10:36:26 +00:00
def __run_tng_batch(self, data_batch, batch_nb):
2019-03-31 01:45:16 +00:00
if data_batch is None:
2019-04-23 12:27:27 +00:00
return 0
2019-03-31 01:45:16 +00:00
# hook
if self.__is_function_implemented('on_batch_start'):
2019-07-18 15:42:47 +00:00
model_ref = self.__get_model()
response = model_ref.on_batch_start(data_batch)
2019-07-01 22:38:07 +00:00
if response == -1:
2019-04-23 12:26:48 +00:00
return -1
2019-03-31 01:45:16 +00:00
2019-07-01 22:38:07 +00:00
if self.progress_bar:
2019-03-31 01:45:16 +00:00
self.prog_bar.update(1)
# forward pass
# return a scalar value and a dic with tqdm metrics
2019-07-18 15:29:21 +00:00
if self.use_ddp:
2019-07-18 15:42:47 +00:00
output = self.model(data_batch, batch_nb)
2019-07-18 15:29:21 +00:00
elif self.use_dp:
2019-07-18 15:42:47 +00:00
output = self.model(data_batch, batch_nb)
2019-07-18 15:29:21 +00:00
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
2019-07-03 20:51:32 +00:00
else:
output = self.model.training_step(data_batch, batch_nb)
2019-07-01 22:38:07 +00:00
2019-07-11 18:58:47 +00:00
try:
model_specific_tqdm_metrics_dic = output['tqdm_metrics']
2019-07-11 19:08:45 +00:00
except Exception as e:
2019-07-11 18:58:47 +00:00
model_specific_tqdm_metrics_dic = {}
2019-07-11 19:08:45 +00:00
# if output dict doesn't have the keyword loss
# then assume the output=loss if scalar
try:
loss = output['loss']
except Exception as e:
2019-07-13 14:16:50 +00:00
if type(output) is torch.Tensor:
2019-07-11 19:08:45 +00:00
loss = output
2019-07-01 22:38:07 +00:00
2019-03-31 01:45:16 +00:00
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
# backward pass
2019-05-14 00:40:07 +00:00
if self.use_amp:
2019-07-09 00:17:55 +00:00
# scale loss when using amp
2019-05-14 00:40:07 +00:00
for optimizer in self.optimizers:
2019-05-14 01:52:02 +00:00
with amp.scale_loss(loss, optimizer) as scaled_loss:
2019-05-16 19:55:21 +00:00
scaled_loss.backward()
2019-05-14 00:40:07 +00:00
else:
loss.backward()
2019-07-21 22:23:48 +00:00
# insert after step hook
if self.__is_function_implemented('on_after_backward'):
model_ref = self.__get_model()
response = model_ref.on_after_backward()
2019-07-01 22:38:07 +00:00
if self.print_nan_grads:
2019-07-12 16:42:17 +00:00
model = self.__get_model()
2019-07-01 22:38:07 +00:00
for param in model.parameters():
2019-05-16 20:01:15 +00:00
print(param.grad.float().sum())
2019-05-16 19:58:06 +00:00
2019-07-11 18:57:26 +00:00
# avoid memory leaks
2019-03-31 01:45:16 +00:00
self.batch_loss_value += loss.item()
# gradient update with accumulated gradients
if (self.batch_nb + 1) % self.accumulate_grad_batches == 0:
2019-07-01 22:38:07 +00:00
# clip gradients
if self.gradient_clip > 0:
2019-07-12 16:42:17 +00:00
model = self.__get_model()
2019-07-01 22:38:07 +00:00
torch.nn.utils.clip_grad_norm(model.parameters(), self.gradient_clip)
2019-03-31 01:45:16 +00:00
# update gradients across all optimizers
for optimizer in self.optimizers:
optimizer.step()
2019-07-21 22:15:58 +00:00
# insert after step hook
if self.__is_function_implemented('on_before_zero_grad'):
model_ref = self.__get_model()
response = model_ref.on_before_zero_grad(optimizer)
2019-03-31 01:45:16 +00:00
# clear gradients
optimizer.zero_grad()
# queuing loss across batches blows it up proportionally... divide out the number accumulated
self.batch_loss_value = self.batch_loss_value / self.accumulate_grad_batches
# track loss
self.running_loss.append(self.batch_loss_value)
self.batch_loss_value = 0
self.avg_loss = np.mean(self.running_loss[-100:])
# update progbar
2019-07-01 22:38:07 +00:00
if self.progress_bar:
2019-03-31 01:45:16 +00:00
# add model specific metrics
tqdm_metrics = self.__tng_tqdm_dic
self.prog_bar.set_postfix(**tqdm_metrics)
# activate batch end hook
if self.__is_function_implemented('on_batch_end'):
2019-07-12 16:42:17 +00:00
model = self.__get_model()
model.on_batch_end()
2019-03-31 01:45:16 +00:00
2019-04-23 12:26:48 +00:00
return 0
2019-03-31 01:45:16 +00:00
def __run_validation(self):
# decide if can check epochs
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
if self.fast_dev_run:
print('skipping to check performance bc of --fast_dev_run')
elif not can_check_epoch:
return
try:
# hook
if self.__is_function_implemented('on_pre_performance_check'):
2019-07-12 16:42:17 +00:00
model = self.__get_model()
model.on_pre_performance_check()
2019-03-31 01:45:16 +00:00
# use full val set on end of epoch
# use a small portion otherwise
max_batches = None if not self.fast_dev_run else 1
model_specific_tqdm_metrics_dic = self.validate(
self.model,
self.val_dataloader,
max_batches
)
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
# hook
if self.__is_function_implemented('on_post_performance_check'):
2019-07-12 16:42:17 +00:00
model = self.__get_model()
model.on_post_performance_check()
2019-03-31 01:45:16 +00:00
except Exception as e:
print(e)
print(traceback.print_exc())
2019-07-01 22:38:07 +00:00
if self.progress_bar:
2019-03-31 01:45:16 +00:00
# add model specific metrics
tqdm_metrics = self.__tng_tqdm_dic
self.prog_bar.set_postfix(**tqdm_metrics)
# model checkpointing
if self.proc_rank == 0 and self.checkpoint_callback:
print('save callback...')
self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch, logs=self.__tng_tqdm_dic)