updated args

This commit is contained in:
William Falcon 2019-06-25 20:16:59 -04:00
parent fea10fc792
commit d33048c67b
1 changed files with 20 additions and 1 deletions

View File

@ -15,6 +15,18 @@ except ModuleNotFoundError:
APEX_AVAILABLE = False
def reduce_distributed_output(output, nb_gpus):
for k, v in output.items():
# recurse on nested dics
if isinstance(output[k], dict):
output[k] = reduce_distributed_output(output[k], nb_gpus)
# reduce only metrics that have the same nb of gpus
elif output[k].size(0) == nb_gpus:
reduced = torch.mean(output[k])
output[k] = reduced
return output
class Trainer(TrainerIO):
def __init__(self,
@ -188,13 +200,20 @@ class Trainer(TrainerIO):
# RUN VALIDATION STEP
# -----------------
output = model(data_batch, batch_i)
pdb.set_trace()
# when DP, we need to aggregate the scalars we received as outputs
# use mean as the reduce function
if self.data_parallel:
output = reduce_distributed_output(output, len(self.gpus))
outputs.append(output)
# batch done
if self.enable_tqdm and self.prog_bar is not None:
self.prog_bar.update(1)
pdb.set_trace()
# give model a chance to do something with the outputs
val_results = model.validation_end(outputs)