updated args
This commit is contained in:
parent
fea10fc792
commit
d33048c67b
|
@ -15,6 +15,18 @@ except ModuleNotFoundError:
|
|||
APEX_AVAILABLE = False
|
||||
|
||||
|
||||
def reduce_distributed_output(output, nb_gpus):
|
||||
for k, v in output.items():
|
||||
# recurse on nested dics
|
||||
if isinstance(output[k], dict):
|
||||
output[k] = reduce_distributed_output(output[k], nb_gpus)
|
||||
|
||||
# reduce only metrics that have the same nb of gpus
|
||||
elif output[k].size(0) == nb_gpus:
|
||||
reduced = torch.mean(output[k])
|
||||
output[k] = reduced
|
||||
return output
|
||||
|
||||
class Trainer(TrainerIO):
|
||||
|
||||
def __init__(self,
|
||||
|
@ -188,13 +200,20 @@ class Trainer(TrainerIO):
|
|||
# RUN VALIDATION STEP
|
||||
# -----------------
|
||||
output = model(data_batch, batch_i)
|
||||
pdb.set_trace()
|
||||
|
||||
# when DP, we need to aggregate the scalars we received as outputs
|
||||
# use mean as the reduce function
|
||||
if self.data_parallel:
|
||||
output = reduce_distributed_output(output, len(self.gpus))
|
||||
|
||||
outputs.append(output)
|
||||
|
||||
# batch done
|
||||
if self.enable_tqdm and self.prog_bar is not None:
|
||||
self.prog_bar.update(1)
|
||||
|
||||
pdb.set_trace()
|
||||
|
||||
# give model a chance to do something with the outputs
|
||||
val_results = model.validation_end(outputs)
|
||||
|
||||
|
|
Loading…
Reference in New Issue