updated args
This commit is contained in:
parent
d33048c67b
commit
e3f96d6f3a
|
@ -79,9 +79,6 @@ class Trainer(TrainerIO):
|
|||
self.data_parallel_device_ids = gpus
|
||||
self.data_parallel = gpus is not None and len(gpus) > 0
|
||||
|
||||
# TODO: remove
|
||||
self.on_gpu = True
|
||||
|
||||
# training state
|
||||
self.optimizers = None
|
||||
self.prog_bar = None
|
||||
|
@ -204,7 +201,7 @@ class Trainer(TrainerIO):
|
|||
# when DP, we need to aggregate the scalars we received as outputs
|
||||
# use mean as the reduce function
|
||||
if self.data_parallel:
|
||||
output = reduce_distributed_output(output, len(self.gpus))
|
||||
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
|
||||
|
||||
outputs.append(output)
|
||||
|
||||
|
|
Loading…
Reference in New Issue