updated args

This commit is contained in:
William Falcon 2019-06-25 20:27:17 -04:00
parent 7527167f69
commit 5eaaf82837
1 changed files with 2 additions and 5 deletions

View File

@ -198,13 +198,10 @@ class Trainer(TrainerIO):
# -----------------
if self.data_parallel:
output = model(data_batch, batch_i)
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
else:
output = model.validation_step(data_batch, batch_i)
# when DP, we need to aggregate the scalars we received as outputs
# use mean as the reduce function
if self.data_parallel:
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
output = reduce_distributed_output(output, 1)
outputs.append(output)