2019-03-31 01:45:16 +00:00
|
|
|
import torch
|
|
|
|
import tqdm
|
|
|
|
import numpy as np
|
2019-03-31 20:29:50 +00:00
|
|
|
from pytorch_lightning.root_module.memory import get_gpu_memory_map
|
2019-03-31 01:45:16 +00:00
|
|
|
import traceback
|
2019-03-31 20:29:50 +00:00
|
|
|
from pytorch_lightning.root_module.model_saving import TrainerIO
|
2019-03-31 01:45:16 +00:00
|
|
|
from torch.optim.lr_scheduler import MultiStepLR
|
2019-07-03 19:09:49 +00:00
|
|
|
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDistributedDataParallel
|
2019-05-14 00:44:25 +00:00
|
|
|
import pdb
|
2019-07-03 19:09:49 +00:00
|
|
|
import torch.multiprocessing as mp
|
|
|
|
import torch.distributed as dist
|
2019-07-08 13:27:16 +00:00
|
|
|
import os
|
|
|
|
import subprocess
|
|
|
|
from time import sleep
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-05-14 00:40:07 +00:00
|
|
|
try:
|
|
|
|
from apex import amp
|
|
|
|
APEX_AVAILABLE = True
|
|
|
|
except ModuleNotFoundError:
|
|
|
|
APEX_AVAILABLE = False
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-07-01 22:38:07 +00:00
|
|
|
|
|
|
|
def reduce_distributed_output(output, nb_gpus):
|
|
|
|
for k, v in output.items():
|
|
|
|
# recurse on nested dics
|
|
|
|
if isinstance(output[k], dict):
|
|
|
|
output[k] = reduce_distributed_output(output[k], nb_gpus)
|
|
|
|
|
|
|
|
# reduce only metrics that have the same nb of gpus
|
|
|
|
elif output[k].size(0) == nb_gpus:
|
|
|
|
reduced = torch.mean(output[k])
|
|
|
|
output[k] = reduced
|
|
|
|
return output
|
|
|
|
|
2019-07-03 20:39:33 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
class Trainer(TrainerIO):
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
experiment,
|
|
|
|
checkpoint_callback, early_stop_callback,
|
2019-07-01 22:38:07 +00:00
|
|
|
gradient_clip=0,
|
2019-03-31 20:29:50 +00:00
|
|
|
cluster=None,
|
2019-03-31 01:45:16 +00:00
|
|
|
process_position=0,
|
|
|
|
current_gpu_name=0,
|
2019-07-08 21:33:20 +00:00
|
|
|
nb_gpu_nodes=1,
|
2019-07-01 22:38:07 +00:00
|
|
|
gpus=None,
|
|
|
|
progress_bar=True,
|
2019-03-31 20:29:50 +00:00
|
|
|
overfit_pct=0.0,
|
2019-03-31 01:45:16 +00:00
|
|
|
track_grad_norm=-1,
|
|
|
|
check_val_every_n_epoch=1,
|
|
|
|
fast_dev_run=False,
|
2019-03-31 20:29:50 +00:00
|
|
|
accumulate_grad_batches=1,
|
2019-07-01 22:38:07 +00:00
|
|
|
enable_early_stop=True, max_nb_epochs=1000, min_nb_epochs=1,
|
2019-03-31 01:45:16 +00:00
|
|
|
train_percent_check=1.0, val_percent_check=1.0, test_percent_check=1.0, val_check_interval=0.95,
|
2019-07-01 22:38:07 +00:00
|
|
|
log_save_interval=100, add_log_row_interval=10,
|
2019-03-31 01:45:16 +00:00
|
|
|
lr_scheduler_milestones=None,
|
2019-05-14 02:02:53 +00:00
|
|
|
use_amp=False,
|
2019-07-01 22:38:07 +00:00
|
|
|
print_nan_grads=False,
|
2019-05-16 19:45:56 +00:00
|
|
|
amp_level='O2',
|
2019-03-31 01:45:16 +00:00
|
|
|
nb_sanity_val_steps=5):
|
|
|
|
|
|
|
|
# Transfer params
|
2019-07-03 20:34:49 +00:00
|
|
|
self.nb_gpu_nodes = nb_gpu_nodes
|
2019-07-01 22:38:07 +00:00
|
|
|
self.gradient_clip = gradient_clip
|
2019-03-31 01:45:16 +00:00
|
|
|
self.check_val_every_n_epoch = check_val_every_n_epoch
|
|
|
|
self.enable_early_stop = enable_early_stop
|
|
|
|
self.track_grad_norm = track_grad_norm
|
|
|
|
self.fast_dev_run = fast_dev_run
|
2019-07-01 22:38:07 +00:00
|
|
|
self.on_gpu = gpus is not None and torch.cuda.is_available()
|
|
|
|
self.progress_bar = progress_bar
|
2019-03-31 01:45:16 +00:00
|
|
|
self.experiment = experiment
|
|
|
|
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)
|
|
|
|
self.cluster = cluster
|
|
|
|
self.process_position = process_position
|
|
|
|
self.current_gpu_name = current_gpu_name
|
|
|
|
self.checkpoint_callback = checkpoint_callback
|
|
|
|
self.checkpoint_callback.save_function = self.save_checkpoint
|
|
|
|
self.early_stop = early_stop_callback
|
|
|
|
self.model = None
|
|
|
|
self.max_nb_epochs = max_nb_epochs
|
|
|
|
self.accumulate_grad_batches = accumulate_grad_batches
|
|
|
|
self.early_stop_callback = early_stop_callback
|
|
|
|
self.min_nb_epochs = min_nb_epochs
|
|
|
|
self.nb_sanity_val_steps = nb_sanity_val_steps
|
|
|
|
self.lr_scheduler_milestones = [] if lr_scheduler_milestones is None else [int(x.strip()) for x in lr_scheduler_milestones.split(',')]
|
|
|
|
self.lr_schedulers = []
|
2019-05-16 19:45:56 +00:00
|
|
|
self.amp_level = amp_level
|
2019-07-01 22:38:07 +00:00
|
|
|
self.print_nan_grads = print_nan_grads
|
2019-07-08 21:44:06 +00:00
|
|
|
self.data_parallel_device_ids = None
|
2019-07-08 13:42:13 +00:00
|
|
|
|
|
|
|
# gpus come in as a string.
|
|
|
|
# if gpus = -1 then use all available devices
|
|
|
|
# otherwise, split the string using commas
|
|
|
|
if gpus is not None:
|
2019-07-08 14:00:04 +00:00
|
|
|
|
2019-07-08 13:42:13 +00:00
|
|
|
if gpus == '-1':
|
2019-07-08 13:45:43 +00:00
|
|
|
self.data_parallel_device_ids = list(range(0, torch.cuda.device_count()))
|
2019-07-08 13:42:13 +00:00
|
|
|
else:
|
|
|
|
self.data_parallel_device_ids = [int(x.strip()) for x in gpus.split(',')]
|
2019-06-25 22:51:41 +00:00
|
|
|
|
2019-07-08 14:00:04 +00:00
|
|
|
# set the correct cuda visible devices (using pci order)
|
|
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(x) for x in self.data_parallel_device_ids])
|
|
|
|
|
2019-07-08 13:44:20 +00:00
|
|
|
self.data_parallel = self.data_parallel_device_ids is not None and len(self.data_parallel_device_ids) > 0
|
|
|
|
|
2019-07-03 21:02:30 +00:00
|
|
|
# process info
|
|
|
|
self.proc_rank = 0
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
# training state
|
|
|
|
self.optimizers = None
|
|
|
|
self.prog_bar = None
|
|
|
|
self.global_step = 0
|
|
|
|
self.current_epoch = 0
|
|
|
|
self.total_batches = 0
|
|
|
|
|
|
|
|
# logging
|
|
|
|
self.log_save_interval = log_save_interval
|
|
|
|
self.val_check_interval = val_check_interval
|
|
|
|
self.add_log_row_interval = add_log_row_interval
|
|
|
|
|
|
|
|
# dataloaders
|
|
|
|
self.tng_dataloader = None
|
|
|
|
self.test_dataloader = None
|
|
|
|
self.val_dataloader = None
|
|
|
|
|
|
|
|
# how much of the data to use
|
|
|
|
self.__determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct)
|
|
|
|
print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu))
|
|
|
|
|
2019-05-14 00:40:07 +00:00
|
|
|
# apex test
|
|
|
|
self.use_amp = use_amp and APEX_AVAILABLE
|
2019-05-14 00:41:23 +00:00
|
|
|
if self.use_amp:
|
|
|
|
print('using 16bit precision')
|
2019-05-14 00:40:07 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
def __determine_data_use_amount(self, train_percent_check, val_percent_check, test_percent_check, overfit_pct):
|
|
|
|
"""
|
|
|
|
Use less data for debugging purposes
|
|
|
|
"""
|
|
|
|
self.train_percent_check = train_percent_check
|
|
|
|
self.val_percent_check = val_percent_check
|
|
|
|
self.test_percent_check = test_percent_check
|
|
|
|
if overfit_pct > 0:
|
|
|
|
self.train_percent_check = overfit_pct
|
|
|
|
self.val_percent_check = overfit_pct
|
|
|
|
self.test_percent_check = overfit_pct
|
|
|
|
|
|
|
|
def __is_function_implemented(self, f_name):
|
2019-04-21 18:16:54 +00:00
|
|
|
f_op = getattr(self.model, f_name, None)
|
2019-03-31 01:45:16 +00:00
|
|
|
return callable(f_op)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def __tng_tqdm_dic(self):
|
|
|
|
tqdm_dic = {
|
|
|
|
'tng_loss': '{0:.3f}'.format(self.avg_loss),
|
|
|
|
'v_nb': '{}'.format(self.experiment.version),
|
|
|
|
'epoch': '{}'.format(self.current_epoch),
|
|
|
|
'batch_nb':'{}'.format(self.batch_nb),
|
|
|
|
}
|
|
|
|
tqdm_dic.update(self.tqdm_metrics)
|
2019-07-01 22:38:07 +00:00
|
|
|
|
|
|
|
if self.on_gpu:
|
|
|
|
tqdm_dic['gpu'] = '{}'.format(self.current_gpu_name)
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
return tqdm_dic
|
|
|
|
|
2019-07-01 22:38:07 +00:00
|
|
|
def __layout_bookeeping(self, model):
|
2019-03-31 01:45:16 +00:00
|
|
|
# training bookeeping
|
|
|
|
self.total_batch_nb = 0
|
|
|
|
self.running_loss = []
|
|
|
|
self.avg_loss = 0
|
|
|
|
self.batch_nb = 0
|
|
|
|
self.tqdm_metrics = {}
|
|
|
|
|
|
|
|
# determine number of training batches
|
2019-07-01 22:38:07 +00:00
|
|
|
self.nb_tng_batches = model.nb_batches(self.tng_dataloader)
|
2019-05-14 10:11:16 +00:00
|
|
|
self.nb_tng_batches = int(self.nb_tng_batches * self.train_percent_check)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# determine number of validation batches
|
2019-07-01 22:38:07 +00:00
|
|
|
self.nb_val_batches = model.nb_batches(self.val_dataloader)
|
2019-05-14 10:11:16 +00:00
|
|
|
self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check)
|
|
|
|
self.nb_val_batches = max(1, self.nb_val_batches)
|
|
|
|
self.nb_val_batches = self.nb_val_batches
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# determine number of test batches
|
2019-07-01 22:38:07 +00:00
|
|
|
self.nb_test_batches = model.nb_batches(self.test_dataloader)
|
2019-05-14 10:11:16 +00:00
|
|
|
self.nb_test_batches = int(self.nb_test_batches * self.test_percent_check)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# determine when to check validation
|
2019-05-14 10:11:16 +00:00
|
|
|
self.val_check_batch = int(self.nb_tng_batches * self.val_check_interval)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
def __add_tqdm_metrics(self, metrics):
|
|
|
|
for k, v in metrics.items():
|
2019-07-01 22:38:07 +00:00
|
|
|
if type(v) is torch.Tensor:
|
|
|
|
v = v.item()
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
self.tqdm_metrics[k] = v
|
|
|
|
|
|
|
|
def validate(self, model, dataloader, max_batches):
|
|
|
|
"""
|
|
|
|
Run validation code
|
|
|
|
:param model: PT model
|
|
|
|
:param dataloader: PT dataloader
|
|
|
|
:param max_batches: Scalar
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
print('validating...')
|
|
|
|
|
|
|
|
# enable eval mode
|
|
|
|
model.zero_grad()
|
|
|
|
model.eval()
|
2019-07-01 22:38:07 +00:00
|
|
|
model.from_lightning = True
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# disable gradients to save memory
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
|
|
|
# bookkeeping
|
|
|
|
outputs = []
|
|
|
|
|
|
|
|
# run training
|
2019-05-14 10:36:26 +00:00
|
|
|
for batch_i, data_batch in enumerate(dataloader):
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
if data_batch is None:
|
|
|
|
continue
|
|
|
|
|
|
|
|
# stop short when on fast dev run
|
2019-05-14 10:40:11 +00:00
|
|
|
if max_batches is not None and batch_i >= max_batches:
|
2019-03-31 01:45:16 +00:00
|
|
|
break
|
|
|
|
|
|
|
|
# -----------------
|
|
|
|
# RUN VALIDATION STEP
|
|
|
|
# -----------------
|
2019-07-03 20:51:32 +00:00
|
|
|
if self.data_parallel:
|
|
|
|
output = model(data_batch, batch_i)
|
|
|
|
# output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
|
|
|
|
else:
|
|
|
|
output = model.validation_step(data_batch, batch_i)
|
2019-07-01 22:38:07 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
outputs.append(output)
|
|
|
|
|
|
|
|
# batch done
|
2019-07-01 22:38:07 +00:00
|
|
|
if self.progress_bar and self.prog_bar is not None:
|
2019-03-31 01:45:16 +00:00
|
|
|
self.prog_bar.update(1)
|
|
|
|
|
|
|
|
# give model a chance to do something with the outputs
|
2019-07-01 22:38:07 +00:00
|
|
|
if self.data_parallel:
|
|
|
|
val_results = model.module.validation_end(outputs)
|
|
|
|
else:
|
|
|
|
val_results = model.validation_end(outputs)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# enable train mode again
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
# enable gradients to save memory
|
|
|
|
torch.set_grad_enabled(True)
|
2019-07-01 22:38:07 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
return val_results
|
|
|
|
|
|
|
|
def __get_dataloaders(self, model):
|
|
|
|
"""
|
|
|
|
Dataloaders are provided by the model
|
|
|
|
:param model:
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
self.tng_dataloader = model.tng_dataloader
|
|
|
|
self.test_dataloader = model.test_dataloader
|
|
|
|
self.val_dataloader = model.val_dataloader
|
|
|
|
|
2019-07-08 21:15:26 +00:00
|
|
|
# when distributed data parallel, we need to distribute the dataset to each node
|
|
|
|
# TODO: implement
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
# -----------------------------
|
|
|
|
# MODEL TRAINING
|
|
|
|
# -----------------------------
|
|
|
|
def fit(self, model):
|
2019-07-08 21:38:57 +00:00
|
|
|
# CHOOSE OPTIMIZER
|
|
|
|
# filter out the weights that were done on gpu so we can load on good old cpus
|
|
|
|
self.optimizers = model.configure_optimizers()
|
|
|
|
|
|
|
|
# run through amp wrapper
|
|
|
|
if self.use_amp:
|
|
|
|
# An example
|
|
|
|
model, optimizers = amp.initialize(
|
|
|
|
model, self.optimizers, opt_level=self.amp_level,
|
|
|
|
)
|
|
|
|
self.optimizers = optimizers
|
|
|
|
model.trainer = self
|
|
|
|
|
2019-07-08 21:15:26 +00:00
|
|
|
# when using gpus, first thing we do is spawn a new process between each worker
|
|
|
|
# applies to single gpu, multi-gpu and multi-nodes
|
2019-03-31 01:45:16 +00:00
|
|
|
if self.on_gpu:
|
2019-07-08 16:58:47 +00:00
|
|
|
self.experiment = self.experiment.get_meta_copy()
|
2019-07-08 17:51:04 +00:00
|
|
|
mp.spawn(self.dp_train, nprocs=len(self.data_parallel_device_ids), args=(model, ))
|
2019-07-03 19:09:49 +00:00
|
|
|
else:
|
|
|
|
self.__run_pretrain_routine(model)
|
|
|
|
|
2019-07-08 17:48:59 +00:00
|
|
|
def dp_train(self, gpu_nb, model):
|
2019-07-03 19:09:49 +00:00
|
|
|
"""
|
|
|
|
Entry point into a DP thread
|
|
|
|
:param gpu_nb:
|
|
|
|
:param model:
|
|
|
|
:param cluster_obj:
|
|
|
|
:return:
|
|
|
|
"""
|
2019-07-08 17:48:59 +00:00
|
|
|
# node rank using relative slurm id
|
2019-07-08 21:31:47 +00:00
|
|
|
# otherwise default to node rank 0
|
|
|
|
try:
|
|
|
|
node_rank = int(os.environ['SLURM_NODEID'])
|
|
|
|
except KeyError as e:
|
|
|
|
node_rank = 0
|
2019-07-08 16:27:53 +00:00
|
|
|
|
2019-07-03 20:29:10 +00:00
|
|
|
# recover original exp before went into process
|
|
|
|
self.experiment = self.experiment.get_non_ddp_exp()
|
2019-07-03 20:17:56 +00:00
|
|
|
|
2019-07-03 22:18:29 +00:00
|
|
|
# show progbar only on prog_rank 0
|
2019-07-08 17:48:59 +00:00
|
|
|
self.prog_bar = self.prog_bar and node_rank == 0 and gpu_nb == 0
|
2019-07-08 16:27:53 +00:00
|
|
|
|
2019-07-08 13:36:09 +00:00
|
|
|
# determine which process we are and world size
|
2019-07-08 17:48:59 +00:00
|
|
|
self.proc_rank = node_rank * len(self.data_parallel_device_ids) + gpu_nb
|
|
|
|
world_size = self.nb_gpu_nodes * len(self.data_parallel_device_ids)
|
2019-07-08 13:36:09 +00:00
|
|
|
|
|
|
|
# set up server using proc 0's ip address
|
2019-07-08 21:29:46 +00:00
|
|
|
ip = self.__get_root_node_ip(self.proc_rank, self.nb_gpu_nodes)
|
2019-07-08 17:48:59 +00:00
|
|
|
dist.init_process_group("nccl", init_method=f'tcp://{ip}:12001', rank=self.proc_rank, world_size=world_size)
|
2019-07-08 13:27:16 +00:00
|
|
|
print(f"GPU: {gpu_nb} - Rank: {self.proc_rank}")
|
2019-07-03 19:09:49 +00:00
|
|
|
|
|
|
|
# copy model to each gpu
|
|
|
|
torch.cuda.set_device(gpu_nb)
|
|
|
|
model.cuda(gpu_nb)
|
|
|
|
model = LightningDistributedDataParallel(model, device_ids=[gpu_nb])
|
|
|
|
|
|
|
|
# continue training routine
|
|
|
|
self.__run_pretrain_routine(model)
|
|
|
|
|
2019-07-08 21:29:46 +00:00
|
|
|
def __get_root_node_ip(self, world_gpu_nb, nb_gpu_nodes):
|
2019-07-08 13:36:09 +00:00
|
|
|
"""
|
|
|
|
Resolves the ip address of proc 0.
|
|
|
|
Proc 0 writes address to a file. Every other process waits until the ip is available before it starts
|
|
|
|
|
2019-07-08 18:00:17 +00:00
|
|
|
:param world_gpu_nb: gpu number amongst all the world gpus
|
2019-07-08 13:36:09 +00:00
|
|
|
:param nb_gpu_nodes:
|
|
|
|
:param ip_file_dir:
|
|
|
|
:return:
|
|
|
|
"""
|
2019-07-08 13:27:16 +00:00
|
|
|
# on one node we use localhost
|
2019-07-08 13:33:58 +00:00
|
|
|
if nb_gpu_nodes == 1:
|
2019-07-08 17:48:59 +00:00
|
|
|
return '127.0.0.1'
|
2019-07-08 13:27:16 +00:00
|
|
|
|
2019-07-08 21:29:46 +00:00
|
|
|
# where to store ip_table
|
|
|
|
ip_file_dir = os.path.join(self.cluster.log_path, 'ip_tables')
|
|
|
|
|
2019-07-08 18:00:17 +00:00
|
|
|
# the first gpu in the world becomes the host
|
|
|
|
# this is based on its global rank
|
2019-07-08 18:22:09 +00:00
|
|
|
# it communicates its ip by saving an ip_table to the slurm cluster logging dir
|
|
|
|
# every other process waits for this ip to appear before continuing
|
2019-07-08 18:07:04 +00:00
|
|
|
ip_table_name = f'.ip_meta_' + os.environ['SLURM_JOB_ID']
|
|
|
|
ip_file = os.path.join(ip_file_dir, ip_table_name)
|
2019-07-08 18:14:36 +00:00
|
|
|
os.makedirs(ip_file_dir, exist_ok=True)
|
|
|
|
|
2019-07-08 18:00:17 +00:00
|
|
|
if world_gpu_nb == 0:
|
2019-07-08 17:48:59 +00:00
|
|
|
# get the proc 0 IP
|
|
|
|
root_ip = subprocess.run(['hostname', '-I'], stdout=subprocess.PIPE).stdout.decode('utf-8')
|
|
|
|
root_ip = root_ip.split(' ')[0]
|
2019-07-08 13:27:16 +00:00
|
|
|
|
2019-07-08 17:48:59 +00:00
|
|
|
# save the ip to the file
|
|
|
|
with open(file=ip_file, mode='w') as f:
|
|
|
|
f.write(root_ip)
|
2019-07-08 16:27:53 +00:00
|
|
|
|
2019-07-08 17:48:59 +00:00
|
|
|
return root_ip
|
|
|
|
else:
|
|
|
|
# wait up to 120 seconds until proc 0 writes
|
|
|
|
# once written, read proc 0's address and use it to configure server
|
|
|
|
for i in range(0, 120):
|
|
|
|
sleep(1.0)
|
|
|
|
if os.path.exists(ip_file):
|
|
|
|
ip = list(open(file=ip_file, mode='r'))[0]
|
|
|
|
return ip
|
2019-07-08 13:27:16 +00:00
|
|
|
|
2019-07-03 19:09:49 +00:00
|
|
|
def __run_pretrain_routine(self, model):
|
|
|
|
"""
|
|
|
|
Sanity check a few things before starting actual training
|
|
|
|
:param model:
|
|
|
|
:return:
|
|
|
|
"""
|
2019-07-08 21:38:57 +00:00
|
|
|
ref_model = model
|
|
|
|
if self.on_gpu:
|
|
|
|
ref_model = model.module
|
|
|
|
|
2019-07-08 21:15:26 +00:00
|
|
|
# set local properties on the model
|
2019-07-08 21:38:57 +00:00
|
|
|
ref_model.on_gpu = self.on_gpu
|
2019-07-08 21:15:26 +00:00
|
|
|
|
|
|
|
# transfer data loaders from model
|
2019-07-08 21:38:57 +00:00
|
|
|
self.__get_dataloaders(ref_model)
|
2019-07-08 21:15:26 +00:00
|
|
|
|
|
|
|
# init training constants
|
2019-07-08 21:38:57 +00:00
|
|
|
self.__layout_bookeeping(ref_model)
|
2019-07-08 21:15:26 +00:00
|
|
|
|
|
|
|
# add lr schedulers
|
|
|
|
if self.lr_scheduler_milestones is not None:
|
|
|
|
for optimizer in self.optimizers:
|
|
|
|
scheduler = MultiStepLR(optimizer, self.lr_scheduler_milestones)
|
|
|
|
self.lr_schedulers.append(scheduler)
|
|
|
|
|
|
|
|
# print model summary
|
2019-07-08 21:41:07 +00:00
|
|
|
if self.proc_rank == 0:
|
|
|
|
ref_model.summarize()
|
2019-07-08 21:15:26 +00:00
|
|
|
|
2019-07-03 20:17:56 +00:00
|
|
|
# give model convenience properties
|
2019-07-08 21:40:23 +00:00
|
|
|
ref_model.trainer = self
|
|
|
|
ref_model.experiment = self.experiment
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# run tiny validation to make sure program won't crash during val
|
|
|
|
_ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)
|
|
|
|
|
|
|
|
# save exp to get started
|
2019-07-03 21:03:10 +00:00
|
|
|
if self.proc_rank == 0:
|
2019-07-03 21:02:30 +00:00
|
|
|
self.experiment.save()
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# enable cluster checkpointing
|
2019-03-31 20:29:50 +00:00
|
|
|
if self.cluster is not None:
|
|
|
|
self.enable_auto_hpc_walltime_manager()
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# ---------------------------
|
|
|
|
# CORE TRAINING LOOP
|
|
|
|
# ---------------------------
|
2019-07-01 22:38:07 +00:00
|
|
|
self.model = model
|
2019-03-31 01:45:16 +00:00
|
|
|
self.__train()
|
|
|
|
|
|
|
|
def __train(self):
|
|
|
|
# run all epochs
|
|
|
|
for epoch_nb in range(self.current_epoch, self.max_nb_epochs):
|
|
|
|
# update the lr scheduler
|
|
|
|
for lr_scheduler in self.lr_schedulers:
|
|
|
|
lr_scheduler.step()
|
|
|
|
|
2019-07-01 22:38:07 +00:00
|
|
|
model = self.model.module if self.data_parallel else self.model
|
|
|
|
model.current_epoch = epoch_nb
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# hook
|
|
|
|
if self.__is_function_implemented('on_epoch_start'):
|
2019-07-01 22:38:07 +00:00
|
|
|
model = self.model.module if self.data_parallel else self.model
|
|
|
|
model.on_epoch_start()
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
self.current_epoch = epoch_nb
|
|
|
|
self.total_batches = self.nb_tng_batches + self.nb_val_batches
|
|
|
|
self.batch_loss_value = 0 # accumulated grads
|
|
|
|
|
|
|
|
# init progbar when requested
|
2019-07-03 22:18:29 +00:00
|
|
|
if self.progress_bar :
|
2019-03-31 01:45:16 +00:00
|
|
|
self.prog_bar = tqdm.tqdm(range(self.total_batches), position=self.process_position)
|
|
|
|
|
|
|
|
for batch_nb, data_batch in enumerate(self.tng_dataloader):
|
|
|
|
self.batch_nb = batch_nb
|
|
|
|
self.global_step += 1
|
2019-07-01 22:38:07 +00:00
|
|
|
|
|
|
|
model = self.model.module if self.data_parallel else self.model
|
|
|
|
model.global_step = self.global_step
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# stop when the flag is changed or we've gone past the amount requested in the batches
|
|
|
|
self.total_batch_nb += 1
|
|
|
|
met_batch_limit = batch_nb > self.nb_tng_batches
|
|
|
|
if met_batch_limit:
|
|
|
|
break
|
|
|
|
|
|
|
|
# ---------------
|
|
|
|
# RUN TRAIN STEP
|
|
|
|
# ---------------
|
2019-05-14 10:36:26 +00:00
|
|
|
batch_result = self.__run_tng_batch(data_batch, batch_nb)
|
2019-04-23 12:46:20 +00:00
|
|
|
early_stop_epoch = batch_result == -1
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# ---------------
|
|
|
|
# RUN VAL STEP
|
|
|
|
# ---------------
|
|
|
|
is_val_check_batch = (batch_nb + 1) % self.val_check_batch == 0
|
2019-04-23 12:46:20 +00:00
|
|
|
if self.fast_dev_run or is_val_check_batch or early_stop_epoch:
|
2019-03-31 01:45:16 +00:00
|
|
|
self.__run_validation()
|
|
|
|
|
|
|
|
# when batch should be saved
|
2019-04-23 15:12:01 +00:00
|
|
|
if (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch:
|
2019-07-03 21:02:30 +00:00
|
|
|
if self.proc_rank == 0:
|
|
|
|
self.experiment.save()
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# when metrics should be logged
|
2019-04-23 15:12:01 +00:00
|
|
|
if batch_nb % self.add_log_row_interval == 0 or early_stop_epoch:
|
2019-03-31 01:45:16 +00:00
|
|
|
# count items in memory
|
|
|
|
# nb_params, nb_tensors = count_mem_items()
|
|
|
|
|
2019-07-01 22:38:07 +00:00
|
|
|
if self.data_parallel:
|
|
|
|
metrics = self.model.module.update_tng_log_metrics(self.__tng_tqdm_dic)
|
|
|
|
else:
|
|
|
|
metrics = self.model.update_tng_log_metrics(self.__tng_tqdm_dic)
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# add gpu memory
|
|
|
|
if self.on_gpu:
|
|
|
|
mem_map = get_gpu_memory_map()
|
|
|
|
metrics.update(mem_map)
|
|
|
|
|
|
|
|
# add norms
|
|
|
|
if self.track_grad_norm > 0:
|
2019-07-01 22:38:07 +00:00
|
|
|
model = self.model.module if self.data_parallel else self.model
|
|
|
|
grad_norm_dic = model.grad_norm(self.track_grad_norm)
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
metrics.update(grad_norm_dic)
|
|
|
|
|
|
|
|
# log metrics
|
2019-07-01 22:38:07 +00:00
|
|
|
scalar_metrics = self.__metrics_to_scalars(metrics, blacklist=self.__log_vals_blacklist())
|
2019-07-03 21:02:30 +00:00
|
|
|
if self.proc_rank == 0:
|
|
|
|
self.experiment.log(scalar_metrics, global_step=self.global_step)
|
|
|
|
self.experiment.save()
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# hook
|
|
|
|
if self.__is_function_implemented('on_batch_end'):
|
2019-07-01 22:38:07 +00:00
|
|
|
model = self.model.module if self.data_parallel else self.model
|
|
|
|
model.on_batch_end()
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-04-23 12:57:58 +00:00
|
|
|
# end epoch early
|
2019-04-23 12:46:20 +00:00
|
|
|
if early_stop_epoch:
|
|
|
|
break
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
# hook
|
|
|
|
if self.__is_function_implemented('on_epoch_end'):
|
2019-07-01 22:38:07 +00:00
|
|
|
model = self.model.module if self.data_parallel else self.model
|
|
|
|
model.on_epoch_end()
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# early stopping
|
|
|
|
if self.enable_early_stop:
|
|
|
|
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch_nb, logs=self.__tng_tqdm_dic)
|
|
|
|
met_min_epochs = epoch_nb > self.min_nb_epochs
|
|
|
|
|
|
|
|
# stop training
|
|
|
|
stop = should_stop and met_min_epochs
|
|
|
|
if stop:
|
|
|
|
return
|
|
|
|
|
2019-07-01 22:38:07 +00:00
|
|
|
def __metrics_to_scalars(self, metrics, blacklist=[]):
|
|
|
|
new_metrics = {}
|
|
|
|
for k, v in metrics.items():
|
|
|
|
if type(v) is torch.Tensor:
|
|
|
|
v = v.item()
|
|
|
|
|
|
|
|
if type(v) is dict:
|
|
|
|
v = self.__metrics_to_scalars(v)
|
|
|
|
|
|
|
|
if k not in blacklist:
|
|
|
|
new_metrics[k] = float(v)
|
|
|
|
|
|
|
|
return new_metrics
|
|
|
|
|
|
|
|
def __log_vals_blacklist(self):
|
|
|
|
"""avoid logging some vals lightning uses to maintain state"""
|
|
|
|
blacklist = {'batch_nb', 'v_nb', 'epoch', 'gpu'}
|
|
|
|
return blacklist
|
2019-04-23 12:57:58 +00:00
|
|
|
|
2019-05-14 10:36:26 +00:00
|
|
|
def __run_tng_batch(self, data_batch, batch_nb):
|
2019-03-31 01:45:16 +00:00
|
|
|
if data_batch is None:
|
2019-04-23 12:27:27 +00:00
|
|
|
return 0
|
2019-03-31 01:45:16 +00:00
|
|
|
|
|
|
|
# hook
|
|
|
|
if self.__is_function_implemented('on_batch_start'):
|
2019-07-01 22:38:07 +00:00
|
|
|
model = self.model.module if self.data_parallel else self.model
|
|
|
|
response = model.on_batch_start(data_batch)
|
|
|
|
|
2019-04-21 17:38:50 +00:00
|
|
|
if response == -1:
|
2019-04-23 12:26:48 +00:00
|
|
|
return -1
|
2019-03-31 01:45:16 +00:00
|
|
|
|
2019-07-01 22:38:07 +00:00
|
|
|
if self.progress_bar:
|
2019-03-31 01:45:16 +00:00
|
|
|
self.prog_bar.update(1)
|
|
|
|
|
|
|
|
# forward pass
|
|
|
|
# return a scalar value and a dic with tqdm metrics
|
2019-07-03 20:51:32 +00:00
|
|
|
if self.data_parallel:
|
|
|
|
output = self.model(data_batch, batch_nb)
|
|
|
|
# output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
|
|
|
|
else:
|
|
|
|
output = self.model.training_step(data_batch, batch_nb)
|
2019-07-01 22:38:07 +00:00
|
|
|
|
|
|
|
model_specific_tqdm_metrics_dic = output['tqdm_metrics']
|
|
|
|
loss = output['loss']
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
|
|
|
|
|
|
|
|
# backward pass
|
2019-05-14 00:40:07 +00:00
|
|
|
if self.use_amp:
|
|
|
|
for optimizer in self.optimizers:
|
2019-05-14 01:52:02 +00:00
|
|
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
2019-05-16 19:55:21 +00:00
|
|
|
scaled_loss.backward()
|
2019-05-14 00:40:07 +00:00
|
|
|
else:
|
|
|
|
loss.backward()
|
|
|
|
|
2019-07-01 22:38:07 +00:00
|
|
|
if self.print_nan_grads:
|
|
|
|
model = self.model.module if self.data_parallel else self.model
|
|
|
|
for param in model.parameters():
|
2019-05-16 20:01:15 +00:00
|
|
|
print(param.grad.float().sum())
|
2019-05-16 19:58:06 +00:00
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
self.batch_loss_value += loss.item()
|
|
|
|
|
|
|
|
# gradient update with accumulated gradients
|
|
|
|
if (self.batch_nb + 1) % self.accumulate_grad_batches == 0:
|
|
|
|
|
2019-07-01 22:38:07 +00:00
|
|
|
# clip gradients
|
|
|
|
if self.gradient_clip > 0:
|
|
|
|
model = self.model.module if self.data_parallel else self.model
|
|
|
|
torch.nn.utils.clip_grad_norm(model.parameters(), self.gradient_clip)
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
# update gradients across all optimizers
|
|
|
|
for optimizer in self.optimizers:
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
# clear gradients
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
# queuing loss across batches blows it up proportionally... divide out the number accumulated
|
|
|
|
self.batch_loss_value = self.batch_loss_value / self.accumulate_grad_batches
|
|
|
|
|
|
|
|
# track loss
|
|
|
|
self.running_loss.append(self.batch_loss_value)
|
|
|
|
self.batch_loss_value = 0
|
|
|
|
self.avg_loss = np.mean(self.running_loss[-100:])
|
|
|
|
|
|
|
|
# update progbar
|
2019-07-01 22:38:07 +00:00
|
|
|
if self.progress_bar:
|
2019-03-31 01:45:16 +00:00
|
|
|
# add model specific metrics
|
|
|
|
tqdm_metrics = self.__tng_tqdm_dic
|
|
|
|
self.prog_bar.set_postfix(**tqdm_metrics)
|
|
|
|
|
|
|
|
# activate batch end hook
|
|
|
|
if self.__is_function_implemented('on_batch_end'):
|
|
|
|
self.model.on_batch_end()
|
|
|
|
|
2019-04-23 12:26:48 +00:00
|
|
|
return 0
|
|
|
|
|
2019-03-31 01:45:16 +00:00
|
|
|
def __run_validation(self):
|
|
|
|
# decide if can check epochs
|
|
|
|
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
|
|
|
|
if self.fast_dev_run:
|
|
|
|
print('skipping to check performance bc of --fast_dev_run')
|
|
|
|
elif not can_check_epoch:
|
|
|
|
return
|
|
|
|
|
|
|
|
try:
|
|
|
|
# hook
|
|
|
|
if self.__is_function_implemented('on_pre_performance_check'):
|
|
|
|
self.model.on_pre_performance_check()
|
|
|
|
|
|
|
|
# use full val set on end of epoch
|
|
|
|
# use a small portion otherwise
|
|
|
|
max_batches = None if not self.fast_dev_run else 1
|
|
|
|
model_specific_tqdm_metrics_dic = self.validate(
|
|
|
|
self.model,
|
|
|
|
self.val_dataloader,
|
|
|
|
max_batches
|
|
|
|
)
|
|
|
|
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
|
|
|
|
|
|
|
|
# hook
|
|
|
|
if self.__is_function_implemented('on_post_performance_check'):
|
|
|
|
self.model.on_post_performance_check()
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
print(e)
|
|
|
|
print(traceback.print_exc())
|
|
|
|
|
2019-07-01 22:38:07 +00:00
|
|
|
if self.progress_bar:
|
2019-03-31 01:45:16 +00:00
|
|
|
# add model specific metrics
|
|
|
|
tqdm_metrics = self.__tng_tqdm_dic
|
|
|
|
self.prog_bar.set_postfix(**tqdm_metrics)
|
|
|
|
|
|
|
|
# model checkpointing
|
|
|
|
print('save callback...')
|
2019-07-03 22:14:34 +00:00
|
|
|
if self.proc_rank == 0:
|
2019-07-08 17:48:59 +00:00
|
|
|
self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch, logs=self.__tng_tqdm_dic)
|