added arg docs
This commit is contained in:
parent
2ca0864ce8
commit
3be26dbb95
|
@ -30,9 +30,14 @@ except ModuleNotFoundError:
|
|||
|
||||
def reduce_distributed_output(output, nb_gpus):
|
||||
pdb.set_trace()
|
||||
if nb_gpus <= 1 or type(output) is torch.Tensor:
|
||||
if nb_gpus <= 1:
|
||||
return output
|
||||
|
||||
# when using DP, we get one output per gpu
|
||||
# average outputs and return
|
||||
if type(output) is torch.Tensor:
|
||||
return output.mean()
|
||||
|
||||
for k, v in output.items():
|
||||
# recurse on nested dics
|
||||
if isinstance(output[k], dict):
|
||||
|
|
Loading…
Reference in New Issue