updated args

This commit is contained in:
William Falcon 2019-06-25 20:17:50 -04:00
parent d33048c67b
commit e3f96d6f3a
1 changed files with 1 additions and 4 deletions

View File

@ -79,9 +79,6 @@ class Trainer(TrainerIO):
self.data_parallel_device_ids = gpus
self.data_parallel = gpus is not None and len(gpus) > 0
# TODO: remove
self.on_gpu = True
# training state
self.optimizers = None
self.prog_bar = None
@ -204,7 +201,7 @@ class Trainer(TrainerIO):
# when DP, we need to aggregate the scalars we received as outputs
# use mean as the reduce function
if self.data_parallel:
output = reduce_distributed_output(output, len(self.gpus))
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
outputs.append(output)