""" The trainer handles all the logic for running a val loop, training loop, distributing, etc... """ import subprocess import traceback import warnings import os import pdb import re import torch from torch.utils.data.distributed import DistributedSampler from torch.optim.lr_scheduler import MultiStepLR import torch.multiprocessing as mp import torch.distributed as dist import numpy as np import tqdm from pytorch_lightning.root_module.memory import get_gpu_memory_map from pytorch_lightning.root_module.model_saving import TrainerIO from pytorch_lightning.pt_overrides.override_data_parallel import LightningDistributedDataParallel, LightningDataParallel from pytorch_lightning.utils.debugging import ForkedPdb try: from apex import amp APEX_AVAILABLE = True except ModuleNotFoundError: APEX_AVAILABLE = False def reduce_distributed_output(output, nb_gpus): if nb_gpus <= 1: return output # when using DP, we get one output per gpu # average outputs and return if type(output) is torch.Tensor: return output.mean() for k, v in output.items(): # recurse on nested dics if isinstance(output[k], dict): output[k] = reduce_distributed_output(output[k], nb_gpus) # reduce only metrics that have the same nb of gpus elif output[k].size(0) == nb_gpus: reduced = torch.mean(output[k]) output[k] = reduced return output class Trainer(TrainerIO): def __init__(self, experiment, early_stop_callback=None, checkpoint_callback=None, gradient_clip=0, cluster=None, process_position=0, current_gpu_name=0, nb_gpu_nodes=1, gpus=None, progress_bar=True, overfit_pct=0.0, track_grad_norm=-1, check_val_every_n_epoch=1, fast_dev_run=False, accumulate_grad_batches=1, max_nb_epochs=1000, min_nb_epochs=1, train_percent_check=1.0, val_percent_check=1.0, test_percent_check=1.0, val_check_interval=0.95, log_save_interval=100, add_log_row_interval=10, lr_scheduler_milestones=None, distributed_backend='dp', use_amp=False, print_nan_grads=False, print_weights_summary=True, amp_level='O2', nb_sanity_val_steps=5): """ :param experiment: Test-tube experiment :param early_stop_callback: from pytorch_lightning import EarlyStopping :param checkpoint_callback: from pytorch_lightning import Checkpoint :param gradient_clip: :param cluster: :param process_position: :param current_gpu_name: :param nb_gpu_nodes: :param gpus: :param progress_bar: :param overfit_pct: :param track_grad_norm: :param check_val_every_n_epoch: :param fast_dev_run: :param accumulate_grad_batches: :param max_nb_epochs: :param min_nb_epochs: :param train_percent_check: :param val_percent_check: :param test_percent_check: :param val_check_interval: :param log_save_interval: :param add_log_row_interval: :param lr_scheduler_milestones: :param distributed_backend: 'np' to use DistributedParallel, 'ddp' to use DistributedDataParallel :param use_amp: :param print_nan_grads: :param print_weights_summary: :param amp_level: :param nb_sanity_val_steps: """ # Transfer params self.nb_gpu_nodes = nb_gpu_nodes self.gradient_clip = gradient_clip self.check_val_every_n_epoch = check_val_every_n_epoch self.enable_early_stop = early_stop_callback is not None self.track_grad_norm = track_grad_norm self.fast_dev_run = fast_dev_run self.on_gpu = gpus is not None and torch.cuda.is_available() self.progress_bar = progress_bar self.experiment = experiment self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version) self.cluster = cluster self.process_position = process_position self.current_gpu_name = current_gpu_name self.print_weights_summary = print_weights_summary self.checkpoint_callback = checkpoint_callback if self.checkpoint_callback is not None: self.checkpoint_callback.save_function = self.save_checkpoint self.early_stop = early_stop_callback self.model = None self.max_nb_epochs = max_nb_epochs self.accumulate_grad_batches = accumulate_grad_batches self.early_stop_callback = early_stop_callback self.min_nb_epochs = min_nb_epochs self.nb_sanity_val_steps = nb_sanity_val_steps self.lr_scheduler_milestones = [] if lr_scheduler_milestones is None else [int(x.strip()) for x in lr_scheduler_milestones.split(',')] self.lr_schedulers = [] self.amp_level = amp_level self.print_nan_grads = print_nan_grads self.data_parallel_device_ids = None self.world_size = 1 self.node_rank = 0 self.use_ddp = False self.use_dp = False # training bookeeping self.total_batch_nb = 0 self.running_loss = [] self.avg_loss = 0 self.batch_nb = 0 self.tqdm_metrics = {} self.nb_val_batches = None self.nb_tng_batches = None self.nb_test_batches = None # gpus come in as a string. # if gpus = -1 then use all available devices # otherwise, split the string using commas if gpus is not None: if type(gpus) is list: self.data_parallel_device_ids = gpus elif type(gpus) is str: if gpus == '-1': self.data_parallel_device_ids = list(range(0, torch.cuda.device_count())) else: self.data_parallel_device_ids = [int(x.strip()) for x in gpus.split(',')] else: raise Exception('gpus has to be a string or list of ids') # set the correct cuda visible devices (using pci order) os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(x) for x in self.data_parallel_device_ids]) print(f'VISIBLE GPUS: {os.environ["CUDA_VISIBLE_DEVICES"]}') # make DP and DDP mutually exclusive # single GPU will also use DP with devices=[0] have_gpus = self.data_parallel_device_ids is not None and len(self.data_parallel_device_ids) > 0 if have_gpus: self.use_dp = distributed_backend == 'dp' self.use_ddp = distributed_backend == 'ddp' # use ddp automatically if nb_gpu_nodes > 1 if nb_gpu_nodes > 1 and self.use_dp: self.use_ddp = True self.use_dp = False w = 'DataParallel does not support nb_gpu_nodes > 1. ' \ 'Switching to DistributedDataParallel for you. ' \ 'To silence this warning set distributed_backend=ddp' warnings.warn(w) # process info self.proc_rank = 0 # training state self.optimizers = None self.prog_bar = None self.global_step = 0 self.current_epoch = 0 self.total_batches = 0 # logging self.log_save_interval = log_save_interval self.val_check_interval = val_check_interval self.add_log_row_interval = add_log_row_interval # dataloaders self.tng_dataloader = None self.test_dataloader = None self.val_dataloader = None # how much of the data to use self.__determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct) print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu)) # 16 bit mixed precision training using apex self.use_amp = use_amp and APEX_AVAILABLE if self.use_amp: print('using 16bit precision') if use_amp and not APEX_AVAILABLE: msg = ''' You set use_amp=True but do not have apex installed. Install apex first using this guide and rerun with use_amp=True: https://github.com/NVIDIA/apex#linux this run will NOT use 16 bit precision ''' raise ModuleNotFoundError(msg) @property def data_parallel(self): return self.use_dp or self.use_ddp def __determine_data_use_amount(self, train_percent_check, val_percent_check, test_percent_check, overfit_pct): """ Use less data for debugging purposes """ self.train_percent_check = train_percent_check self.val_percent_check = val_percent_check self.test_percent_check = test_percent_check if overfit_pct > 0: self.train_percent_check = overfit_pct self.val_percent_check = overfit_pct self.test_percent_check = overfit_pct def __get_model(self): return self.model.module if self.data_parallel else self.model def __is_function_implemented(self, f_name): model = self.__get_model() f_op = getattr(model, f_name, None) return callable(f_op) @property def __tng_tqdm_dic(self): # ForkedPdb().set_trace() tqdm_dic = { 'tng_loss': '{0:.3f}'.format(self.avg_loss), 'v_nb': '{}'.format(self.experiment.version), 'epoch': '{}'.format(self.current_epoch), 'batch_nb':'{}'.format(self.batch_nb), } tqdm_dic.update(self.tqdm_metrics) if self.on_gpu: tqdm_dic['gpu'] = '{}'.format(self.current_gpu_name) return tqdm_dic @property def tng_tqdm_dic(self): """ Read-only for tqdm metrics :return: """ return self.__tng_tqdm_dic def __layout_bookeeping(self): # determine number of training batches self.nb_tng_batches = len(self.tng_dataloader) self.nb_tng_batches = int(self.nb_tng_batches * self.train_percent_check) # determine number of validation batches self.nb_val_batches = len(self.val_dataloader) self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check) self.nb_val_batches = max(1, self.nb_val_batches) self.nb_val_batches = self.nb_val_batches # determine number of test batches self.nb_test_batches = len(self.test_dataloader) self.nb_test_batches = int(self.nb_test_batches * self.test_percent_check) # determine when to check validation self.val_check_batch = int(self.nb_tng_batches * self.val_check_interval) def __add_tqdm_metrics(self, metrics): for k, v in metrics.items(): if type(v) is torch.Tensor: v = v.item() self.tqdm_metrics[k] = v def validate(self, model, dataloader, max_batches): """ Run validation code :param model: PT model :param dataloader: PT dataloader :param max_batches: Scalar :return: """ # enable eval mode model.zero_grad() model.eval() # disable gradients to save memory torch.set_grad_enabled(False) # bookkeeping outputs = [] # run training for batch_i, data_batch in enumerate(dataloader): if data_batch is None: continue # stop short when on fast dev run if max_batches is not None and batch_i >= max_batches: break # ----------------- # RUN VALIDATION STEP # ----------------- if self.use_ddp: output = model(data_batch, batch_i) elif self.use_dp: output = model(data_batch, batch_i) output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) else: output = model.validation_step(data_batch, batch_i) outputs.append(output) # batch done if self.progress_bar and self.prog_bar is not None: self.prog_bar.update(1) # give model a chance to do something with the outputs if self.data_parallel: val_results = model.module.validation_end(outputs) else: val_results = model.validation_end(outputs) # enable train mode again model.train() # enable gradients to save memory torch.set_grad_enabled(True) return val_results def __get_dataloaders(self, model): """ Dataloaders are provided by the model :param model: :return: """ self.tng_dataloader = model.tng_dataloader self.test_dataloader = model.test_dataloader self.val_dataloader = model.val_dataloader if self.use_ddp and not isinstance(self.tng_dataloader.sampler, DistributedSampler): msg = ''' when using multiple gpus and multiple nodes you must pass a DistributedSampler to DataLoader(sampler). ie: this: dataset = myDataset() dataloader = Dataloader(dataset) becomes: dataset = myDataset() dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset) dataloader = Dataloader(dataset, sampler=dist_sampler) ''' raise Exception(msg) # ----------------------------- # MODEL TRAINING # ----------------------------- def fit(self, model): # when using multi-node or DDP within a node start each module in a separate process if self.use_ddp: # must copy only the meta of the exp so it survives pickle/unpickle when going to new process self.experiment = self.experiment.get_meta_copy() # whenever we have the correct number of tasks, we let slurm manage processes # otherwise we launch the required number of processes nb_requested_gpus = len(self.data_parallel_device_ids) * self.nb_gpu_nodes nb_slurm_tasks = 0 try: nb_slurm_tasks = int(os.environ['SLURM_NTASKS']) is_slurm_managing_tasks = nb_slurm_tasks == nb_requested_gpus except Exception as e: # likely not on slurm, so set the slurm managed flag to false is_slurm_managing_tasks = False if is_slurm_managing_tasks: task = int(os.environ['SLURM_LOCALID']) self.ddp_train(task, model) else: msg = f""" You requested {nb_requested_gpus} GPUs but launched {nb_slurm_tasks} slurm tasks. We will launch {nb_requested_gpus} processes for you. We recommend you let slurm manage the processes by setting: --ntasks-per-node={nb_requested_gpus} If you're not using SLURM, ignore this message! """ warnings.warn(msg) mp.spawn(self.ddp_train, nprocs=len(self.data_parallel_device_ids), args=(model, )) # 1 gpu or dp option triggers training using DP module # easier to avoid NCCL issues elif self.use_dp: self.__dp_train(model) # ON CPU else: # CHOOSE OPTIMIZER # filter out the weights that were done on gpu so we can load on good old cpus self.optimizers = model.configure_optimizers() # run through amp wrapper if self.use_amp: # An example model, optimizers = amp.initialize( model, self.optimizers, opt_level=self.amp_level, ) self.optimizers = optimizers self.__run_pretrain_routine(model) # return 1 when finished # used for testing or when we need to know that training succeeded return 1 def __dp_train(self, model): # CHOOSE OPTIMIZER # filter out the weights that were done on gpu so we can load on good old cpus self.optimizers = model.configure_optimizers() # run through amp wrapper if self.use_amp: # An example model, optimizers = amp.initialize( model, self.optimizers, opt_level=self.amp_level, ) self.optimizers = optimizers model = LightningDataParallel(model, device_ids=self.data_parallel_device_ids) self.__run_pretrain_routine(model) def ddp_train(self, gpu_nb, model): """ Entry point into a DP thread :param gpu_nb: :param model: :param cluster_obj: :return: """ # node rank using relative slurm id # otherwise default to node rank 0 try: node_id = os.environ['SLURM_NODEID'] self.node_rank = int(node_id) except Exception as e: self.node_rank = 0 # recover original exp before went into process # init in write mode only on proc 0 self.experiment.debug = self.proc_rank > 0 self.experiment = self.experiment.get_non_ddp_exp() # show progbar only on prog_rank 0 self.prog_bar = self.prog_bar and self.node_rank == 0 and gpu_nb == 0 # determine which process we are and world size self.proc_rank = self.node_rank * len(self.data_parallel_device_ids) + gpu_nb self.world_size = self.nb_gpu_nodes * len(self.data_parallel_device_ids) # set up server using proc 0's ip address # try to init for 20 times at max in case ports are taken # where to store ip_table self.__init_tcp_connection() # CHOOSE OPTIMIZER # filter out the weights that were done on gpu so we can load on good old cpus self.optimizers = model.configure_optimizers() # MODEL # copy model to each gpu torch.cuda.set_device(gpu_nb) model.cuda(gpu_nb) # AMP # run through amp wrapper before going to distributed DP if self.use_amp: # An example model, optimizers = amp.initialize( model, self.optimizers, opt_level=self.amp_level, ) self.optimizers = optimizers model = LightningDistributedDataParallel(model, device_ids=[gpu_nb], find_unused_parameters=True) # continue training routine self.__run_pretrain_routine(model) def __init_tcp_connection(self): """ Connect all procs in the world using the env:// init Use the first node as the root address :param port: :param tries: :return: """ try: port = os.environ['MASTER_PORT'] except Exception as e: port = 12910 os.environ['MASTER_PORT'] = f'{port}' root_node = self.__resolve_root_node_address() os.environ['MASTER_ADDR'] = root_node dist.init_process_group("nccl", rank=self.proc_rank, world_size=self.world_size) def __resolve_root_node_address(self): try: root_node = os.environ['SLURM_NODELIST'].split(' ')[0] if '[' in root_node: name = root_node.split('[')[0] number = root_node.split(',')[0] if '-' in number: number = number.split('-')[0] number = re.sub('[^0-9]', '', number) root_node = name + number except Exception as e: root_node = '127.0.0.2' return root_node def __run_pretrain_routine(self, model): """ Sanity check a few things before starting actual training :param model: :return: """ ref_model = model if self.data_parallel: ref_model = model.module ref_model.trainer = self # set local properties on the model ref_model.on_gpu = self.on_gpu # transfer data loaders from model self.__get_dataloaders(ref_model) # init training constants self.__layout_bookeeping() # add lr schedulers if self.lr_scheduler_milestones is not None: for optimizer in self.optimizers: scheduler = MultiStepLR(optimizer, self.lr_scheduler_milestones) self.lr_schedulers.append(scheduler) # print model summary if self.proc_rank == 0 and self.print_weights_summary: ref_model.summarize() # give model convenience properties ref_model.trainer = self ref_model.experiment = self.experiment # run tiny validation to make sure program won't crash during val _ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps) # save exp to get started if self.proc_rank == 0: self.experiment.save() # enable cluster checkpointing if self.cluster is not None: self.enable_auto_hpc_walltime_manager() # --------------------------- # CORE TRAINING LOOP # --------------------------- self.model = model self.__train() def __train(self): # run all epochs for epoch_nb in range(self.current_epoch, self.max_nb_epochs): # update the lr scheduler for lr_scheduler in self.lr_schedulers: lr_scheduler.step() model = self.__get_model() model.current_epoch = epoch_nb # hook if self.__is_function_implemented('on_epoch_start'): model = self.__get_model() model.on_epoch_start() self.current_epoch = epoch_nb self.total_batches = self.nb_tng_batches + self.nb_val_batches self.batch_loss_value = 0 # accumulated grads # init progbar when requested if self.progress_bar: self.prog_bar = tqdm.tqdm(range(self.total_batches), position=self.process_position) for batch_nb, data_batch in enumerate(self.tng_dataloader): self.batch_nb = batch_nb self.global_step += 1 model = self.__get_model() model.global_step = self.global_step # stop when the flag is changed or we've gone past the amount requested in the batches self.total_batch_nb += 1 met_batch_limit = batch_nb > self.nb_tng_batches if met_batch_limit: break # --------------- # RUN TRAIN STEP # --------------- batch_result = self.__run_tng_batch(data_batch, batch_nb) early_stop_epoch = batch_result == -1 # --------------- # RUN VAL STEP # --------------- is_val_check_batch = (batch_nb + 1) % self.val_check_batch == 0 if self.fast_dev_run or is_val_check_batch or early_stop_epoch: self.__run_validation() # when batch should be saved if (batch_nb + 1) % self.log_save_interval == 0 or early_stop_epoch: if self.proc_rank == 0: self.experiment.save() # when metrics should be logged if batch_nb % self.add_log_row_interval == 0 or early_stop_epoch: # count items in memory # nb_params, nb_tensors = count_mem_items() model = self.__get_model() metrics = model.update_tng_log_metrics(self.__tng_tqdm_dic) # add gpu memory if self.on_gpu: mem_map = get_gpu_memory_map() metrics.update(mem_map) # add norms if self.track_grad_norm > 0: model = self.__get_model() grad_norm_dic = model.grad_norm(self.track_grad_norm) metrics.update(grad_norm_dic) if self.__is_function_implemented('on_tng_metrics'): model.on_tng_metrics(metrics) # log metrics scalar_metrics = self.__metrics_to_scalars(metrics, blacklist=self.__log_vals_blacklist()) if self.proc_rank == 0: self.experiment.log(scalar_metrics, global_step=self.global_step) self.experiment.save() # hook if self.__is_function_implemented('on_batch_end'): model = self.__get_model() model.on_batch_end() # end epoch early if early_stop_epoch: break # hook if self.__is_function_implemented('on_epoch_end'): model = self.__get_model() model.on_epoch_end() # early stopping met_min_epochs = epoch_nb > self.min_nb_epochs if self.enable_early_stop and met_min_epochs: should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch_nb, logs=self.__tng_tqdm_dic) # stop training stop = should_stop and met_min_epochs if stop: return def __metrics_to_scalars(self, metrics, blacklist=[]): new_metrics = {} for k, v in metrics.items(): if type(v) is torch.Tensor: v = v.item() if type(v) is dict: v = self.__metrics_to_scalars(v) if k not in blacklist: new_metrics[k] = float(v) return new_metrics def __log_vals_blacklist(self): """avoid logging some vals lightning uses to maintain state""" blacklist = {'batch_nb', 'v_nb', 'gpu'} return blacklist def __run_tng_batch(self, data_batch, batch_nb): if data_batch is None: return 0 # hook if self.__is_function_implemented('on_batch_start'): model_ref = self.__get_model() response = model_ref.on_batch_start(data_batch) if response == -1: return -1 if self.progress_bar: self.prog_bar.update(1) # forward pass # return a scalar value and a dic with tqdm metrics if self.use_ddp: output = self.model(data_batch, batch_nb) elif self.use_dp: output = self.model(data_batch, batch_nb) output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) else: output = self.model.training_step(data_batch, batch_nb) try: model_specific_tqdm_metrics_dic = output['tqdm_metrics'] except Exception as e: model_specific_tqdm_metrics_dic = {} # if output dict doesn't have the keyword loss # then assume the output=loss if scalar try: loss = output['loss'] except Exception as e: if type(output) is torch.Tensor: loss = output self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic) # backward pass if self.use_amp: # scale loss when using amp for optimizer in self.optimizers: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # insert after step hook if self.__is_function_implemented('on_after_backward'): model_ref = self.__get_model() response = model_ref.on_after_backward() if self.print_nan_grads: model = self.__get_model() for param in model.parameters(): print(param.grad.float().sum()) # avoid memory leaks self.batch_loss_value += loss.item() # gradient update with accumulated gradients if (self.batch_nb + 1) % self.accumulate_grad_batches == 0: # clip gradients if self.gradient_clip > 0: model = self.__get_model() torch.nn.utils.clip_grad_norm(model.parameters(), self.gradient_clip) # update gradients across all optimizers for optimizer in self.optimizers: optimizer.step() # insert after step hook if self.__is_function_implemented('on_before_zero_grad'): model_ref = self.__get_model() response = model_ref.on_before_zero_grad(optimizer) # clear gradients optimizer.zero_grad() # queuing loss across batches blows it up proportionally... divide out the number accumulated self.batch_loss_value = self.batch_loss_value / self.accumulate_grad_batches # track loss self.running_loss.append(self.batch_loss_value) self.batch_loss_value = 0 self.avg_loss = np.mean(self.running_loss[-100:]) # update progbar if self.progress_bar: # add model specific metrics tqdm_metrics = self.__tng_tqdm_dic self.prog_bar.set_postfix(**tqdm_metrics) # activate batch end hook if self.__is_function_implemented('on_batch_end'): model = self.__get_model() model.on_batch_end() return 0 def __run_validation(self): # decide if can check epochs can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 if self.fast_dev_run: print('skipping to check performance bc of --fast_dev_run') elif not can_check_epoch: return try: # hook if self.__is_function_implemented('on_pre_performance_check'): model = self.__get_model() model.on_pre_performance_check() # use full val set on end of epoch # use a small portion otherwise max_batches = None if not self.fast_dev_run else 1 model_specific_tqdm_metrics_dic = self.validate( self.model, self.val_dataloader, max_batches ) self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic) # hook if self.__is_function_implemented('on_post_performance_check'): model = self.__get_model() model.on_post_performance_check() except Exception as e: print(e) print(traceback.print_exc()) if self.progress_bar: # add model specific metrics tqdm_metrics = self.__tng_tqdm_dic self.prog_bar.set_postfix(**tqdm_metrics) # model checkpointing if self.proc_rank == 0 and self.checkpoint_callback: print('save callback...') self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch, logs=self.__tng_tqdm_dic)