From 3be26dbb95ae80d09b52956744739a89e39ebd27 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 18 Jul 2019 12:08:17 -0400 Subject: [PATCH] added arg docs --- pytorch_lightning/models/trainer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index cc44fa392d..a6801158bf 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -30,9 +30,14 @@ except ModuleNotFoundError: def reduce_distributed_output(output, nb_gpus): pdb.set_trace() - if nb_gpus <= 1 or type(output) is torch.Tensor: + if nb_gpus <= 1: return output + # when using DP, we get one output per gpu + # average outputs and return + if type(output) is torch.Tensor: + return output.mean() + for k, v in output.items(): # recurse on nested dics if isinstance(output[k], dict):