added arg docs

This commit is contained in:
William Falcon 2019-07-18 12:08:17 -04:00
parent 2ca0864ce8
commit 3be26dbb95
1 changed files with 6 additions and 1 deletions

View File

@ -30,9 +30,14 @@ except ModuleNotFoundError:
def reduce_distributed_output(output, nb_gpus):
pdb.set_trace()
if nb_gpus <= 1 or type(output) is torch.Tensor:
if nb_gpus <= 1:
return output
# when using DP, we get one output per gpu
# average outputs and return
if type(output) is torch.Tensor:
return output.mean()
for k, v in output.items():
# recurse on nested dics
if isinstance(output[k], dict):