lightning/tests/metrics/test_ddp.py

46 lines
1.2 KiB
Python
Raw Normal View History

revamp entire metrics (#3868) * removed metric Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * added new metrics Co-authored-by: Teddy Koker teddy.koker@gmail.com * pep8 Co-authored-by: Teddy Koker teddy.koker@gmail.com * pep8 Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * docs Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * docs Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * win ddp tests skip Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * win ddp tests skip Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * win ddp tests skip Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * win ddp tests skip Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * reset in compute, cache compute Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * reduce_ops handling Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * sync -> sync_dist, type annotations Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * wip docs Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * mean squared error * docstring * added mean ___ error metrics * added mean ___ error metrics * seperated files * accuracy doctest * gpu fix * remove unnecessary mixin * metric and accuracy docstring Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * metric docs Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * pep8, changelog Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * refactor dist utils, pep8 * refactor dist utils, pep8 Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
2020-10-06 21:03:24 +00:00
import pytest
import torch
import os
import sys
from tests.metrics.test_metric import Dummy
from tests.metrics.utils import setup_ddp
torch.manual_seed(42)
def _test_ddp_sum(rank, worldsize):
setup_ddp(rank, worldsize)
dummy = Dummy()
dummy._reductions = {"foo": torch.sum}
dummy.foo = torch.tensor(1)
dummy._sync_dist()
assert dummy.foo == worldsize
def _test_ddp_cat(rank, worldsize):
setup_ddp(rank, worldsize)
dummy = Dummy()
dummy._reductions = {"foo": torch.cat}
dummy.foo = [torch.tensor([1])]
dummy._sync_dist()
assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1])))
def _test_ddp_sum_cat(rank, worldsize):
setup_ddp(rank, worldsize)
dummy = Dummy()
dummy._reductions = {"foo": torch.cat, "bar": torch.sum}
dummy.foo = [torch.tensor([1])]
dummy.bar = torch.tensor(1)
dummy._sync_dist()
assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1])))
assert dummy.bar == worldsize
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.parametrize("process", [_test_ddp_cat, _test_ddp_sum, _test_ddp_sum_cat])
def test_ddp(process):
torch.multiprocessing.spawn(process, args=(2,), nprocs=2)