diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 5bb222e3a9..c44afcec4d 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -5,7 +5,7 @@ from pytorch_lightning.root_module.memory import get_gpu_memory_map import traceback from pytorch_lightning.root_module.model_saving import TrainerIO from torch.optim.lr_scheduler import MultiStepLR -from torch.nn import DataParallel +from pytorch_lightning.pt_overrides.override_data_parallel import LightningDataParallel import pdb try: @@ -14,37 +14,53 @@ try: except ModuleNotFoundError: APEX_AVAILABLE = False + +def reduce_distributed_output(output, nb_gpus): + for k, v in output.items(): + # recurse on nested dics + if isinstance(output[k], dict): + output[k] = reduce_distributed_output(output[k], nb_gpus) + + # reduce only metrics that have the same nb of gpus + elif output[k].size(0) == nb_gpus: + reduced = torch.mean(output[k]) + output[k] = reduced + return output + + class Trainer(TrainerIO): def __init__(self, experiment, checkpoint_callback, early_stop_callback, + gradient_clip=0, cluster=None, process_position=0, current_gpu_name=0, - on_gpu=False, - enable_tqdm=True, + gpus=None, + progress_bar=True, overfit_pct=0.0, track_grad_norm=-1, check_val_every_n_epoch=1, fast_dev_run=False, accumulate_grad_batches=1, - enable_early_stop=True, max_nb_epochs=5, min_nb_epochs=1, + enable_early_stop=True, max_nb_epochs=1000, min_nb_epochs=1, train_percent_check=1.0, val_percent_check=1.0, test_percent_check=1.0, val_check_interval=0.95, - log_save_interval=1, add_log_row_interval=1, + log_save_interval=100, add_log_row_interval=10, lr_scheduler_milestones=None, use_amp=False, - check_grad_nans=False, + print_nan_grads=False, amp_level='O2', nb_sanity_val_steps=5): # Transfer params + self.gradient_clip = gradient_clip self.check_val_every_n_epoch = check_val_every_n_epoch self.enable_early_stop = enable_early_stop self.track_grad_norm = track_grad_norm self.fast_dev_run = fast_dev_run - self.on_gpu = on_gpu - self.enable_tqdm = enable_tqdm + self.on_gpu = gpus is not None and torch.cuda.is_available() + self.progress_bar = progress_bar self.experiment = experiment self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version) self.cluster = cluster @@ -62,9 +78,9 @@ class Trainer(TrainerIO): self.lr_scheduler_milestones = [] if lr_scheduler_milestones is None else [int(x.strip()) for x in lr_scheduler_milestones.split(',')] self.lr_schedulers = [] self.amp_level = amp_level - self.check_grad_nans = check_grad_nans - self.data_parallel_device_ids = [0] - self.data_parallel = False + self.print_nan_grads = print_nan_grads + self.data_parallel_device_ids = gpus + self.data_parallel = gpus is not None and len(gpus) > 0 # training state self.optimizers = None @@ -112,15 +128,18 @@ class Trainer(TrainerIO): def __tng_tqdm_dic(self): tqdm_dic = { 'tng_loss': '{0:.3f}'.format(self.avg_loss), - 'gpu': '{}'.format(self.current_gpu_name), 'v_nb': '{}'.format(self.experiment.version), 'epoch': '{}'.format(self.current_epoch), 'batch_nb':'{}'.format(self.batch_nb), } tqdm_dic.update(self.tqdm_metrics) + + if self.on_gpu: + tqdm_dic['gpu'] = '{}'.format(self.current_gpu_name) + return tqdm_dic - def __layout_bookeeping(self): + def __layout_bookeeping(self, model): # training bookeeping self.total_batch_nb = 0 self.running_loss = [] @@ -129,17 +148,17 @@ class Trainer(TrainerIO): self.tqdm_metrics = {} # determine number of training batches - self.nb_tng_batches = self.model.nb_batches(self.tng_dataloader) + self.nb_tng_batches = model.nb_batches(self.tng_dataloader) self.nb_tng_batches = int(self.nb_tng_batches * self.train_percent_check) # determine number of validation batches - self.nb_val_batches = self.model.nb_batches(self.val_dataloader) + self.nb_val_batches = model.nb_batches(self.val_dataloader) self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check) self.nb_val_batches = max(1, self.nb_val_batches) self.nb_val_batches = self.nb_val_batches # determine number of test batches - self.nb_test_batches = self.model.nb_batches(self.test_dataloader) + self.nb_test_batches = model.nb_batches(self.test_dataloader) self.nb_test_batches = int(self.nb_test_batches * self.test_percent_check) # determine when to check validation @@ -147,6 +166,9 @@ class Trainer(TrainerIO): def __add_tqdm_metrics(self, metrics): for k, v in metrics.items(): + if type(v) is torch.Tensor: + v = v.item() + self.tqdm_metrics[k] = v def validate(self, model, dataloader, max_batches): @@ -162,6 +184,7 @@ class Trainer(TrainerIO): # enable eval mode model.zero_grad() model.eval() + model.from_lightning = True # disable gradients to save memory torch.set_grad_enabled(False) @@ -182,21 +205,30 @@ class Trainer(TrainerIO): # ----------------- # RUN VALIDATION STEP # ----------------- - output = model.validation_step(data_batch, batch_i) + if self.data_parallel: + output = model(data_batch, batch_i) + output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) + else: + output = model.validation_step(data_batch, batch_i) + outputs.append(output) # batch done - if self.enable_tqdm and self.prog_bar is not None: + if self.progress_bar and self.prog_bar is not None: self.prog_bar.update(1) # give model a chance to do something with the outputs - val_results = model.validation_end(outputs) + if self.data_parallel: + val_results = model.module.validation_end(outputs) + else: + val_results = model.validation_end(outputs) # enable train mode again model.train() # enable gradients to save memory torch.set_grad_enabled(True) + return val_results def __get_dataloaders(self, model): @@ -213,14 +245,16 @@ class Trainer(TrainerIO): # MODEL TRAINING # ----------------------------- def fit(self, model): - self.model = model + + # give model convenience properties model.trainer = self + model.experiment = self.experiment # transfer data loaders from model self.__get_dataloaders(model) # init training constants - self.__layout_bookeeping() + self.__layout_bookeeping(model) # CHOOSE OPTIMIZER # filter out the weights that were done on gpu so we can load on good old cpus @@ -228,8 +262,8 @@ class Trainer(TrainerIO): if self.use_amp: # An example - self.model, optimizer = amp.initialize( - self.model, self.optimizers[0], opt_level=self.amp_level, + model, optimizer = amp.initialize( + model, self.optimizers[0], opt_level=self.amp_level, ) self.optimizers[0] = optimizer model.trainer = self @@ -245,9 +279,7 @@ class Trainer(TrainerIO): # put on gpu if needed if self.on_gpu: - if self.data_parallel: - model = DataParallel(model, device_ids=self.data_parallel_device_ids) - + model = LightningDataParallel(model, device_ids=self.data_parallel_device_ids) model.cuda(self.data_parallel_device_ids[0]) # run tiny validation to make sure program won't crash during val @@ -263,6 +295,7 @@ class Trainer(TrainerIO): # --------------------------- # CORE TRAINING LOOP # --------------------------- + self.model = model self.__train() def __train(self): @@ -272,24 +305,28 @@ class Trainer(TrainerIO): for lr_scheduler in self.lr_schedulers: lr_scheduler.step() - self.model.current_epoch = epoch_nb + model = self.model.module if self.data_parallel else self.model + model.current_epoch = epoch_nb # hook if self.__is_function_implemented('on_epoch_start'): - self.model.on_epoch_start() + model = self.model.module if self.data_parallel else self.model + model.on_epoch_start() self.current_epoch = epoch_nb self.total_batches = self.nb_tng_batches + self.nb_val_batches self.batch_loss_value = 0 # accumulated grads # init progbar when requested - if self.enable_tqdm: + if self.progress_bar: self.prog_bar = tqdm.tqdm(range(self.total_batches), position=self.process_position) for batch_nb, data_batch in enumerate(self.tng_dataloader): self.batch_nb = batch_nb self.global_step += 1 - self.model.global_step = self.global_step + + model = self.model.module if self.data_parallel else self.model + model.global_step = self.global_step # stop when the flag is changed or we've gone past the amount requested in the batches self.total_batch_nb += 1 @@ -319,7 +356,10 @@ class Trainer(TrainerIO): # count items in memory # nb_params, nb_tensors = count_mem_items() - metrics = self.model.update_tng_log_metrics(self.__tng_tqdm_dic) + if self.data_parallel: + metrics = self.model.module.update_tng_log_metrics(self.__tng_tqdm_dic) + else: + metrics = self.model.update_tng_log_metrics(self.__tng_tqdm_dic) # add gpu memory if self.on_gpu: @@ -328,16 +368,20 @@ class Trainer(TrainerIO): # add norms if self.track_grad_norm > 0: - grad_norm_dic = self.model.grad_norm(self.track_grad_norm) + model = self.model.module if self.data_parallel else self.model + grad_norm_dic = model.grad_norm(self.track_grad_norm) + metrics.update(grad_norm_dic) # log metrics - self.experiment.log(metrics) + scalar_metrics = self.__metrics_to_scalars(metrics, blacklist=self.__log_vals_blacklist()) + self.experiment.log(scalar_metrics, global_step=self.global_step) self.experiment.save() # hook if self.__is_function_implemented('on_batch_end'): - self.model.on_batch_end() + model = self.model.module if self.data_parallel else self.model + model.on_batch_end() # end epoch early if early_stop_epoch: @@ -345,7 +389,8 @@ class Trainer(TrainerIO): # hook if self.__is_function_implemented('on_epoch_end'): - self.model.on_epoch_end() + model = self.model.module if self.data_parallel else self.model + model.on_epoch_end() # early stopping if self.enable_early_stop: @@ -357,6 +402,24 @@ class Trainer(TrainerIO): if stop: return + def __metrics_to_scalars(self, metrics, blacklist=[]): + new_metrics = {} + for k, v in metrics.items(): + if type(v) is torch.Tensor: + v = v.item() + + if type(v) is dict: + v = self.__metrics_to_scalars(v) + + if k not in blacklist: + new_metrics[k] = float(v) + + return new_metrics + + def __log_vals_blacklist(self): + """avoid logging some vals lightning uses to maintain state""" + blacklist = {'batch_nb', 'v_nb', 'epoch', 'gpu'} + return blacklist def __run_tng_batch(self, data_batch, batch_nb): if data_batch is None: @@ -364,16 +427,26 @@ class Trainer(TrainerIO): # hook if self.__is_function_implemented('on_batch_start'): - response = self.model.on_batch_start(data_batch) + model = self.model.module if self.data_parallel else self.model + response = model.on_batch_start(data_batch) + if response == -1: return -1 - if self.enable_tqdm: + if self.progress_bar: self.prog_bar.update(1) # forward pass # return a scalar value and a dic with tqdm metrics - loss, model_specific_tqdm_metrics_dic = self.model.training_step(data_batch, batch_nb) + if self.data_parallel: + output = self.model(data_batch, batch_nb) + output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) + else: + output = self.model.training_step(data_batch, batch_nb) + + model_specific_tqdm_metrics_dic = output['tqdm_metrics'] + loss = output['loss'] + self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic) # backward pass @@ -384,8 +457,9 @@ class Trainer(TrainerIO): else: loss.backward() - if self.check_grad_nans: - for param in self.model.parameters(): + if self.print_nan_grads: + model = self.model.module if self.data_parallel else self.model + for param in model.parameters(): print(param.grad.float().sum()) self.batch_loss_value += loss.item() @@ -393,6 +467,11 @@ class Trainer(TrainerIO): # gradient update with accumulated gradients if (self.batch_nb + 1) % self.accumulate_grad_batches == 0: + # clip gradients + if self.gradient_clip > 0: + model = self.model.module if self.data_parallel else self.model + torch.nn.utils.clip_grad_norm(model.parameters(), self.gradient_clip) + # update gradients across all optimizers for optimizer in self.optimizers: optimizer.step() @@ -409,7 +488,7 @@ class Trainer(TrainerIO): self.avg_loss = np.mean(self.running_loss[-100:]) # update progbar - if self.enable_tqdm: + if self.progress_bar: # add model specific metrics tqdm_metrics = self.__tng_tqdm_dic self.prog_bar.set_postfix(**tqdm_metrics) @@ -451,7 +530,7 @@ class Trainer(TrainerIO): print(e) print(traceback.print_exc()) - if self.enable_tqdm: + if self.progress_bar: # add model specific metrics tqdm_metrics = self.__tng_tqdm_dic self.prog_bar.set_postfix(**tqdm_metrics)