diff --git a/tests/test_models.py b/tests/test_models.py index 590caa8f40..552739eb52 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -26,7 +26,7 @@ def test_dp_output_reduce(): # test identity when we have a single gpu out = torch.rand(3, 1) - assert reduce_distributed_output(out, nb_gpus=1) == out + assert reduce_distributed_output(out, nb_gpus=1) is out # average when we have multiples assert reduce_distributed_output(out, nb_gpus=2) == out.mean()