added dp reduce out test
This commit is contained in:
parent
d6e7994922
commit
23e7521300
|
@ -26,7 +26,7 @@ def test_dp_output_reduce():
|
|||
|
||||
# test identity when we have a single gpu
|
||||
out = torch.rand(3, 1)
|
||||
assert reduce_distributed_output(out, nb_gpus=1) == out
|
||||
assert reduce_distributed_output(out, nb_gpus=1) is out
|
||||
|
||||
# average when we have multiples
|
||||
assert reduce_distributed_output(out, nb_gpus=2) == out.mean()
|
||||
|
|
Loading…
Reference in New Issue