From 5eaaf828371113a4db2a6c4da27697a5ebf75698 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 25 Jun 2019 20:27:17 -0400 Subject: [PATCH] updated args --- pytorch_lightning/models/trainer.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 509919b04b..74875e3610 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -198,13 +198,10 @@ class Trainer(TrainerIO): # ----------------- if self.data_parallel: output = model(data_batch, batch_i) + output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) else: output = model.validation_step(data_batch, batch_i) - - # when DP, we need to aggregate the scalars we received as outputs - # use mean as the reduce function - if self.data_parallel: - output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) + output = reduce_distributed_output(output, 1) outputs.append(output)