From 8cd764a15149376a1829b8bfe056e1129bab3c65 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 8 Aug 2019 12:06:29 -0400 Subject: [PATCH] removed reduce on non-loss outputs from dp (#78) * removed reduce on non-loss outputs from dp * fixed val reduce * fixed val reduce * fixed val reduce * fixed val reduce --- .../lightning_module_template.py | 17 +++++++++++--- pytorch_lightning/models/trainer.py | 23 ++++++++++++++----- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/examples/new_project_templates/lightning_module_template.py b/examples/new_project_templates/lightning_module_template.py index c11dd1b335..94e3407d96 100644 --- a/examples/new_project_templates/lightning_module_template.py +++ b/examples/new_project_templates/lightning_module_template.py @@ -151,12 +151,23 @@ class LightningTemplateModel(LightningModule): val_loss_mean = 0 val_acc_mean = 0 for output in outputs: - val_loss_mean += output['val_loss'] - val_acc_mean += output['val_acc'] + val_loss = output['val_loss'] + + # reduce manually when using dp + if self.trainer.use_dp: + val_loss = torch.mean(val_loss) + val_loss_mean += val_loss + + # reduce manually when using dp + val_acc = output['val_acc'] + if self.trainer.use_dp: + val_acc_mean = torch.mean(val_acc) + + val_acc_mean += val_acc_mean val_loss_mean /= len(outputs) val_acc_mean /= len(outputs) - tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()} + tqdm_dic = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean} return tqdm_dic # --------------------- diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index e67a90cbd7..1c646a1221 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -399,13 +399,14 @@ class Trainer(TrainerIO): output = model(data_batch, batch_i) elif self.use_dp: output = model(data_batch, batch_i) - output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) - elif self.single_gpu: + # put inputs on gpu manually gpu_id = self.data_parallel_device_ids[0] for i, x in enumerate(data_batch): if isinstance(x, torch.Tensor): data_batch[i] = x.cuda(gpu_id) + + # do non dp, ddp step output = model.validation_step(data_batch, batch_i) else: @@ -862,7 +863,6 @@ We recommend you switch to ddp if you want to use amp output = self.model(data_batch, batch_nb) elif self.use_dp: output = self.model(data_batch, batch_nb) - output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) elif self.single_gpu: gpu_id = self.data_parallel_device_ids[0] for i, x in enumerate(data_batch): @@ -874,7 +874,14 @@ We recommend you switch to ddp if you want to use amp output = self.model.training_step(data_batch, batch_nb) try: - model_specific_tqdm_metrics_dic = output['prog'] + prog_output = output['prog'] + + # reduce prog metrics for tqdm when using dp + if self.use_dp: + nb_gpus = len(self.data_parallel_device_ids) + prog_output = reduce_distributed_output(prog_output, nb_gpus) + + model_specific_tqdm_metrics_dic = prog_output except Exception: model_specific_tqdm_metrics_dic = {} @@ -886,6 +893,10 @@ We recommend you switch to ddp if you want to use amp if type(output) is torch.Tensor: loss = output + # when using dp need to reduce the loss + if self.use_dp: + loss = reduce_distributed_output(loss, len(self.data_parallel_device_ids)) + self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic) # backward pass @@ -968,12 +979,12 @@ We recommend you switch to ddp if you want to use amp # use full val set on end of epoch # use a small portion otherwise max_batches = None if not self.fast_dev_run else 1 - model_specific_tqdm_metrics_dic = self.validate( + validation_results = self.validate( self.model, self.val_dataloader, max_batches ) - self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic) + self.__add_tqdm_metrics(validation_results) # hook if self.__is_function_implemented('on_post_performance_check'):