diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 1f82dd1ed3..434e50609b 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -407,12 +407,11 @@ class Trainer(TrainerIO): # forward pass # return a scalar value and a dic with tqdm metrics - output = self.model(data_batch, batch_nb) - - # when DP, we need to aggregate the scalars we received as outputs - # use mean as the reduce function if self.data_parallel: + output = self.model(data_batch, batch_nb) output = reduce_distributed_output(output, len(self.data_parallel_device_ids)) + else: + output = self.model.training_step(data_batch, batch_nb) model_specific_tqdm_metrics_dic = output['tqdm_metrics'] loss = output['loss']